mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 03:37:55 +08:00
完成fastnlp_paddle_all_gather和fastnlp_paddle_broadcast_object函数及测试例
This commit is contained in:
parent
6c532829c5
commit
1c0e331bad
@ -1,27 +1,25 @@
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
_pickler = pickle.Pickler
|
||||
_unpickler = pickle.Unpickler
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
|
||||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
import numpy as np
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
from fastNLP.envs.env import FASTNLP_NO_SYNC
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
try:
|
||||
from torch._C._distributed_c10d import ProcessGroupGloo
|
||||
from torch._C._distributed_c10d import _ProcessGroupWrapper
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
from fastNLP.core.utils import apply_to_collection
|
||||
from fastNLP.core.utils import paddle_move_data_to_device
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from paddle.framework.io import (
|
||||
_is_state_dict,
|
||||
_build_saved_state_dict,
|
||||
_unpack_saved_dict,
|
||||
_pickle_save,
|
||||
_pack_loaded_dict,
|
||||
_ndarray_to_tensor,
|
||||
_parse_load_result,
|
||||
)
|
||||
|
||||
def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
||||
if dst == my_rank:
|
||||
@ -35,48 +33,65 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
||||
"on non-destination ranks."
|
||||
)
|
||||
|
||||
def paddle_pickle_dump(obj, stream, protocol):
|
||||
"""
|
||||
Reference to `paddle.save`
|
||||
"""
|
||||
if _is_state_dict(obj):
|
||||
saved_obj = _build_saved_state_dict(obj)
|
||||
saved_obj = _unpack_saved_dict(saved_obj, protocol)
|
||||
pickle.dump(saved_obj, stream, protocol=protocol)
|
||||
else:
|
||||
_pickle_save(obj, stream, protocol)
|
||||
|
||||
def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
|
||||
def paddle_pickle_load(stream):
|
||||
"""
|
||||
Reference to `paddle.load`
|
||||
"""
|
||||
load_result = pickle.load(stream)
|
||||
if isinstance(load_result, dict):
|
||||
load_result = _pack_loaded_dict(load_result)
|
||||
if "StructuredToParameterName@@" in load_result:
|
||||
|
||||
for key in load_result["StructuredToParameterName@@"]:
|
||||
if isinstance(load_result[key], np.ndarray):
|
||||
load_result[key] = _ndarray_to_tensor(
|
||||
load_result[key], return_numpy=False)
|
||||
|
||||
if "StructuredToParameterName@@" in load_result:
|
||||
del load_result["StructuredToParameterName@@"]
|
||||
else:
|
||||
load_result = _parse_load_result(load_result, return_numpy=False)
|
||||
|
||||
else:
|
||||
load_result = _parse_load_result(load_result, return_numpy=False)
|
||||
|
||||
return load_result
|
||||
|
||||
def _object_to_tensor(obj, device=None):
|
||||
f = io.BytesIO()
|
||||
paddle_pickle_dump(obj, f, protocol=2)
|
||||
byte_data = list(f.getvalue())
|
||||
byte_tensor = paddle.to_tensor(byte_data, dtype=paddle.int32)
|
||||
local_size = paddle.to_tensor([byte_tensor.numel()])
|
||||
if device is not None:
|
||||
byte_tensor = paddle_move_data_to_device(byte_tensor, device)
|
||||
local_size = paddle_move_data_to_device(local_size, device)
|
||||
return byte_tensor, local_size
|
||||
|
||||
def _tensor_to_object(tensor, tensor_size):
|
||||
buf = tensor.astype(paddle.uint8).detach().cpu().numpy().tobytes()[:tensor_size]
|
||||
return paddle_pickle_load(io.BytesIO(buf))
|
||||
|
||||
def fastnlp_paddle_gather_object(obj, dst=0, group=None):
|
||||
"""
|
||||
从其它 rank gather 东西到 dst rank 。
|
||||
|
||||
Gathers picklable objects from the whole group in a single process.
|
||||
Similar to :func:`gather`, but Python objects can be passed in. Note that the
|
||||
object must be picklable in order to be gathered.
|
||||
|
||||
Args:
|
||||
obj (Any): Input object. Must be picklable.
|
||||
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
|
||||
should be correctly sized as the size of the group for this
|
||||
collective and will contain the output. Must be ``None`` on non-dst
|
||||
ranks. (default is ``None``)
|
||||
dst (int, optional): Destination rank. (default is 0)
|
||||
group: (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Default is ``None``.
|
||||
|
||||
Returns:
|
||||
None. On the ``dst`` rank, ``object_gather_list`` will contain the
|
||||
output of the collective.
|
||||
|
||||
.. note:: Note that this API differs slightly from the gather collective
|
||||
since it does not provide an async_op handle and thus will be a blocking
|
||||
call.
|
||||
|
||||
.. note:: Note that this API is not supported when using the NCCL backend.
|
||||
|
||||
.. warning::
|
||||
:func:`gather_object` uses ``pickle`` module implicitly, which is
|
||||
known to be insecure. It is possible to construct malicious pickle data
|
||||
which will execute arbitrary code during unpickling. Only call this
|
||||
function with data you trust.
|
||||
|
||||
Example::
|
||||
>>> # Note: Process group initialization omitted on each rank.
|
||||
>>> import torch.distributed as dist
|
||||
>>> # Assumes world_size of 3.
|
||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
|
||||
>>> output = [None for _ in gather_objects]
|
||||
>>> dist.gather_object(
|
||||
>>> fastnlp_paddle_gather_object(
|
||||
gather_objects[dist.get_rank()],
|
||||
output if dist.get_rank() == 0 else None,
|
||||
dst=0
|
||||
@ -84,99 +99,58 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA
|
||||
>>> # On rank 0
|
||||
>>> output
|
||||
['foo', 12, {1: 2}]
|
||||
|
||||
:param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象
|
||||
:param dst: 目标的 rank 。
|
||||
:param group: 在哪个 group 执行该函数。
|
||||
:return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if dist.get_rank() == dst:
|
||||
object_gather_list = [None for _ in range(dist.get_world_size(group))]
|
||||
object_gather_list = [None for _ in range(dist.get_world_size())]
|
||||
else:
|
||||
object_gather_list = None
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
# if group is None:
|
||||
# TODO 2.2 版本存在 bug
|
||||
# group = dist.collective._get_global_group()
|
||||
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
if group is not None and not group.is_member():
|
||||
return
|
||||
|
||||
# Ensure object_gather_list is specified appopriately.
|
||||
my_rank = dist.get_rank()
|
||||
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
|
||||
# 防止 unpickle 的时候出现在了发送的 gpu 上。
|
||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
|
||||
obj = paddle_move_data_to_device(obj, device="cpu")
|
||||
input_tensor, local_size = _object_to_tensor(obj)
|
||||
group_backend = dist.get_backend(group)
|
||||
current_device = torch.device("cpu")
|
||||
is_nccl_backend = group_backend == dist.Backend.NCCL
|
||||
if is_nccl_backend:
|
||||
current_device = torch.device('cuda', torch.cuda.current_device())
|
||||
input_tensor = input_tensor.to(current_device)
|
||||
local_size = local_size.to(current_device)
|
||||
# Gather all local sizes. This is so that we can find the max size, and index
|
||||
# until the correct size when deserializing the tensors.
|
||||
group_size = dist.get_world_size(group=group)
|
||||
object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device)
|
||||
object_size_list = [
|
||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
|
||||
]
|
||||
# Allgather tensor sizes. An all-gather is needed here despite this being a
|
||||
# gather, since each rank needs to broadcast a tensor of the same (maximal)
|
||||
# size.
|
||||
# 目前 paddle 的 group 仅支持 nccl
|
||||
input_tensor = paddle_move_data_to_device(input_tensor, device=paddle.device.get_device())
|
||||
local_size = paddle_move_data_to_device(local_size, device=paddle.device.get_device())
|
||||
|
||||
# 收集所有的 local_size,找到最大的 size
|
||||
object_size_list = []
|
||||
dist.all_gather(object_size_list, local_size, group=group)
|
||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
|
||||
# Resize tensor to max size across all ranks.
|
||||
input_tensor.resize_(max_object_size)
|
||||
# Avoid populating output tensors if the result won't be gathered on this rank.
|
||||
if my_rank == dst:
|
||||
coalesced_output_tensor = torch.empty(
|
||||
max_object_size * group_size, dtype=torch.uint8, device=current_device
|
||||
)
|
||||
# Output tensors are nonoverlapping views of coalesced_output_tensor
|
||||
output_tensors = [
|
||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
|
||||
for i in range(group_size)
|
||||
]
|
||||
# All ranks call gather with equal-sized tensors.
|
||||
dist.gather(
|
||||
input_tensor,
|
||||
gather_list=output_tensors if my_rank == dst else None,
|
||||
dst=dst,
|
||||
group=group,
|
||||
)
|
||||
input_tensor.reshape_(max_object_size)
|
||||
# TODO 暂时没有在 paddle 中发现类似 torch.distributed.gather 的函数
|
||||
output_tensors = []
|
||||
dist.all_gather(output_tensors, input_tensor, group)
|
||||
if my_rank != dst:
|
||||
return
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
|
||||
tensor = tensor.astype(paddle.uint8)
|
||||
tensor_size = object_size_list[i]
|
||||
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
|
||||
|
||||
|
||||
def _object_to_tensor(obj, device=None):
|
||||
f = io.BytesIO()
|
||||
_pickler(f).dump(obj)
|
||||
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
|
||||
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
|
||||
# Otherwise, it will casue 100X slowdown.
|
||||
# See: https://github.com/pytorch/pytorch/issues/65696
|
||||
byte_tensor = torch.ByteTensor(byte_storage)
|
||||
local_size = torch.LongTensor([byte_tensor.numel()])
|
||||
if device is not None:
|
||||
byte_tensor = byte_tensor.to(device)
|
||||
local_size = local_size.to(device)
|
||||
return byte_tensor, local_size
|
||||
|
||||
|
||||
def _tensor_to_object(tensor, tensor_size):
|
||||
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size]
|
||||
return _unpickler(io.BytesIO(buf)).load()
|
||||
|
||||
|
||||
def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
|
||||
def send_recv_object(obj, src, cur_rank, device, group=None, use_calc_stream=True):
|
||||
# src rank send to all other ranks
|
||||
size = torch.LongTensor([0]).to(device)
|
||||
size = paddle_move_data_to_device(paddle.to_tensor([0]), device)
|
||||
|
||||
if cur_rank == src:
|
||||
world_size = dist.get_world_size(group=group)
|
||||
world_size = dist.get_world_size()
|
||||
tensor, size = _object_to_tensor(obj)
|
||||
tensor = tensor.to(device)
|
||||
size = size.to(device)
|
||||
@ -185,15 +159,15 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
|
||||
dist.broadcast(size, src, group=group)
|
||||
for subrank in range(world_size):
|
||||
if subrank != src:
|
||||
dist.send(tensor=tensor, dst=subrank, group=group, tag=tag)
|
||||
dist.send(tensor=tensor, dst=subrank, group=group, use_calc_stream=use_calc_stream)
|
||||
else:
|
||||
dist.broadcast(size, src, group=group)
|
||||
tensor = torch.ByteTensor([0] * size).to(device)
|
||||
dist.recv(tensor=tensor, src=src, group=group, tag=tag)
|
||||
tensor = paddle_move_data_to_device(paddle.to_tensor([0] * size), device)
|
||||
dist.recv(tensor=tensor, src=src, group=group, use_calc_stream=use_calc_stream)
|
||||
|
||||
return _tensor_to_object(tensor.cpu(), size)
|
||||
|
||||
def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
|
||||
def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List:
|
||||
"""
|
||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。
|
||||
|
||||
@ -220,178 +194,108 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP)
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
if isinstance(obj, torch.Tensor):
|
||||
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
|
||||
# if group is None:
|
||||
# TODO 2.2 版本存在 bug
|
||||
# group = dist.collective._get_global_group()
|
||||
if isinstance(obj, paddle.Tensor):
|
||||
objs = []
|
||||
dist.all_gather(objs, obj, group=group)
|
||||
else:
|
||||
objs = [None for _ in range(dist.get_world_size(group))]
|
||||
objs = [None for _ in range(dist.get_world_size())]
|
||||
# 防止 unpickle 的时候弄到发送的 gpu 上了
|
||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
dist.all_gather_object(objs, obj, group=group)
|
||||
else:
|
||||
objs = all_gather_object(objs, obj, group=group)
|
||||
obj = paddle_move_data_to_device(obj, "cpu")
|
||||
objs = all_gather_object(objs, obj, group=group)
|
||||
|
||||
return objs
|
||||
|
||||
|
||||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
|
||||
def fastnlp_paddle_broadcast_object(obj, src, device=None, group=None):
|
||||
"""
|
||||
将 src 上的 obj 对象广播到其它 rank 上。
|
||||
|
||||
:param obj:
|
||||
:param src:
|
||||
:param obj: 需要发送的对象
|
||||
:param src: 从哪里发出。
|
||||
:param device:
|
||||
:param group:
|
||||
:param group: 属于哪个通信 group
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
if src == dist.get_rank(group):
|
||||
if src == dist.get_rank():
|
||||
return obj
|
||||
else:
|
||||
return None
|
||||
|
||||
if group is None:
|
||||
group = DEFAULT_TORCH_GROUP
|
||||
cur_rank = dist.get_rank(group)
|
||||
cur_rank = dist.get_rank()
|
||||
if cur_rank == src:
|
||||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
|
||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
if cur_rank!=src:
|
||||
get_obj = [None]
|
||||
dist.broadcast_object_list(get_obj, src=src, group=group)
|
||||
return get_obj[0]
|
||||
else:
|
||||
dist.broadcast_object_list([obj], src=src, group=group)
|
||||
return obj
|
||||
obj = paddle_move_data_to_device(obj, "cpu")
|
||||
|
||||
if device is None:
|
||||
device = torch.cuda.current_device()
|
||||
device = paddle.device.get_device()
|
||||
|
||||
if cur_rank == src:
|
||||
tensor, size = _object_to_tensor(obj, device=device)
|
||||
else:
|
||||
size = torch.LongTensor([0]).to(device)
|
||||
size = paddle_move_data_to_device(paddle.to_tensor([0]), device)
|
||||
|
||||
dist.broadcast(size, src=src, group=group)
|
||||
if cur_rank != src:
|
||||
tensor = torch.empty(
|
||||
size.int().item(), # type: ignore[arg-type]
|
||||
dtype=torch.uint8,
|
||||
device=device
|
||||
tensor = paddle.empty(
|
||||
size.astype(paddle.int32), # type: ignore[arg-type]
|
||||
dtype=paddle.int32,
|
||||
)
|
||||
dist.broadcast(tensor, src=src, group=group)
|
||||
|
||||
return _tensor_to_object(tensor, tensor_size=size.item())
|
||||
|
||||
|
||||
def _check_for_nccl_backend(group):
|
||||
pg = group or dist.distributed_c10d._get_default_group()
|
||||
# It is not expected for PG to be wrapped many times, but support it just
|
||||
# in case
|
||||
while isinstance(pg, _ProcessGroupWrapper):
|
||||
pg = pg.wrapped_pg
|
||||
|
||||
return (
|
||||
dist.is_nccl_available() and
|
||||
isinstance(pg, dist.ProcessGroupNCCL)
|
||||
)
|
||||
|
||||
|
||||
def all_gather_object(object_list, obj, group=None):
|
||||
"""
|
||||
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。
|
||||
|
||||
Gathers picklable objects from the whole group into a list. Similar to
|
||||
:func:`all_gather`, but Python objects can be passed in. Note that the object
|
||||
must be picklable in order to be gathered.
|
||||
|
||||
Args:
|
||||
object_list (list[Any]): Output list. It should be correctly sized as the
|
||||
size of the group for this collective and will contain the output.
|
||||
object (Any): Pickable Python object to be broadcast from current process.
|
||||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used. Default is ``None``.
|
||||
|
||||
Returns:
|
||||
None. If the calling rank is part of this group, the output of the
|
||||
collective will be populated into the input ``object_list``. If the
|
||||
calling rank is not part of the group, the passed in ``object_list`` will
|
||||
be unmodified.
|
||||
|
||||
.. note:: Note that this API differs slightly from the :func:`all_gather`
|
||||
collective since it does not provide an ``async_op`` handle and thus
|
||||
will be a blocking call.
|
||||
|
||||
.. note:: For NCCL-based processed groups, internal tensor representations
|
||||
of objects must be moved to the GPU device before communication takes
|
||||
place. In this case, the device used is given by
|
||||
``torch.cuda.current_device()`` and it is the user's responsiblity to
|
||||
ensure that this is set so that each rank has an individual GPU, via
|
||||
``torch.cuda.set_device()``.
|
||||
|
||||
.. warning::
|
||||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
|
||||
known to be insecure. It is possible to construct malicious pickle data
|
||||
which will execute arbitrary code during unpickling. Only call this
|
||||
function with data you trust.
|
||||
|
||||
Example::
|
||||
>>> # Note: Process group initialization omitted on each rank.
|
||||
>>> import torch.distributed as dist
|
||||
>>> # Assumes world_size of 3.
|
||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
|
||||
>>> output = [None for _ in gather_objects]
|
||||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
|
||||
>>> all_gather_object(output, gather_objects[dist.get_rank()])
|
||||
>>> output
|
||||
['foo', 12, {1: 2}]
|
||||
|
||||
:param object_list:
|
||||
:param obj:
|
||||
:param group:
|
||||
:return:
|
||||
"""
|
||||
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
|
||||
return [obj]
|
||||
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
if group is not None and not group.is_member():
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
current_device = torch.device("cpu")
|
||||
is_nccl_backend = _check_for_nccl_backend(group)
|
||||
if is_nccl_backend:
|
||||
# See note about using torch.cuda.current_device() here in docstring.
|
||||
# We cannot simply use my_rank since rank == device is not necessarily
|
||||
# true.
|
||||
current_device = torch.device("cuda", torch.cuda.current_device())
|
||||
else:
|
||||
current_device = torch.cuda.current_device()
|
||||
|
||||
current_device = paddle.device.get_device()
|
||||
|
||||
input_tensor, local_size = _object_to_tensor(obj, device=current_device)
|
||||
|
||||
# Gather all local sizes. This is so that we can find the max size, and index
|
||||
# until the correct size when deserializing the tensors.
|
||||
group_size = dist.get_world_size(group=group)
|
||||
object_sizes_tensor = torch.zeros(
|
||||
group_size, dtype=torch.long, device=current_device
|
||||
)
|
||||
object_size_list = [
|
||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
|
||||
]
|
||||
# 聚合 tensor 的 size,找到最大的
|
||||
object_size_list = []
|
||||
# Allgather tensor sizes
|
||||
dist.all_gather(object_size_list, local_size, group=group)
|
||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
|
||||
# Resize tensor to max size across all ranks.
|
||||
input_tensor.resize_(max_object_size)
|
||||
coalesced_output_tensor = torch.empty(
|
||||
max_object_size * group_size, dtype=torch.uint8, device=current_device
|
||||
)
|
||||
# 将张量进行 pad
|
||||
pad_dims = []
|
||||
pad_by = (max_object_size - local_size).detach().cpu()
|
||||
for val in reversed(pad_by):
|
||||
pad_dims.append(0)
|
||||
pad_dims.append(val.item())
|
||||
tensor_padded = paddle.nn.functional.pad(input_tensor, pad_dims)
|
||||
|
||||
# Output tensors are nonoverlapping views of coalesced_output_tensor
|
||||
output_tensors = [
|
||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
|
||||
for i in range(group_size)
|
||||
]
|
||||
dist.all_gather(output_tensors, input_tensor, group=group)
|
||||
output_tensors = []
|
||||
dist.all_gather(output_tensors, tensor_padded, group=group)
|
||||
dist.barrier()
|
||||
# Deserialize outputs back to object.
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8)
|
||||
if tensor.device != torch.device("cpu"):
|
||||
tensor = tensor.astype(paddle.uint8)
|
||||
if not tensor.place.is_cpu_place():
|
||||
tensor = tensor.cpu()
|
||||
tensor_size = object_size_list[i]
|
||||
object_list[i] = _tensor_to_object(tensor, tensor_size)
|
||||
|
185
tests/core/drivers/paddle_driver/test_dist_utils.py
Normal file
185
tests/core/drivers/paddle_driver/test_dist_utils.py
Normal file
@ -0,0 +1,185 @@
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import pytest
|
||||
import traceback
|
||||
os.environ["FASTNLP_BACKEND"] = "paddle"
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.drivers.paddle_driver.dist_utils import (
|
||||
_tensor_to_object,
|
||||
_object_to_tensor,
|
||||
fastnlp_paddle_all_gather,
|
||||
fastnlp_paddle_broadcast_object,
|
||||
)
|
||||
from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
class TestDistUtilsTools:
|
||||
"""
|
||||
测试一些工具函数
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("device", (["cpu", 0]))
|
||||
def test_tensor_object_transfer_tensor(self, device):
|
||||
"""
|
||||
测试 _tensor_to_object 和 _object_to_tensor 二者的结果能否互相转换
|
||||
"""
|
||||
# 张量
|
||||
paddle_tensor = paddle.rand((3, 4, 5)).cpu()
|
||||
obj_tensor, size = _object_to_tensor(paddle_tensor, device=device)
|
||||
res = _tensor_to_object(obj_tensor, size)
|
||||
assert paddle.equal_all(res, paddle_tensor)
|
||||
|
||||
# 列表
|
||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)]
|
||||
obj_tensor, size = _object_to_tensor(paddle_list, device=device)
|
||||
res = _tensor_to_object(obj_tensor, size)
|
||||
assert isinstance(res, list)
|
||||
for before, after in zip(paddle_list, res):
|
||||
assert paddle.equal_all(after, before)
|
||||
|
||||
# 元组
|
||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)]
|
||||
paddle_tuple = tuple(paddle_list)
|
||||
obj_tensor, size = _object_to_tensor(paddle_tuple, device=device)
|
||||
res = _tensor_to_object(obj_tensor, size)
|
||||
assert isinstance(res, tuple)
|
||||
for before, after in zip(paddle_list, res):
|
||||
assert paddle.equal_all(after, before)
|
||||
|
||||
# 字典
|
||||
paddle_dict = {
|
||||
"tensor": paddle.rand((3, 4)),
|
||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
|
||||
"dict":{
|
||||
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
|
||||
"tensor": paddle.rand((3, 4))
|
||||
},
|
||||
"int": 2,
|
||||
"string": "test string"
|
||||
}
|
||||
obj_tensor, size = _object_to_tensor(paddle_dict, device=device)
|
||||
res = _tensor_to_object(obj_tensor, size)
|
||||
assert isinstance(res, dict)
|
||||
assert paddle.equal_all(res["tensor"], paddle_dict["tensor"])
|
||||
assert isinstance(res["list"], list)
|
||||
for before, after in zip(paddle_dict["list"], res["list"]):
|
||||
assert paddle.equal_all(after, before)
|
||||
|
||||
assert isinstance(res["dict"], dict)
|
||||
assert paddle.equal_all(res["dict"]["tensor"], paddle_dict["dict"]["tensor"])
|
||||
for before, after in zip(paddle_dict["dict"]["list"], res["dict"]["list"]):
|
||||
assert paddle.equal_all(after, before)
|
||||
assert res["int"] == paddle_dict["int"]
|
||||
assert res["string"] == paddle_dict["string"]
|
||||
|
||||
|
||||
class TestAllGatherAndBroadCast:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
devices = [0,1,2]
|
||||
output_from_new_proc = "only_error"
|
||||
|
||||
launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc)
|
||||
cls.local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", "0"))
|
||||
if cls.local_rank == 0:
|
||||
launcher = FleetLauncher(devices, output_from_new_proc)
|
||||
launcher.launch()
|
||||
dist.fleet.init(is_collective=True)
|
||||
dist.barrier()
|
||||
|
||||
# cls._pids = []
|
||||
# dist.all_gather(cls._pids, paddle.to_tensor(os.getpid(), dtype="int32"))
|
||||
# local_world_size = paddle.to_tensor(cls.local_rank, dtype="int32")
|
||||
# dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
|
||||
# local_world_size = local_world_size.item() + 1
|
||||
|
||||
def on_exception(self):
|
||||
if self._pids is not None:
|
||||
|
||||
exc_type, exc_value, exc_traceback_obj = sys.exc_info()
|
||||
traceback.print_tb(exc_traceback_obj, file=sys.stderr)
|
||||
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n")
|
||||
for pid in self._pids:
|
||||
pid = pid.item()
|
||||
if pid != os.getpid():
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_fastnlp_paddle_all_gather(self):
|
||||
obj = {
|
||||
'tensor': paddle.full(shape=(2, ), fill_value=self.local_rank).cuda(),
|
||||
'numpy': np.full(shape=(2, ), fill_value=self.local_rank),
|
||||
'bool': self.local_rank % 2 == 0,
|
||||
'float': self.local_rank + 0.1,
|
||||
'int': self.local_rank,
|
||||
'dict': {
|
||||
'rank': self.local_rank
|
||||
},
|
||||
'list': [self.local_rank] * 2,
|
||||
'str': f'{self.local_rank}',
|
||||
'tensors': [paddle.full(shape=(2,), fill_value=self.local_rank).cuda(),
|
||||
paddle.full(shape=(2,), fill_value=self.local_rank).cuda()]
|
||||
}
|
||||
data = fastnlp_paddle_all_gather(obj)
|
||||
world_size = int(os.environ['PADDLE_TRAINERS_NUM'])
|
||||
assert len(data) == world_size
|
||||
for i in range(world_size):
|
||||
assert (data[i]['tensor'] == i).sum() == 2
|
||||
assert (data[i]['numpy'] == i).sum() == 2
|
||||
assert data[i]['bool'] == (i % 2 == 0)
|
||||
assert np.allclose(data[i]['float'], i + 0.1)
|
||||
assert data[i]['int'] == i
|
||||
assert data[i]['dict']['rank'] == i
|
||||
assert data[i]['list'][0] == i
|
||||
assert data[i]['str'] == f'{i}'
|
||||
assert data[i]['tensors'][0][0] == i
|
||||
|
||||
for obj in [1, True, 'xxx']:
|
||||
data = fastnlp_paddle_all_gather(obj)
|
||||
assert len(data) == world_size
|
||||
assert data[0] == data[1]
|
||||
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("src_rank", ([0, 1, 2]))
|
||||
def test_fastnlp_paddle_broadcast_object(self, src_rank):
|
||||
if self.local_rank == src_rank:
|
||||
obj = {
|
||||
'tensor': paddle.full(shape=(2, ), fill_value=self.local_rank).cuda(),
|
||||
'numpy': np.full(shape=(2, ), fill_value=self.local_rank),
|
||||
'bool': self.local_rank % 2 == 0,
|
||||
'float': self.local_rank + 0.1,
|
||||
'int': self.local_rank,
|
||||
'dict': {
|
||||
'rank': self.local_rank
|
||||
},
|
||||
'list': [self.local_rank] * 2,
|
||||
'str': f'{self.local_rank}',
|
||||
'tensors': [paddle.full(shape=(2,), fill_value=self.local_rank).cuda(),
|
||||
paddle.full(shape=(2,), fill_value=self.local_rank).cuda()]
|
||||
}
|
||||
else:
|
||||
obj = None
|
||||
data = fastnlp_paddle_broadcast_object(obj, src=src_rank, device=paddle.device.get_device())
|
||||
assert data['tensor'][0] == src_rank
|
||||
assert data['numpy'][0] == src_rank
|
||||
assert data['bool'] == (src_rank % 2 == 0)
|
||||
assert np.allclose(data['float'], src_rank + 0.1)
|
||||
assert data['int'] == src_rank
|
||||
assert data['dict']['rank'] == src_rank
|
||||
assert data['list'][0] == src_rank
|
||||
assert data['str'] == f'{src_rank}'
|
||||
assert data['tensors'][0][0] == src_rank
|
||||
|
||||
for obj in [self.local_rank, bool(self.local_rank == 1), str(self.local_rank)]:
|
||||
data = fastnlp_paddle_broadcast_object(obj, src=0, device=paddle.device.get_device())
|
||||
assert int (data) == 0
|
||||
dist.barrier()
|
Loading…
Reference in New Issue
Block a user