diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index f1421c38..439f5886 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -57,9 +57,37 @@ __all__ = [ "TorchPaddleDriver", # log - "logger" + "logger", + "print", - # + # metrics + "Metric", + "Accuracy", + 'SpanFPreRecMetric', + 'ClassifyFPreRecMetric', + + # samplers + 'ReproducibleSampler', + 'RandomSampler', + "SequentialSampler", + "SortedSampler", + 'UnrepeatedSampler', + 'UnrepeatedRandomSampler', + "UnrepeatedSortedSampler", + "UnrepeatedSequentialSampler", + "ReproduceBatchSampler", + "BucketedBatchSampler", + "ReproducibleBatchSampler", + "RandomBatchSampler", + + # utils + "cache_results", + "f_rich_progress", + "auto_param_call", + "seq_len_to_mask", + + # vocabulary.py + 'Vocabulary' ] from .callbacks import * from .collators import * @@ -68,4 +96,7 @@ from .dataloaders import * from .dataset import * from .drivers import * from .log import * -from .utils import * \ No newline at end of file +from .metrics import * +from .samplers import * +from .utils import * +from .vocabulary import Vocabulary \ No newline at end of file diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index 8c3f3811..25e66cb9 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -7,7 +7,7 @@ from copy import deepcopy from pathlib import Path from typing import Optional, Dict, Tuple, Callable, Union -from fastNLP.core.utils import rank_zero_rm +from ...envs.distributed import rank_zero_rm from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.envs import rank_zero_call diff --git a/fastNLP/core/collators/__init__.py b/fastNLP/core/collators/__init__.py index 1e508689..3033c37e 100644 --- a/fastNLP/core/collators/__init__.py +++ b/fastNLP/core/collators/__init__.py @@ -8,6 +8,7 @@ __all__ = [ "NullPadder", "RawNumberPadder", "RawSequencePadder", + "RawTensorPadder", 'TorchNumberPadder', 'TorchSequencePadder', 'TorchTensorPadder', diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 5c5abda4..9ea08d95 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -67,7 +67,7 @@ def _get_backend() -> str: # 方式 (2) for backend in CHECK_BACKEND: if backend in sys.modules: - logger.debug(f"sys.modules contains backend:{catch_backend[0]}.") + logger.debug(f"sys.modules contains backend:{backend}.") return backend for key, module in sys.modules.items(): catch_backend = _check_module(module) diff --git a/fastNLP/core/collators/padders/__init__.py b/fastNLP/core/collators/padders/__init__.py index 09a5ca8d..11ffc07b 100644 --- a/fastNLP/core/collators/padders/__init__.py +++ b/fastNLP/core/collators/padders/__init__.py @@ -9,6 +9,7 @@ __all__ = [ "RawNumberPadder", "RawSequencePadder", + "RawTensorPadder", 'TorchNumberPadder', 'TorchSequencePadder', diff --git a/fastNLP/core/collators/padders/numpy_padder.py b/fastNLP/core/collators/padders/numpy_padder.py index 4d507f2e..1113c91a 100644 --- a/fastNLP/core/collators/padders/numpy_padder.py +++ b/fastNLP/core/collators/padders/numpy_padder.py @@ -79,7 +79,7 @@ class NumpyTensorPadder(Padder): def pad(batch_field, pad_val, dtype): try: if not isinstance(batch_field[0], np.ndarray): - batch_field = [np.array(field.tolist()) for field in batch_field] + batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] except AttributeError: raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " f"it must have tolist() method.") diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 10d5a385..f7db6534 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -131,7 +131,7 @@ class PaddleTensorPadder(Padder): def pad(batch_field, pad_val, dtype): try: if not isinstance(batch_field[0], paddle.Tensor): - batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field] + batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] except AttributeError: raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") @@ -143,8 +143,6 @@ class PaddleTensorPadder(Padder): tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) - if isinstance(field, np.ndarray): - field = paddle.to_tensor(field) tensor[slices] = field return tensor diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index 18f414e8..f1940380 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -114,7 +114,7 @@ class TorchTensorPadder(Padder): def pad(batch_field, pad_val, dtype): try: if not isinstance(batch_field[0], torch.Tensor): - batch_field = [torch.tensor(field.tolist()) for field in batch_field] + batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] except AttributeError: raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") @@ -124,8 +124,6 @@ class TorchTensorPadder(Padder): tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) - if isinstance(field, np.ndarray): - field = torch.from_numpy(field) tensor[slices] = field return tensor diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 73342748..f3a739f0 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -18,9 +18,9 @@ from fastNLP.core.utils import ( auto_param_call, check_user_specific_params, paddle_move_data_to_device, - is_in_paddle_dist, - rank_zero_rm + is_in_paddle_dist ) +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.core.samplers import ( ReproduceBatchSampler, ReproducibleSampler, diff --git a/fastNLP/core/log/__init__.py b/fastNLP/core/log/__init__.py index 3cb6d4dc..d1d95f20 100644 --- a/fastNLP/core/log/__init__.py +++ b/fastNLP/core/log/__init__.py @@ -1,6 +1,8 @@ __all__ = [ - 'logger' + 'logger', + "print" ] from .logger import logger +from .print import print diff --git a/fastNLP/core/metrics/__init__.py b/fastNLP/core/metrics/__init__.py index 82bca331..f7d60606 100644 --- a/fastNLP/core/metrics/__init__.py +++ b/fastNLP/core/metrics/__init__.py @@ -1,16 +1,11 @@ __all__ = [ "Metric", "Accuracy", - 'Backend', - 'AutoBackend', - 'PaddleBackend', - 'TorchBackend', 'SpanFPreRecMetric', 'ClassifyFPreRecMetric', ] from .metric import Metric from .accuracy import Accuracy -from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend from .span_f1_pre_rec_metric import SpanFPreRecMetric from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 9fb538a9..4af6a24a 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -23,8 +23,6 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', - 'rank_zero_rm', - 'rank_zero_mkdir' ] from .cache_results import cache_results @@ -36,7 +34,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_utils import torch_move_data_to_device from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ - deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir -from ..dataloaders.utils import indice_collate_wrapper + deprecated, seq_len_to_mask diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 91b3c8f6..93f38e2a 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -22,8 +22,6 @@ import numpy as np from pathlib import Path from fastNLP.core.log import logger -from fastNLP.envs import FASTNLP_GLOBAL_RANK - __all__ = [ 'get_fn_arg_names', @@ -37,8 +35,6 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', - 'rank_zero_rm', - 'rank_zero_mkdir' ] @@ -609,54 +605,6 @@ def wait_filepath(path, exist=True): logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") - -def rank_zero_rm(path: Optional[Union[str, Path]]): - """ - 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 - 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; - 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 - - :param path: - :return: - """ - if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: - if path is None: - return - if isinstance(path, str): - path = Path(path) - if not path.exists(): - return - _recursive_rm(path) - - -def _recursive_rm(path: Path): - if path.is_file() or path.is_symlink(): - if path.exists(): - try: - path.unlink() - except Exception: - pass - return - for sub_path in list(path.iterdir()): - _recursive_rm(sub_path) - path.rmdir() - - -def rank_zero_mkdir(path: Optional[Union[str, Path]]): - """ - 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; - 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 - - """ - if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: - if path is None: - return - if isinstance(path, str): - path = Path(path) - - path.mkdir(parents=True, exist_ok=True) - - def get_class_that_defined_method(method): """ 给定一个method,返回这个 method 的 class 的对象 diff --git a/fastNLP/envs/__init__.py b/fastNLP/envs/__init__.py index bc09c33b..6c5e857e 100644 --- a/fastNLP/envs/__init__.py +++ b/fastNLP/envs/__init__.py @@ -3,12 +3,17 @@ r""" """ __all__ = [ 'dump_fastnlp_backend', - 'is_cur_env_distributed', - 'get_global_rank', - 'rank_zero_call', - 'all_rank_call_context', + + # utils 'get_gpu_count', - 'fastnlp_no_sync_context' + + # distributed + "rank_zero_rm", + 'rank_zero_call', + 'get_global_rank', + 'fastnlp_no_sync_context', + 'all_rank_call_context', + 'is_cur_env_distributed', ] diff --git a/fastNLP/envs/distributed.py b/fastNLP/envs/distributed.py index 34515c2c..3d87c8b2 100644 --- a/fastNLP/envs/distributed.py +++ b/fastNLP/envs/distributed.py @@ -1,6 +1,7 @@ import os from functools import wraps -from typing import Callable, Any, Optional +from pathlib import Path +from typing import Callable, Any, Optional, Union from contextlib import contextmanager __all__ = [ @@ -8,7 +9,8 @@ __all__ = [ 'get_global_rank', 'rank_zero_call', 'all_rank_call_context', - 'fastnlp_no_sync_context' + 'fastnlp_no_sync_context', + "rank_zero_rm" ] from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC @@ -96,3 +98,35 @@ def all_rank_call_context(): os.environ[FASTNLP_GLOBAL_RANK] = old_fastnlp_global_rank else: os.environ.pop(FASTNLP_GLOBAL_RANK) + + +def rank_zero_rm(path: Optional[Union[str, Path]]): + """ + 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 + 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; + 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 + + :param path: + :return: + """ + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: + if path is None: + return + if isinstance(path, str): + path = Path(path) + if not path.exists(): + return + _recursive_rm(path) + + +def _recursive_rm(path: Path): + if path.is_file() or path.is_symlink(): + if path.exists(): + try: + path.unlink() + except Exception: + pass + return + for sub_path in list(path.iterdir()): + _recursive_rm(sub_path) + path.rmdir() \ No newline at end of file diff --git a/fastNLP/envs/env.py b/fastNLP/envs/env.py index 74d833e0..9cc05a02 100644 --- a/fastNLP/envs/env.py +++ b/fastNLP/envs/env.py @@ -22,7 +22,7 @@ FASTNLP_GLOBAL_RANK = "FASTNLP_GLOBAL_RANK" FASTNLP_LOG_LEVEL = "FASTNLP_LOG_LEVEL" -# todo 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; +# 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; # FASTNLP_LAUNCH_TIME 记录了当前 fastNLP 脚本启动的时间。 FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME" @@ -42,7 +42,7 @@ USER_CUDA_VISIBLE_DEVICES = 'USER_CUDA_VISIBLE_DEVICES' # 用于在 torch.distributed.launch 时移除传入的 rank ,在 pytorch 中有使用。值的可选为 [0, 1] FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' -# todo 注释 +# 检测到当前脚本是通过类似 python -m torch.launch 启动的话设置这个变量为1 FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 2de21825..60dcc862 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -11,7 +11,7 @@ from fastNLP.core.controllers.trainer import Trainer from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 08c6f8e2..9c32c20b 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -20,7 +20,7 @@ from fastNLP.core.controllers.trainer import Trainer from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index ba1e7e08..65101321 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -83,7 +83,7 @@ class TestCollator: assert raw_pad_batch == collator(dict_batch) collator = Collator(backend='raw') raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], [{'1'}, {'2'}]] findListDiff(raw_pad_lst, collator(list_batch)) @@ -194,7 +194,7 @@ class TestCollator: collator.set_ignore('_0', '_3', '_1') collator.set_pad('_4', pad_val=None) raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], [{'1'}, {'2'}]] findListDiff(raw_pad_lst, collator(list_batch)) @@ -210,7 +210,7 @@ class TestCollator: collator.set_pad('_2', backend='numpy') collator.set_pad('_4', backend='numpy', pad_val=100) raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], [{'1'}, {'2'}]] findListDiff(raw_pad_lst, collator(list_batch)) diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 102ab310..e3d90e9b 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -13,7 +13,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.utils import magic_argv_env_context, Capturing -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch.distributed as dist diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index a184bb11..3b3f15ec 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -12,7 +12,7 @@ from fastNLP.core.samplers import ( from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index b8ccd802..ba243106 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -7,7 +7,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH if _NEED_IMPORT_PADDLE: import paddle diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index d6f0ee77..0e3f99ad 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -12,7 +12,7 @@ from fastNLP.core.samplers import ( from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index ef60e2b6..9115ed19 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -7,7 +7,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch diff --git a/tests/core/log/test_logger_torch.py b/tests/core/log/test_logger_torch.py index 13a758e9..7d45782c 100644 --- a/tests/core/log/test_logger_torch.py +++ b/tests/core/log/test_logger_torch.py @@ -7,7 +7,7 @@ import re import pytest from fastNLP.envs.env import FASTNLP_LAUNCH_TIME -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.core.log.logger import logger from tests.helpers.utils import magic_argv_env_context, recover_logger diff --git a/tests/core/utils/test_cache_results.py b/tests/core/utils/test_cache_results.py index 77c618bb..efef9f10 100644 --- a/tests/core/utils/test_cache_results.py +++ b/tests/core/utils/test_cache_results.py @@ -6,7 +6,7 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) from fastNLP.core.utils.cache_results import cache_results -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm def get_subprocess_results(cmd): diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 170110ce..c45acd7b 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -3,7 +3,7 @@ import pytest from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm def test_dump_fastnlp_envs(): diff --git a/tests/modules/mix_modules/_test_mix_module.py b/tests/modules/mix_modules/_test_mix_module.py index 700e0cfe..87206fd6 100644 --- a/tests/modules/mix_modules/_test_mix_module.py +++ b/tests/modules/mix_modules/_test_mix_module.py @@ -9,7 +9,7 @@ import numpy as np from fastNLP.modules.mix_modules.mix_module import MixModule from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm ############################################################################