mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 11:17:50 +08:00
设置init
This commit is contained in:
parent
b2e9507118
commit
5e24933521
@ -57,9 +57,37 @@ __all__ = [
|
||||
"TorchPaddleDriver",
|
||||
|
||||
# log
|
||||
"logger"
|
||||
"logger",
|
||||
"print",
|
||||
|
||||
#
|
||||
# metrics
|
||||
"Metric",
|
||||
"Accuracy",
|
||||
'SpanFPreRecMetric',
|
||||
'ClassifyFPreRecMetric',
|
||||
|
||||
# samplers
|
||||
'ReproducibleSampler',
|
||||
'RandomSampler',
|
||||
"SequentialSampler",
|
||||
"SortedSampler",
|
||||
'UnrepeatedSampler',
|
||||
'UnrepeatedRandomSampler',
|
||||
"UnrepeatedSortedSampler",
|
||||
"UnrepeatedSequentialSampler",
|
||||
"ReproduceBatchSampler",
|
||||
"BucketedBatchSampler",
|
||||
"ReproducibleBatchSampler",
|
||||
"RandomBatchSampler",
|
||||
|
||||
# utils
|
||||
"cache_results",
|
||||
"f_rich_progress",
|
||||
"auto_param_call",
|
||||
"seq_len_to_mask",
|
||||
|
||||
# vocabulary.py
|
||||
'Vocabulary'
|
||||
]
|
||||
from .callbacks import *
|
||||
from .collators import *
|
||||
@ -68,4 +96,7 @@ from .dataloaders import *
|
||||
from .dataset import *
|
||||
from .drivers import *
|
||||
from .log import *
|
||||
from .utils import *
|
||||
from .metrics import *
|
||||
from .samplers import *
|
||||
from .utils import *
|
||||
from .vocabulary import Vocabulary
|
@ -7,7 +7,7 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Tuple, Callable, Union
|
||||
|
||||
from fastNLP.core.utils import rank_zero_rm
|
||||
from ...envs.distributed import rank_zero_rm
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME
|
||||
from fastNLP.envs import rank_zero_call
|
||||
|
@ -8,6 +8,7 @@ __all__ = [
|
||||
"NullPadder",
|
||||
"RawNumberPadder",
|
||||
"RawSequencePadder",
|
||||
"RawTensorPadder",
|
||||
'TorchNumberPadder',
|
||||
'TorchSequencePadder',
|
||||
'TorchTensorPadder',
|
||||
|
@ -67,7 +67,7 @@ def _get_backend() -> str:
|
||||
# 方式 (2)
|
||||
for backend in CHECK_BACKEND:
|
||||
if backend in sys.modules:
|
||||
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.")
|
||||
logger.debug(f"sys.modules contains backend:{backend}.")
|
||||
return backend
|
||||
for key, module in sys.modules.items():
|
||||
catch_backend = _check_module(module)
|
||||
|
@ -9,6 +9,7 @@ __all__ = [
|
||||
|
||||
"RawNumberPadder",
|
||||
"RawSequencePadder",
|
||||
"RawTensorPadder",
|
||||
|
||||
'TorchNumberPadder',
|
||||
'TorchSequencePadder',
|
||||
|
@ -79,7 +79,7 @@ class NumpyTensorPadder(Padder):
|
||||
def pad(batch_field, pad_val, dtype):
|
||||
try:
|
||||
if not isinstance(batch_field[0], np.ndarray):
|
||||
batch_field = [np.array(field.tolist()) for field in batch_field]
|
||||
batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field]
|
||||
except AttributeError:
|
||||
raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), "
|
||||
f"it must have tolist() method.")
|
||||
|
@ -131,7 +131,7 @@ class PaddleTensorPadder(Padder):
|
||||
def pad(batch_field, pad_val, dtype):
|
||||
try:
|
||||
if not isinstance(batch_field[0], paddle.Tensor):
|
||||
batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field]
|
||||
batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field]
|
||||
except AttributeError:
|
||||
raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), "
|
||||
f"it must have tolist() method.")
|
||||
@ -143,8 +143,6 @@ class PaddleTensorPadder(Padder):
|
||||
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype)
|
||||
for i, field in enumerate(batch_field):
|
||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
|
||||
if isinstance(field, np.ndarray):
|
||||
field = paddle.to_tensor(field)
|
||||
tensor[slices] = field
|
||||
return tensor
|
||||
|
||||
|
@ -114,7 +114,7 @@ class TorchTensorPadder(Padder):
|
||||
def pad(batch_field, pad_val, dtype):
|
||||
try:
|
||||
if not isinstance(batch_field[0], torch.Tensor):
|
||||
batch_field = [torch.tensor(field.tolist()) for field in batch_field]
|
||||
batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field]
|
||||
except AttributeError:
|
||||
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), "
|
||||
f"it must have tolist() method.")
|
||||
@ -124,8 +124,6 @@ class TorchTensorPadder(Padder):
|
||||
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)
|
||||
for i, field in enumerate(batch_field):
|
||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
|
||||
if isinstance(field, np.ndarray):
|
||||
field = torch.from_numpy(field)
|
||||
tensor[slices] = field
|
||||
return tensor
|
||||
|
||||
|
@ -18,9 +18,9 @@ from fastNLP.core.utils import (
|
||||
auto_param_call,
|
||||
check_user_specific_params,
|
||||
paddle_move_data_to_device,
|
||||
is_in_paddle_dist,
|
||||
rank_zero_rm
|
||||
is_in_paddle_dist
|
||||
)
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.core.samplers import (
|
||||
ReproduceBatchSampler,
|
||||
ReproducibleSampler,
|
||||
|
@ -1,6 +1,8 @@
|
||||
__all__ = [
|
||||
'logger'
|
||||
'logger',
|
||||
"print"
|
||||
]
|
||||
|
||||
from .logger import logger
|
||||
from .print import print
|
||||
|
||||
|
@ -1,16 +1,11 @@
|
||||
__all__ = [
|
||||
"Metric",
|
||||
"Accuracy",
|
||||
'Backend',
|
||||
'AutoBackend',
|
||||
'PaddleBackend',
|
||||
'TorchBackend',
|
||||
'SpanFPreRecMetric',
|
||||
'ClassifyFPreRecMetric',
|
||||
]
|
||||
|
||||
from .metric import Metric
|
||||
from .accuracy import Accuracy
|
||||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend
|
||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric
|
||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric
|
||||
|
@ -23,8 +23,6 @@ __all__ = [
|
||||
'Option',
|
||||
'deprecated',
|
||||
'seq_len_to_mask',
|
||||
'rank_zero_rm',
|
||||
'rank_zero_mkdir'
|
||||
]
|
||||
|
||||
from .cache_results import cache_results
|
||||
@ -36,7 +34,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device
|
||||
from .torch_utils import torch_move_data_to_device
|
||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
|
||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
|
||||
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
|
||||
from ..dataloaders.utils import indice_collate_wrapper
|
||||
deprecated, seq_len_to_mask
|
||||
|
||||
|
||||
|
@ -22,8 +22,6 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK
|
||||
|
||||
|
||||
__all__ = [
|
||||
'get_fn_arg_names',
|
||||
@ -37,8 +35,6 @@ __all__ = [
|
||||
'Option',
|
||||
'deprecated',
|
||||
'seq_len_to_mask',
|
||||
'rank_zero_rm',
|
||||
'rank_zero_mkdir'
|
||||
]
|
||||
|
||||
|
||||
@ -609,54 +605,6 @@ def wait_filepath(path, exist=True):
|
||||
logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...")
|
||||
|
||||
|
||||
|
||||
def rank_zero_rm(path: Optional[Union[str, Path]]):
|
||||
"""
|
||||
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
|
||||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
|
||||
该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。
|
||||
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
if path is None:
|
||||
return
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
return
|
||||
_recursive_rm(path)
|
||||
|
||||
|
||||
def _recursive_rm(path: Path):
|
||||
if path.is_file() or path.is_symlink():
|
||||
if path.exists():
|
||||
try:
|
||||
path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
for sub_path in list(path.iterdir()):
|
||||
_recursive_rm(sub_path)
|
||||
path.rmdir()
|
||||
|
||||
|
||||
def rank_zero_mkdir(path: Optional[Union[str, Path]]):
|
||||
"""
|
||||
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数;
|
||||
该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。
|
||||
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
if path is None:
|
||||
return
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_class_that_defined_method(method):
|
||||
"""
|
||||
给定一个method,返回这个 method 的 class 的对象
|
||||
|
@ -3,12 +3,17 @@ r"""
|
||||
"""
|
||||
__all__ = [
|
||||
'dump_fastnlp_backend',
|
||||
'is_cur_env_distributed',
|
||||
'get_global_rank',
|
||||
'rank_zero_call',
|
||||
'all_rank_call_context',
|
||||
|
||||
# utils
|
||||
'get_gpu_count',
|
||||
'fastnlp_no_sync_context'
|
||||
|
||||
# distributed
|
||||
"rank_zero_rm",
|
||||
'rank_zero_call',
|
||||
'get_global_rank',
|
||||
'fastnlp_no_sync_context',
|
||||
'all_rank_call_context',
|
||||
'is_cur_env_distributed',
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import Callable, Any, Optional
|
||||
from pathlib import Path
|
||||
from typing import Callable, Any, Optional, Union
|
||||
from contextlib import contextmanager
|
||||
|
||||
__all__ = [
|
||||
@ -8,7 +9,8 @@ __all__ = [
|
||||
'get_global_rank',
|
||||
'rank_zero_call',
|
||||
'all_rank_call_context',
|
||||
'fastnlp_no_sync_context'
|
||||
'fastnlp_no_sync_context',
|
||||
"rank_zero_rm"
|
||||
]
|
||||
|
||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC
|
||||
@ -96,3 +98,35 @@ def all_rank_call_context():
|
||||
os.environ[FASTNLP_GLOBAL_RANK] = old_fastnlp_global_rank
|
||||
else:
|
||||
os.environ.pop(FASTNLP_GLOBAL_RANK)
|
||||
|
||||
|
||||
def rank_zero_rm(path: Optional[Union[str, Path]]):
|
||||
"""
|
||||
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候
|
||||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件;
|
||||
该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。
|
||||
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
|
||||
if path is None:
|
||||
return
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
return
|
||||
_recursive_rm(path)
|
||||
|
||||
|
||||
def _recursive_rm(path: Path):
|
||||
if path.is_file() or path.is_symlink():
|
||||
if path.exists():
|
||||
try:
|
||||
path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
for sub_path in list(path.iterdir()):
|
||||
_recursive_rm(sub_path)
|
||||
path.rmdir()
|
@ -22,7 +22,7 @@ FASTNLP_GLOBAL_RANK = "FASTNLP_GLOBAL_RANK"
|
||||
FASTNLP_LOG_LEVEL = "FASTNLP_LOG_LEVEL"
|
||||
|
||||
|
||||
# todo 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp;
|
||||
# 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp;
|
||||
# FASTNLP_LAUNCH_TIME 记录了当前 fastNLP 脚本启动的时间。
|
||||
FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME"
|
||||
|
||||
@ -42,7 +42,7 @@ USER_CUDA_VISIBLE_DEVICES = 'USER_CUDA_VISIBLE_DEVICES'
|
||||
# 用于在 torch.distributed.launch 时移除传入的 rank ,在 pytorch 中有使用。值的可选为 [0, 1]
|
||||
FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK'
|
||||
|
||||
# todo 注释
|
||||
# 检测到当前脚本是通过类似 python -m torch.launch 启动的话设置这个变量为1
|
||||
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH"
|
||||
|
||||
# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行;
|
||||
|
@ -11,7 +11,7 @@ from fastNLP.core.controllers.trainer import Trainer
|
||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
|
||||
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
|
||||
from torchmetrics import Accuracy
|
||||
|
@ -20,7 +20,7 @@ from fastNLP.core.controllers.trainer import Trainer
|
||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
|
||||
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
|
||||
from torchmetrics import Accuracy
|
||||
|
@ -83,7 +83,7 @@ class TestCollator:
|
||||
assert raw_pad_batch == collator(dict_batch)
|
||||
collator = Collator(backend='raw')
|
||||
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
|
||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
|
||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}],
|
||||
[{'1'}, {'2'}]]
|
||||
findListDiff(raw_pad_lst, collator(list_batch))
|
||||
|
||||
@ -194,7 +194,7 @@ class TestCollator:
|
||||
collator.set_ignore('_0', '_3', '_1')
|
||||
collator.set_pad('_4', pad_val=None)
|
||||
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
|
||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
|
||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}],
|
||||
[{'1'}, {'2'}]]
|
||||
findListDiff(raw_pad_lst, collator(list_batch))
|
||||
|
||||
@ -210,7 +210,7 @@ class TestCollator:
|
||||
collator.set_pad('_2', backend='numpy')
|
||||
collator.set_pad('_4', backend='numpy', pad_val=100)
|
||||
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
|
||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
|
||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}],
|
||||
[{'1'}, {'2'}]]
|
||||
findListDiff(raw_pad_lst, collator(list_batch))
|
||||
|
||||
|
@ -13,7 +13,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
|
||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
|
||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
|
||||
from tests.helpers.utils import magic_argv_env_context, Capturing
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch.distributed as dist
|
||||
|
@ -12,7 +12,7 @@ from fastNLP.core.samplers import (
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
|
@ -7,7 +7,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
|
@ -12,7 +12,7 @@ from fastNLP.core.samplers import (
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
|
@ -7,7 +7,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
|
@ -7,7 +7,7 @@ import re
|
||||
import pytest
|
||||
|
||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.core.log.logger import logger
|
||||
|
||||
from tests.helpers.utils import magic_argv_env_context, recover_logger
|
||||
|
@ -6,7 +6,7 @@ import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))
|
||||
|
||||
from fastNLP.core.utils.cache_results import cache_results
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
|
||||
|
||||
def get_subprocess_results(cmd):
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
|
||||
from fastNLP.envs.set_backend import dump_fastnlp_backend
|
||||
from tests.helpers.utils import Capturing
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
|
||||
|
||||
def test_dump_fastnlp_envs():
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
|
||||
from fastNLP.modules.mix_modules.mix_module import MixModule
|
||||
from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
|
||||
|
||||
############################################################################
|
||||
|
Loading…
Reference in New Issue
Block a user