mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
torch test_dist_utils.py中torch.full添加dtype以兼容不同的torchban蹦
This commit is contained in:
parent
49eb1fcc6b
commit
51a4439737
@ -15,69 +15,20 @@ from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_con
|
|||||||
@pytest.mark.torch
|
@pytest.mark.torch
|
||||||
@magic_argv_env_context
|
@magic_argv_env_context
|
||||||
def test_fastnlp_torch_all_gather():
|
def test_fastnlp_torch_all_gather():
|
||||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
try:
|
||||||
os.environ['MASTER_PORT'] = '29500'
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
|
os.environ['MASTER_PORT'] = '29500'
|
||||||
os.environ['LOCAL_RANK'] = '0'
|
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
|
||||||
os.environ['RANK'] = '0'
|
os.environ['LOCAL_RANK'] = '0'
|
||||||
os.environ['WORLD_SIZE'] = '2'
|
os.environ['RANK'] = '0'
|
||||||
re_run_current_cmd_for_torch(1, output_from_new_proc='all')
|
os.environ['WORLD_SIZE'] = '2'
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
re_run_current_cmd_for_torch(1, output_from_new_proc='all')
|
||||||
torch.distributed.barrier()
|
torch.distributed.init_process_group(backend='nccl')
|
||||||
local_rank = int(os.environ['LOCAL_RANK'])
|
torch.distributed.barrier()
|
||||||
torch.cuda.set_device(local_rank)
|
local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
obj = {
|
torch.cuda.set_device(local_rank)
|
||||||
'tensor': torch.full(size=(2,), fill_value=local_rank).cuda(),
|
|
||||||
'numpy': np.full(shape=(2, ), fill_value=local_rank),
|
|
||||||
'bool': local_rank%2==0,
|
|
||||||
'float': local_rank + 0.1,
|
|
||||||
'int': local_rank,
|
|
||||||
'dict': {
|
|
||||||
'rank': local_rank
|
|
||||||
},
|
|
||||||
'list': [local_rank]*2,
|
|
||||||
'str': f'{local_rank}',
|
|
||||||
'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(),
|
|
||||||
torch.full(size=(2,), fill_value=local_rank).cuda()]
|
|
||||||
}
|
|
||||||
data = fastnlp_torch_all_gather(obj)
|
|
||||||
world_size = int(os.environ['WORLD_SIZE'])
|
|
||||||
assert len(data) == world_size
|
|
||||||
for i in range(world_size):
|
|
||||||
assert (data[i]['tensor']==i).sum()==world_size
|
|
||||||
assert data[i]['numpy'][0]==i
|
|
||||||
assert data[i]['bool']==(i%2==0)
|
|
||||||
assert np.allclose(data[i]['float'], i+0.1)
|
|
||||||
assert data[i]['int'] == i
|
|
||||||
assert data[i]['dict']['rank'] == i
|
|
||||||
assert data[i]['list'][0] == i
|
|
||||||
assert data[i]['str'] == f'{i}'
|
|
||||||
assert data[i]['tensors'][0][0] == i
|
|
||||||
|
|
||||||
for obj in [1, True, 'xxx']:
|
|
||||||
data = fastnlp_torch_all_gather(obj)
|
|
||||||
assert len(data)==world_size
|
|
||||||
assert data[0]==data[1]
|
|
||||||
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
@pytest.mark.torch
|
|
||||||
@magic_argv_env_context
|
|
||||||
def test_fastnlp_torch_broadcast_object():
|
|
||||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
|
||||||
os.environ['MASTER_PORT'] = '29500'
|
|
||||||
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
|
|
||||||
os.environ['LOCAL_RANK'] = '0'
|
|
||||||
os.environ['RANK'] = '0'
|
|
||||||
os.environ['WORLD_SIZE'] = '2'
|
|
||||||
re_run_current_cmd_for_torch(1, output_from_new_proc='all')
|
|
||||||
torch.distributed.init_process_group(backend='nccl')
|
|
||||||
torch.distributed.barrier()
|
|
||||||
local_rank = int(os.environ['LOCAL_RANK'])
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
if os.environ['LOCAL_RANK']=="0":
|
|
||||||
obj = {
|
obj = {
|
||||||
'tensor': torch.full(size=(2,), fill_value=local_rank).cuda(),
|
'tensor': torch.full(size=(2,), fill_value=local_rank, dtype=int).cuda(),
|
||||||
'numpy': np.full(shape=(2, ), fill_value=local_rank),
|
'numpy': np.full(shape=(2, ), fill_value=local_rank),
|
||||||
'bool': local_rank%2==0,
|
'bool': local_rank%2==0,
|
||||||
'float': local_rank + 0.1,
|
'float': local_rank + 0.1,
|
||||||
@ -87,24 +38,77 @@ def test_fastnlp_torch_broadcast_object():
|
|||||||
},
|
},
|
||||||
'list': [local_rank]*2,
|
'list': [local_rank]*2,
|
||||||
'str': f'{local_rank}',
|
'str': f'{local_rank}',
|
||||||
'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(),
|
'tensors': [torch.full(size=(2,), fill_value=local_rank, dtype=int).cuda(),
|
||||||
torch.full(size=(2,), fill_value=local_rank).cuda()]
|
torch.full(size=(2,), fill_value=local_rank, dtype=int).cuda()]
|
||||||
}
|
}
|
||||||
else:
|
data = fastnlp_torch_all_gather(obj)
|
||||||
obj = None
|
world_size = int(os.environ['WORLD_SIZE'])
|
||||||
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device())
|
assert len(data) == world_size
|
||||||
i = 0
|
for i in range(world_size):
|
||||||
assert data['tensor'][0]==0
|
assert (data[i]['tensor']==i).sum()==world_size
|
||||||
assert data['numpy'][0]==0
|
assert data[i]['numpy'][0]==i
|
||||||
assert data['bool']==(i%2==0)
|
assert data[i]['bool']==(i%2==0)
|
||||||
assert np.allclose(data['float'], i+0.1)
|
assert np.allclose(data[i]['float'], i+0.1)
|
||||||
assert data['int'] == i
|
assert data[i]['int'] == i
|
||||||
assert data['dict']['rank'] == i
|
assert data[i]['dict']['rank'] == i
|
||||||
assert data['list'][0] == i
|
assert data[i]['list'][0] == i
|
||||||
assert data['str'] == f'{i}'
|
assert data[i]['str'] == f'{i}'
|
||||||
assert data['tensors'][0][0] == i
|
assert data[i]['tensors'][0][0] == i
|
||||||
|
|
||||||
for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]:
|
for obj in [1, True, 'xxx']:
|
||||||
|
data = fastnlp_torch_all_gather(obj)
|
||||||
|
assert len(data)==world_size
|
||||||
|
assert data[0]==data[1]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
@pytest.mark.torch
|
||||||
|
@magic_argv_env_context
|
||||||
|
def test_fastnlp_torch_broadcast_object():
|
||||||
|
try:
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '29500'
|
||||||
|
if 'LOCAL_RANK' not in os.environ and 'RANK' not in os.environ and 'WORLD_SIZE' not in os.environ:
|
||||||
|
os.environ['LOCAL_RANK'] = '0'
|
||||||
|
os.environ['RANK'] = '0'
|
||||||
|
os.environ['WORLD_SIZE'] = '2'
|
||||||
|
re_run_current_cmd_for_torch(1, output_from_new_proc='all')
|
||||||
|
torch.distributed.init_process_group(backend='nccl')
|
||||||
|
torch.distributed.barrier()
|
||||||
|
local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
if os.environ['LOCAL_RANK']=="0":
|
||||||
|
obj = {
|
||||||
|
'tensor': torch.full(size=(2,), fill_value=local_rank, dtype=int).cuda(),
|
||||||
|
'numpy': np.full(shape=(2, ), fill_value=local_rank, dtype=int),
|
||||||
|
'bool': local_rank%2==0,
|
||||||
|
'float': local_rank + 0.1,
|
||||||
|
'int': local_rank,
|
||||||
|
'dict': {
|
||||||
|
'rank': local_rank
|
||||||
|
},
|
||||||
|
'list': [local_rank]*2,
|
||||||
|
'str': f'{local_rank}',
|
||||||
|
'tensors': [torch.full(size=(2,), fill_value=local_rank, dtype=int).cuda(),
|
||||||
|
torch.full(size=(2,), fill_value=local_rank, dtype=int).cuda()]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
obj = None
|
||||||
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device())
|
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device())
|
||||||
assert int(data)==0
|
i = 0
|
||||||
dist.destroy_process_group()
|
assert data['tensor'][0]==0
|
||||||
|
assert data['numpy'][0]==0
|
||||||
|
assert data['bool']==(i%2==0)
|
||||||
|
assert np.allclose(data['float'], i+0.1)
|
||||||
|
assert data['int'] == i
|
||||||
|
assert data['dict']['rank'] == i
|
||||||
|
assert data['list'][0] == i
|
||||||
|
assert data['str'] == f'{i}'
|
||||||
|
assert data['tensors'][0][0] == i
|
||||||
|
|
||||||
|
for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]:
|
||||||
|
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device())
|
||||||
|
assert int(data)==0
|
||||||
|
finally:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
Loading…
Reference in New Issue
Block a user