mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 11:17:50 +08:00
为torch driver的测试例添加销毁通信进程的代码
This commit is contained in:
parent
5520f597ea
commit
ebfd0e966c
@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.core import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader, BatchSampler
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader, BatchSampler
|
||||
|
||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
|
||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"):
|
||||
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension)
|
||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
|
||||
device = [torch.device(i) for i in device]
|
||||
@ -73,107 +74,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
|
||||
############################################################################
|
||||
|
||||
@pytest.mark.torch
|
||||
@magic_argv_env_context
|
||||
def test_multi_drivers():
|
||||
"""
|
||||
测试使用了多个 TorchDDPDriver 的情况。
|
||||
"""
|
||||
generate_driver(10, 10)
|
||||
generate_driver(20, 10)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# 设备设置不同,应该报错
|
||||
generate_driver(20, 3, device=[0,1,2])
|
||||
assert False
|
||||
dist.barrier()
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.torchtemp
|
||||
class TestDDPDriverFunction:
|
||||
"""
|
||||
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.driver = generate_driver(10, 10)
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_multi_drivers(self):
|
||||
def test_simple_functions(self):
|
||||
"""
|
||||
测试使用了多个 TorchDDPDriver 的情况。
|
||||
简单测试多个函数
|
||||
"""
|
||||
|
||||
driver2 = generate_driver(20, 10)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# 设备设置不同,应该报错
|
||||
driver3 = generate_driver(20, 3, device=[0,1,2])
|
||||
assert False
|
||||
driver = generate_driver(10, 10)
|
||||
|
||||
"""
|
||||
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在
|
||||
tests/core/utils/test_torch_utils.py中,就不重复测试了
|
||||
"""
|
||||
driver.move_data_to_device(torch.rand((32, 64)))
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_move_data_to_device(self):
|
||||
"""
|
||||
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中
|
||||
就不重复测试了
|
||||
"""
|
||||
self.driver.move_data_to_device(torch.rand((32, 64)))
|
||||
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_is_distributed(self):
|
||||
"""
|
||||
测试 is_distributed 函数
|
||||
"""
|
||||
assert self.driver.is_distributed() == True
|
||||
assert driver.is_distributed() == True
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_get_no_sync_context(self):
|
||||
"""
|
||||
测试 get_no_sync_context 函数
|
||||
"""
|
||||
res = self.driver.get_model_no_sync_context()
|
||||
res = driver.get_model_no_sync_context()
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_is_global_zero(self):
|
||||
"""
|
||||
测试 is_global_zero 函数
|
||||
"""
|
||||
self.driver.is_global_zero()
|
||||
driver.is_global_zero()
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_unwrap_model(self):
|
||||
"""
|
||||
测试 unwrap_model 函数
|
||||
"""
|
||||
self.driver.unwrap_model()
|
||||
driver.unwrap_model()
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_get_local_rank(self):
|
||||
"""
|
||||
测试 get_local_rank 函数
|
||||
"""
|
||||
self.driver.get_local_rank()
|
||||
driver.get_local_rank()
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_all_gather(self):
|
||||
"""
|
||||
测试 all_gather 函数
|
||||
详细的测试在 test_dist_utils.py 中完成
|
||||
"""
|
||||
obj = {
|
||||
"rank": self.driver.global_rank
|
||||
"rank": driver.global_rank
|
||||
}
|
||||
obj_list = self.driver.all_gather(obj, group=None)
|
||||
obj_list = driver.all_gather(obj, group=None)
|
||||
for i, res in enumerate(obj_list):
|
||||
assert res["rank"] == i
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("src_rank", ([0, 1]))
|
||||
def test_broadcast_object(self, src_rank):
|
||||
"""
|
||||
测试 broadcast_object 函数
|
||||
详细的函数在 test_dist_utils.py 中完成
|
||||
"""
|
||||
if self.driver.global_rank == src_rank:
|
||||
if driver.global_rank == 0:
|
||||
obj = {
|
||||
"rank": self.driver.global_rank
|
||||
"rank": driver.global_rank
|
||||
}
|
||||
else:
|
||||
obj = None
|
||||
res = self.driver.broadcast_object(obj, src=src_rank)
|
||||
assert res["rank"] == src_rank
|
||||
res = driver.broadcast_object(obj, src=0)
|
||||
assert res["rank"] == 0
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
############################################################################
|
||||
#
|
||||
@ -182,12 +176,12 @@ class TestDDPDriverFunction:
|
||||
############################################################################
|
||||
|
||||
@pytest.mark.torch
|
||||
@pytest.mark.torchtemp
|
||||
class TestSetDistReproDataloader:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.device = [0, 1]
|
||||
cls.driver = generate_driver(10, 10, device=cls.device)
|
||||
|
||||
def setup_method(self):
|
||||
self.dataset = TorchNormalDataset(40)
|
||||
@ -204,17 +198,20 @@ class TestSetDistReproDataloader:
|
||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
|
||||
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
|
||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
|
||||
assert replaced_loader.batch_sampler is batch_sampler
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
|
||||
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@ -223,9 +220,10 @@ class TestSetDistReproDataloader:
|
||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
|
||||
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
|
||||
sampler = RandomSampler(self.dataset, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
@ -234,9 +232,11 @@ class TestSetDistReproDataloader:
|
||||
assert replaced_loader.batch_sampler.sampler is sampler
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
|
||||
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
"""
|
||||
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler`
|
||||
@ -251,15 +251,17 @@ class TestSetDistReproDataloader:
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
|
||||
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
with pytest.raises(RuntimeError):
|
||||
# 应当抛出 RuntimeError
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True)
|
||||
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
# @pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
|
||||
"""
|
||||
@ -268,21 +270,24 @@ class TestSetDistReproDataloader:
|
||||
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
|
||||
和原 dataloader 相同
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
|
||||
dataloader.batch_sampler.set_distributed(
|
||||
num_replicas=self.driver.world_size,
|
||||
rank=self.driver.global_rank,
|
||||
num_replicas=driver.world_size,
|
||||
rank=driver.global_rank,
|
||||
pad=True
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 4
|
||||
self.check_distributed_sampler(dataloader.batch_sampler)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
|
||||
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@ -292,12 +297,13 @@ class TestSetDistReproDataloader:
|
||||
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
|
||||
batch_sampler.sampler 和原 dataloader 相同
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
|
||||
dataloader.batch_sampler.sampler.set_distributed(
|
||||
num_replicas=self.driver.world_size,
|
||||
rank=self.driver.global_rank
|
||||
num_replicas=driver.world_size,
|
||||
rank=driver.global_rank
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
@ -307,9 +313,11 @@ class TestSetDistReproDataloader:
|
||||
assert replaced_loader.batch_sampler.batch_size == 4
|
||||
assert replaced_loader.batch_sampler.drop_last == False
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
|
||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
|
||||
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@ -318,11 +326,14 @@ class TestSetDistReproDataloader:
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
|
||||
此时直接返回原来的 dataloader,不做任何处理。
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
|
||||
assert replaced_loader is dataloader
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
"""
|
||||
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数
|
||||
@ -337,12 +348,13 @@ class TestSetDistReproDataloader:
|
||||
的表现
|
||||
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
|
||||
)
|
||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
|
||||
@ -351,6 +363,8 @@ class TestSetDistReproDataloader:
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler)
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@ -361,8 +375,9 @@ class TestSetDistReproDataloader:
|
||||
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
|
||||
的属性
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
@ -372,6 +387,8 @@ class TestSetDistReproDataloader:
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@ -381,8 +398,9 @@ class TestSetDistReproDataloader:
|
||||
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
|
||||
的属性
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
@ -392,6 +410,8 @@ class TestSetDistReproDataloader:
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
"""
|
||||
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数
|
||||
@ -407,8 +427,9 @@ class TestSetDistReproDataloader:
|
||||
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
|
||||
的属性
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
@ -418,6 +439,8 @@ class TestSetDistReproDataloader:
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@ -427,8 +450,9 @@ class TestSetDistReproDataloader:
|
||||
的表现
|
||||
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
@ -439,6 +463,8 @@ class TestSetDistReproDataloader:
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@ -448,8 +474,9 @@ class TestSetDistReproDataloader:
|
||||
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
|
||||
的属性
|
||||
"""
|
||||
driver = generate_driver(10, 10, device=self.device)
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
@ -459,6 +486,8 @@ class TestSetDistReproDataloader:
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def check_distributed_sampler(self, sampler):
|
||||
"""
|
||||
@ -469,7 +498,7 @@ class TestSetDistReproDataloader:
|
||||
if not isinstance(sampler, UnrepeatedSampler):
|
||||
assert sampler.pad == True
|
||||
|
||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
|
||||
def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle):
|
||||
"""
|
||||
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
|
||||
"""
|
||||
@ -501,8 +530,8 @@ class TestSetDistReproDataloader:
|
||||
drop_last=False,
|
||||
)
|
||||
new_loader.batch_sampler.set_distributed(
|
||||
num_replicas=self.driver.world_size,
|
||||
rank=self.driver.global_rank,
|
||||
num_replicas=driver.world_size,
|
||||
rank=driver.global_rank,
|
||||
pad=True
|
||||
)
|
||||
new_loader.batch_sampler.load_state_dict(sampler_states)
|
||||
@ -512,8 +541,8 @@ class TestSetDistReproDataloader:
|
||||
# 重新构造 dataloader
|
||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False)
|
||||
new_loader.batch_sampler.sampler.set_distributed(
|
||||
num_replicas=self.driver.world_size,
|
||||
rank=self.driver.global_rank
|
||||
num_replicas=driver.world_size,
|
||||
rank=driver.global_rank
|
||||
)
|
||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
|
||||
for idx, batch in enumerate(new_loader):
|
||||
@ -534,11 +563,6 @@ class TestSaveLoad:
|
||||
测试多卡情况下 save 和 load 相关函数的表现
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# 不在这里 setup 的话会报错
|
||||
cls.driver = generate_driver(10, 10)
|
||||
|
||||
def setup_method(self):
|
||||
self.dataset = TorchArgMaxDataset(10, 20)
|
||||
|
||||
@ -552,26 +576,26 @@ class TestSaveLoad:
|
||||
path = "model"
|
||||
|
||||
dataloader = DataLoader(self.dataset, batch_size=2)
|
||||
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)
|
||||
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10)
|
||||
|
||||
self.driver1.save_model(path, only_state_dict)
|
||||
driver1.save_model(path, only_state_dict)
|
||||
|
||||
# 同步
|
||||
dist.barrier()
|
||||
self.driver2.load_model(path, only_state_dict)
|
||||
driver2.load_model(path, only_state_dict)
|
||||
|
||||
for idx, batch in enumerate(dataloader):
|
||||
batch = self.driver1.move_data_to_device(batch)
|
||||
res1 = self.driver1.model(
|
||||
batch = driver1.move_data_to_device(batch)
|
||||
res1 = driver1.model(
|
||||
batch,
|
||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
|
||||
fastnlp_fn=driver1.model.module.model.evaluate_step,
|
||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
|
||||
fastnlp_signature_fn=None,
|
||||
wo_auto_param_call=False,
|
||||
)
|
||||
res2 = self.driver2.model(
|
||||
res2 = driver2.model(
|
||||
batch,
|
||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
|
||||
fastnlp_fn=driver2.model.module.model.evaluate_step,
|
||||
fastnlp_signature_fn=None,
|
||||
wo_auto_param_call=False,
|
||||
)
|
||||
@ -580,6 +604,9 @@ class TestSaveLoad:
|
||||
finally:
|
||||
rank_zero_rm(path)
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("only_state_dict", ([True, False]))
|
||||
@pytest.mark.parametrize("fp16", ([True, False]))
|
||||
@ -593,7 +620,7 @@ class TestSaveLoad:
|
||||
path = "model.ckp"
|
||||
num_replicas = len(device)
|
||||
|
||||
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
|
||||
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
|
||||
generate_driver(10, 10, device=device, fp16=False)
|
||||
dataloader = dataloader_with_bucketedbatchsampler(
|
||||
self.dataset,
|
||||
@ -603,8 +630,8 @@ class TestSaveLoad:
|
||||
drop_last=False
|
||||
)
|
||||
dataloader.batch_sampler.set_distributed(
|
||||
num_replicas=self.driver1.world_size,
|
||||
rank=self.driver1.global_rank,
|
||||
num_replicas=driver1.world_size,
|
||||
rank=driver1.global_rank,
|
||||
pad=True
|
||||
)
|
||||
num_consumed_batches = 2
|
||||
@ -623,7 +650,7 @@ class TestSaveLoad:
|
||||
# 保存状态
|
||||
sampler_states = dataloader.batch_sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
|
||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
|
||||
# 加载
|
||||
# 更改 batch_size
|
||||
dataloader = dataloader_with_bucketedbatchsampler(
|
||||
@ -634,11 +661,11 @@ class TestSaveLoad:
|
||||
drop_last=False
|
||||
)
|
||||
dataloader.batch_sampler.set_distributed(
|
||||
num_replicas=self.driver2.world_size,
|
||||
rank=self.driver2.global_rank,
|
||||
num_replicas=driver2.world_size,
|
||||
rank=driver2.global_rank,
|
||||
pad=True
|
||||
)
|
||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
replaced_loader = load_states.pop("dataloader")
|
||||
# 1. 检查 optimizer 的状态
|
||||
# TODO optimizer 的 state_dict 总是为空
|
||||
@ -652,7 +679,7 @@ class TestSaveLoad:
|
||||
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
@ -664,16 +691,16 @@ class TestSaveLoad:
|
||||
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
res1 = self.driver1.model(
|
||||
res1 = driver1.model(
|
||||
batch,
|
||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
|
||||
fastnlp_fn=driver1.model.module.model.evaluate_step,
|
||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
|
||||
fastnlp_signature_fn=None,
|
||||
wo_auto_param_call=False,
|
||||
)
|
||||
res2 = self.driver2.model(
|
||||
res2 = driver2.model(
|
||||
batch,
|
||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
|
||||
fastnlp_fn=driver2.model.module.model.evaluate_step,
|
||||
fastnlp_signature_fn=None,
|
||||
wo_auto_param_call=False,
|
||||
)
|
||||
@ -686,6 +713,9 @@ class TestSaveLoad:
|
||||
finally:
|
||||
rank_zero_rm(path)
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("only_state_dict", ([True, False]))
|
||||
@pytest.mark.parametrize("fp16", ([True, False]))
|
||||
@ -700,13 +730,13 @@ class TestSaveLoad:
|
||||
|
||||
num_replicas = len(device)
|
||||
|
||||
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
|
||||
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
|
||||
driver1 = generate_driver(10, 10, device=device, fp16=fp16)
|
||||
driver2 = generate_driver(10, 10, device=device, fp16=False)
|
||||
|
||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
|
||||
dataloader.batch_sampler.sampler.set_distributed(
|
||||
num_replicas=self.driver1.world_size,
|
||||
rank=self.driver1.global_rank,
|
||||
num_replicas=driver1.world_size,
|
||||
rank=driver1.global_rank,
|
||||
pad=True
|
||||
)
|
||||
num_consumed_batches = 2
|
||||
@ -726,18 +756,18 @@ class TestSaveLoad:
|
||||
sampler_states = dataloader.batch_sampler.sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
if only_state_dict:
|
||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
|
||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
|
||||
else:
|
||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
|
||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
|
||||
# 加载
|
||||
# 更改 batch_size
|
||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False)
|
||||
dataloader.batch_sampler.sampler.set_distributed(
|
||||
num_replicas=self.driver2.world_size,
|
||||
rank=self.driver2.global_rank,
|
||||
num_replicas=driver2.world_size,
|
||||
rank=driver2.global_rank,
|
||||
pad=True
|
||||
)
|
||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
replaced_loader = load_states.pop("dataloader")
|
||||
|
||||
# 1. 检查 optimizer 的状态
|
||||
@ -753,7 +783,7 @@ class TestSaveLoad:
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
|
||||
# 3. 检查 fp16 是否被加载
|
||||
if fp16:
|
||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
@ -765,16 +795,16 @@ class TestSaveLoad:
|
||||
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
res1 = self.driver1.model(
|
||||
res1 = driver1.model(
|
||||
batch,
|
||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
|
||||
fastnlp_fn=driver1.model.module.model.evaluate_step,
|
||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model
|
||||
fastnlp_signature_fn=None,
|
||||
wo_auto_param_call=False,
|
||||
)
|
||||
res2 = self.driver2.model(
|
||||
res2 = driver2.model(
|
||||
batch,
|
||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
|
||||
fastnlp_fn=driver2.model.module.model.evaluate_step,
|
||||
fastnlp_signature_fn=None,
|
||||
wo_auto_param_call=False,
|
||||
)
|
||||
@ -786,4 +816,7 @@ class TestSaveLoad:
|
||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
|
||||
|
||||
finally:
|
||||
rank_zero_rm(path)
|
||||
rank_zero_rm(path)
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
@ -2,12 +2,12 @@ import pytest
|
||||
|
||||
from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver
|
||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
|
||||
from fastNLP.envs import get_gpu_count
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@pytest.mark.torch
|
||||
def test_incorrect_driver():
|
||||
@ -55,6 +55,9 @@ def test_get_ddp_2(driver, device):
|
||||
driver = initialize_torch_driver(driver, device, model)
|
||||
|
||||
assert isinstance(driver, TorchDDPDriver)
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@ -76,6 +79,9 @@ def test_get_ddp(driver, device):
|
||||
driver = initialize_torch_driver(driver, device, model)
|
||||
|
||||
assert isinstance(driver, TorchDDPDriver)
|
||||
dist.barrier()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.mark.torch
|
||||
@ -83,7 +89,6 @@ def test_get_ddp(driver, device):
|
||||
("driver", "device"),
|
||||
[("torch_ddp", "cpu")]
|
||||
)
|
||||
@magic_argv_env_context
|
||||
def test_get_ddp_cpu(driver, device):
|
||||
"""
|
||||
测试试图在 cpu 上初始化分布式训练的情况
|
||||
@ -102,7 +107,6 @@ def test_get_ddp_cpu(driver, device):
|
||||
"driver",
|
||||
["torch", "torch_ddp"]
|
||||
)
|
||||
@magic_argv_env_context
|
||||
def test_device_out_of_range(driver, device):
|
||||
"""
|
||||
测试传入的device超过范围的情况
|
||||
|
Loading…
Reference in New Issue
Block a user