mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
添加FASTNLP_NO_SYNC相关的设置
This commit is contained in:
parent
a25a73394b
commit
fcd27cfc3f
@ -1,4 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
_pickler = pickle.Pickler
|
_pickler = pickle.Pickler
|
||||||
_unpickler = pickle.Unpickler
|
_unpickler = pickle.Unpickler
|
||||||
@ -7,6 +8,7 @@ from typing import Any, List
|
|||||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
|
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
|
||||||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
|
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
|
||||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||||
|
from fastNLP.envs.env import FASTNLP_NO_SYNC
|
||||||
if _NEED_IMPORT_TORCH:
|
if _NEED_IMPORT_TORCH:
|
||||||
import torch
|
import torch
|
||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA
|
|||||||
>>> output
|
>>> output
|
||||||
['foo', 12, {1: 2}]
|
['foo', 12, {1: 2}]
|
||||||
"""
|
"""
|
||||||
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||||
|
return [obj]
|
||||||
|
|
||||||
|
if dist.get_rank() == dst:
|
||||||
|
object_gather_list = [None for _ in range(dist.get_world_size(group))]
|
||||||
|
else:
|
||||||
|
object_gather_list = None
|
||||||
|
|
||||||
if group is None:
|
if group is None:
|
||||||
group = DEFAULT_TORCH_GROUP
|
group = DEFAULT_TORCH_GROUP
|
||||||
|
|
||||||
@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP)
|
|||||||
:param group:
|
:param group:
|
||||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
|
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
|
||||||
"""
|
"""
|
||||||
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||||
|
return [obj]
|
||||||
|
|
||||||
if group is None:
|
if group is None:
|
||||||
group = DEFAULT_TORCH_GROUP
|
group = DEFAULT_TORCH_GROUP
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR
|
|||||||
:param group:
|
:param group:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||||
|
if src == dist.get_rank(group):
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
if group is None:
|
if group is None:
|
||||||
group = DEFAULT_TORCH_GROUP
|
group = DEFAULT_TORCH_GROUP
|
||||||
cur_rank = dist.get_rank(group)
|
cur_rank = dist.get_rank(group)
|
||||||
@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None):
|
|||||||
>>> output
|
>>> output
|
||||||
['foo', 12, {1: 2}]
|
['foo', 12, {1: 2}]
|
||||||
"""
|
"""
|
||||||
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||||
|
return [obj]
|
||||||
|
|
||||||
if dist.distributed_c10d._rank_not_in_group(group):
|
if dist.distributed_c10d._rank_not_in_group(group):
|
||||||
return
|
return
|
||||||
if _TORCH_GREATER_EQUAL_1_8:
|
if _TORCH_GREATER_EQUAL_1_8:
|
||||||
|
@ -29,7 +29,7 @@ from fastNLP.core.samplers import (
|
|||||||
re_instantiate_sampler,
|
re_instantiate_sampler,
|
||||||
conversion_between_reproducible_and_unrepeated_sampler,
|
conversion_between_reproducible_and_unrepeated_sampler,
|
||||||
)
|
)
|
||||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED
|
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC
|
||||||
from fastNLP.core.log import logger
|
from fastNLP.core.log import logger
|
||||||
|
|
||||||
if _NEED_IMPORT_PADDLE:
|
if _NEED_IMPORT_PADDLE:
|
||||||
@ -234,7 +234,8 @@ class PaddleFleetDriver(PaddleDriver):
|
|||||||
self.global_rank = paddledist.get_rank()
|
self.global_rank = paddledist.get_rank()
|
||||||
|
|
||||||
def barrier(self):
|
def barrier(self):
|
||||||
paddledist.barrier()
|
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行
|
||||||
|
paddledist.barrier()
|
||||||
|
|
||||||
def configure_fleet(self):
|
def configure_fleet(self):
|
||||||
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
|
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):
|
||||||
@ -451,6 +452,8 @@ class PaddleFleetDriver(PaddleDriver):
|
|||||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
|
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
|
||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
|
||||||
|
return
|
||||||
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
|
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
|
||||||
|
|
||||||
def all_gather(self, obj, group) -> List:
|
def all_gather(self, obj, group) -> List:
|
||||||
@ -477,4 +480,6 @@ class PaddleFleetDriver(PaddleDriver):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
|
||||||
|
return [obj]
|
||||||
return fastnlp_paddle_all_gather(obj, group=group)
|
return fastnlp_paddle_all_gather(obj, group=group)
|
||||||
|
Loading…
Reference in New Issue
Block a user