mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
添加FASTNLP_NO_SYNC相关的设置
This commit is contained in:
parent
a25a73394b
commit
fcd27cfc3f
@ -1,4 +1,5 @@
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
_pickler = pickle.Pickler
|
||||
_unpickler = pickle.Unpickler
|
||||
@ -7,6 +8,7 @@ from typing import Any, List
|
||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
|
||||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
from fastNLP.envs.env import FASTNLP_NO_SYNC
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA
|
||||
>>> output
|
||||
['foo', 12, {1: 2}]
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if dist.get_rank() == dst:
|
||||
object_gather_list = [None for _ in range(dist.get_world_size(group))]
|
||||
else:
|
||||
object_gather_list = None
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
|
||||
@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP)
|
||||
:param group:
|
||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
if isinstance(obj, torch.Tensor):
|
||||
@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR
|
||||
:param group:
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
if src == dist.get_rank(group):
|
||||
return obj
|
||||
else:
|
||||
return None
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
cur_rank = dist.get_rank(group)
|
||||
@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None):
|
||||
>>> output
|
||||
['foo', 12, {1: 2}]
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
|
@ -29,7 +29,7 @@ from fastNLP.core.samplers import (
|
||||
re_instantiate_sampler,
|
||||
conversion_between_reproducible_and_unrepeated_sampler,
|
||||
)
|
||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED
|
||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
@ -234,7 +234,8 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
self.global_rank = paddledist.get_rank()
|
||||
|
||||
def barrier(self):
|
||||
paddledist.barrier()
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
|
||||
paddledist.barrier()
|
||||
|
||||
def configure_fleet(self):
|
||||
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
|
||||
@ -451,6 +452,8 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
|
||||
"""
|
||||
return
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
|
||||
return
|
||||
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
|
||||
|
||||
def all_gather(self, obj, group) -> List:
|
||||
@ -477,4 +480,6 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
:return:
|
||||
"""
|
||||
return
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
|
||||
return [obj]
|
||||
return fastnlp_paddle_all_gather(obj, group=group)
|
||||
|
Loading…
Reference in New Issue
Block a user