添加FASTNLP_NO_SYNC相关的设置

This commit is contained in:
x54-729 2022-04-16 05:50:53 +00:00
parent a25a73394b
commit fcd27cfc3f
2 changed files with 29 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import io import io
import os
import pickle import pickle
_pickler = pickle.Pickler _pickler = pickle.Pickler
_unpickler = pickle.Unpickler _unpickler = pickle.Unpickler
@ -7,6 +8,7 @@ from typing import Any, List
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.env import FASTNLP_NO_SYNC
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
from torch import distributed as dist from torch import distributed as dist
@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA
>>> output >>> output
['foo', 12, {1: 2}] ['foo', 12, {1: 2}]
""" """
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]
if dist.get_rank() == dst:
object_gather_list = [None for _ in range(dist.get_world_size(group))]
else:
object_gather_list = None
if group is None: if group is None:
group = DEFAULT_TORCH_GROUP group = DEFAULT_TORCH_GROUP
@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP)
:param group: :param group:
:return: 返回的结果是 [obj0, obj1, ...]其中 obj_i 即为第 i rank 上的 obj :return: 返回的结果是 [obj0, obj1, ...]其中 obj_i 即为第 i rank 上的 obj
""" """
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]
if group is None: if group is None:
group = DEFAULT_TORCH_GROUP group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR
:param group: :param group:
:return: :return:
""" """
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
if src == dist.get_rank(group):
return obj
else:
return None
if group is None: if group is None:
group = DEFAULT_TORCH_GROUP group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group) cur_rank = dist.get_rank(group)
@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None):
>>> output >>> output
['foo', 12, {1: 2}] ['foo', 12, {1: 2}]
""" """
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]
if dist.distributed_c10d._rank_not_in_group(group): if dist.distributed_c10d._rank_not_in_group(group):
return return
if _TORCH_GREATER_EQUAL_1_8: if _TORCH_GREATER_EQUAL_1_8:

View File

@ -29,7 +29,7 @@ from fastNLP.core.samplers import (
re_instantiate_sampler, re_instantiate_sampler,
conversion_between_reproducible_and_unrepeated_sampler, conversion_between_reproducible_and_unrepeated_sampler,
) )
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
from fastNLP.core.log import logger from fastNLP.core.log import logger
if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
@ -234,7 +234,8 @@ class PaddleFleetDriver(PaddleDriver):
self.global_rank = paddledist.get_rank() self.global_rank = paddledist.get_rank()
def barrier(self): def barrier(self):
paddledist.barrier() if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
paddledist.barrier()
def configure_fleet(self): def configure_fleet(self):
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
@ -451,6 +452,8 @@ class PaddleFleetDriver(PaddleDriver):
接收到的参数如果是 source 端则返回发射的内容既不是发送端又不是接收端则返回 None 接收到的参数如果是 source 端则返回发射的内容既不是发送端又不是接收端则返回 None
""" """
return return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
return
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
def all_gather(self, obj, group) -> List: def all_gather(self, obj, group) -> List:
@ -477,4 +480,6 @@ class PaddleFleetDriver(PaddleDriver):
:return: :return:
""" """
return return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
return [obj]
return fastnlp_paddle_all_gather(obj, group=group) return fastnlp_paddle_all_gather(obj, group=group)