mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 11:48:09 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
02e080d239
@ -1,21 +1,35 @@
|
||||
from dataclasses import replace
|
||||
import pytest
|
||||
import os
|
||||
import numpy as np
|
||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle
|
||||
|
||||
set_env_on_import_paddle()
|
||||
os.environ["FASTNLP_BACKEND"] = "paddle"
|
||||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
|
||||
from fastNLP.core.samplers import (
|
||||
RandomSampler,
|
||||
UnrepeatedSampler,
|
||||
BucketedBatchSampler,
|
||||
UnrepeatedRandomSampler,
|
||||
UnrepeatedSequentialSampler,
|
||||
)
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DataLoader, BatchSampler
|
||||
|
||||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
|
||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
|
||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
|
||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.core import synchronize_safe_rm
|
||||
def generate_driver(num_labels, feature_dimension):
|
||||
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=[0,1],
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
driver.setup()
|
||||
|
||||
return driver
|
||||
|
||||
############################################################################
|
||||
#
|
||||
@ -23,269 +37,340 @@ from fastNLP.core import synchronize_safe_rm
|
||||
#
|
||||
############################################################################
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_move_data_to_device():
|
||||
class TestFleetDriverFunction:
|
||||
"""
|
||||
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中
|
||||
就不重复测试了
|
||||
测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
|
||||
"""
|
||||
try:
|
||||
paddle_model = PaddleNormalModel_Classification(10, 784)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=[0,1],
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
# 区分launch和子进程setup的时候
|
||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
|
||||
with pytest.raises(SystemExit) as e:
|
||||
driver.setup()
|
||||
assert e.value.code == 0
|
||||
return
|
||||
else:
|
||||
driver.setup()
|
||||
driver.move_data_to_device(paddle.rand((32, 64)))
|
||||
finally:
|
||||
synchronize_safe_rm("log")
|
||||
|
||||
dist.barrier()
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.driver = generate_driver(10, 10)
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_move_data_to_device(self):
|
||||
"""
|
||||
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中
|
||||
就不重复测试了
|
||||
"""
|
||||
self.driver.move_data_to_device(paddle.rand((32, 64)))
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_is_distributed():
|
||||
print(os.getenv("CUDA_VISIBLE_DEVICES"))
|
||||
print(paddle.device.get_device())
|
||||
try:
|
||||
paddle_model = PaddleNormalModel_Classification(10, 784)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=[0,1],
|
||||
output_from_new_proc='all'
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
# 区分launch和子进程setup的时候
|
||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
|
||||
with pytest.raises(SystemExit) as e:
|
||||
driver.setup()
|
||||
assert e.value.code == 0
|
||||
return
|
||||
else:
|
||||
driver.setup()
|
||||
assert driver.is_distributed() == True
|
||||
finally:
|
||||
synchronize_safe_rm("log")
|
||||
dist.barrier()
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_is_distributed(self):
|
||||
"""
|
||||
测试 is_distributed 函数
|
||||
"""
|
||||
assert self.driver.is_distributed() == True
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_get_no_sync_context():
|
||||
"""
|
||||
测试能否运行
|
||||
"""
|
||||
try:
|
||||
paddle_model = PaddleNormalModel_Classification(10, 784)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=[0,1],
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
# 区分launch和子进程setup的时候
|
||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
|
||||
with pytest.raises(SystemExit) as e:
|
||||
driver.setup()
|
||||
assert e.value.code == 0
|
||||
return
|
||||
else:
|
||||
driver.setup()
|
||||
res = driver.get_no_sync_context()
|
||||
finally:
|
||||
synchronize_safe_rm("log")
|
||||
dist.barrier()
|
||||
@magic_argv_env_context
|
||||
def test_get_no_sync_context(self):
|
||||
"""
|
||||
测试 get_no_sync_context 函数
|
||||
"""
|
||||
res = self.driver.get_no_sync_context()
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_is_global_zero(self):
|
||||
"""
|
||||
测试 is_global_zero 函数
|
||||
"""
|
||||
self.driver.is_global_zero()
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_is_global_zero():
|
||||
try:
|
||||
paddle_model = PaddleNormalModel_Classification(10, 784)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=[0,1],
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
# 区分launch和子进程setup的时候
|
||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
|
||||
with pytest.raises(SystemExit) as e:
|
||||
driver.setup()
|
||||
assert e.value.code == 0
|
||||
return
|
||||
else:
|
||||
driver.setup()
|
||||
driver.is_global_zero()
|
||||
finally:
|
||||
synchronize_safe_rm("log")
|
||||
dist.barrier()
|
||||
@magic_argv_env_context
|
||||
def test_unwrap_model(self):
|
||||
"""
|
||||
测试 unwrap_model 函数
|
||||
"""
|
||||
self.driver.unwrap_model()
|
||||
dist.barrier()
|
||||
|
||||
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_unwrap_model():
|
||||
try:
|
||||
paddle_model = PaddleNormalModel_Classification(10, 784)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=[0,1],
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
# 区分launch和子进程setup的时候
|
||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
|
||||
with pytest.raises(SystemExit) as e:
|
||||
driver.setup()
|
||||
assert e.value.code == 0
|
||||
return
|
||||
else:
|
||||
driver.setup()
|
||||
driver.unwrap_model()
|
||||
finally:
|
||||
synchronize_safe_rm("log")
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_get_local_rank():
|
||||
try:
|
||||
paddle_model = PaddleNormalModel_Classification(10, 784)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=[0,1],
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
# 区分launch和子进程setup的时候
|
||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
|
||||
with pytest.raises(SystemExit) as e:
|
||||
driver.setup()
|
||||
assert e.value.code == 0
|
||||
return
|
||||
else:
|
||||
driver.setup()
|
||||
driver.get_local_rank()
|
||||
finally:
|
||||
synchronize_safe_rm("log")
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize(
|
||||
"dist_sampler",
|
||||
["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"reproducible",
|
||||
[True, False]
|
||||
)
|
||||
def test_replace_sampler(dist_sampler, reproducible):
|
||||
"""
|
||||
测试replace_sampler
|
||||
"""
|
||||
try:
|
||||
paddle_model = PaddleNormalModel_Classification(10, 784)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=[0,1],
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
# 区分launch和子进程setup的时候
|
||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
|
||||
with pytest.raises(SystemExit) as e:
|
||||
driver.setup()
|
||||
assert e.value.code == 0
|
||||
return
|
||||
else:
|
||||
driver.setup()
|
||||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
|
||||
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
|
||||
finally:
|
||||
synchronize_safe_rm("log")
|
||||
dist.barrier()
|
||||
@magic_argv_env_context
|
||||
def test_get_local_rank(self):
|
||||
"""
|
||||
测试 get_local_rank 函数
|
||||
"""
|
||||
self.driver.get_local_rank()
|
||||
dist.barrier()
|
||||
|
||||
############################################################################
|
||||
#
|
||||
# 测试单机多卡的训练情况
|
||||
# 测试 set_dist_repro_dataloader 函数
|
||||
#
|
||||
############################################################################
|
||||
|
||||
@magic_argv_env_context
|
||||
class SingleMachineMultiGPUTrainingTestCase:
|
||||
class TestSetDistReproDataloader:
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.driver = generate_driver(10, 10)
|
||||
|
||||
def setup_method(self):
|
||||
self.dataset = PaddleNormalDataset(20)
|
||||
|
||||
"""
|
||||
测试在单机多卡上使用PaddleFleetDriver进行训练。
|
||||
分布式训练用pytest会有些混乱
|
||||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况
|
||||
此时对应 driver.load 中的情况
|
||||
"""
|
||||
|
||||
def test_case1(self):
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
|
||||
|
||||
gpus = [0, 1]
|
||||
lr = 0.0003
|
||||
epochs = 20
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
|
||||
assert replaced_loader.batch_sampler is batch_sampler
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
paddle_model = PaddleNormalModel_Classification()
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_sampler(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
sampler = RandomSampler(self.dataset, shuffle=True)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
|
||||
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr)
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert replaced_loader.batch_sampler.sampler is sampler
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
|
||||
train_dataset = PaddleDataset_MNIST("train")
|
||||
test_dataset = PaddleDataset_MNIST("test")
|
||||
loss_func = paddle.nn.CrossEntropyLoss()
|
||||
dist.barrier()
|
||||
|
||||
"""
|
||||
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler`
|
||||
参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。
|
||||
当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定
|
||||
是否重新实例化 dataloader
|
||||
"""
|
||||
|
||||
dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True)
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
with pytest.raises(RuntimeError):
|
||||
# 应当抛出 RuntimeError
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)
|
||||
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
parallel_device=gpus,
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
|
||||
时的表现
|
||||
"""
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4),
|
||||
)
|
||||
driver.set_optimizers(paddle_opt)
|
||||
dataloader = driver.set_dist_repro_dataloader(dataloader, )
|
||||
driver.setup()
|
||||
# 检查model_device
|
||||
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}")
|
||||
dataloader.batch_sampler.set_distributed(
|
||||
num_replicas=self.driver.world_size,
|
||||
rank=self.driver.global_rank,
|
||||
pad=True
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
|
||||
driver.barrier()
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 4
|
||||
self.check_distributed_sampler(dataloader.batch_sampler)
|
||||
|
||||
driver.zero_grad()
|
||||
current_epoch_idx = 0
|
||||
while current_epoch_idx < epochs:
|
||||
epoch_loss, batch = 0, 0
|
||||
driver.set_model_mode("train")
|
||||
driver.set_sampler_epoch(dataloader, current_epoch_idx)
|
||||
for batch, (img, label) in enumerate(dataloader):
|
||||
dist.barrier()
|
||||
|
||||
img = paddle.to_tensor(img)
|
||||
out = driver.train_step(img)
|
||||
label + 1
|
||||
loss = loss_func(out, label)
|
||||
epoch_loss += loss.item()
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, True)
|
||||
batch_sampler.sampler.set_distributed(
|
||||
num_replicas=self.driver.world_size,
|
||||
rank=self.driver.global_rank
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=batch_sampler
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
|
||||
if batch % 50 == 0:
|
||||
print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank))
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 2
|
||||
assert replaced_loader.batch_sampler.drop_last == False
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
|
||||
driver.backward(loss)
|
||||
driver.step()
|
||||
driver.zero_grad()
|
||||
driver.barrier()
|
||||
current_epoch_idx += 1
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
|
||||
|
||||
# test
|
||||
correct = 0
|
||||
driver.set_model_mode("eval")
|
||||
for img, label in test_dataset:
|
||||
assert replaced_loader is dataloader
|
||||
dist.barrier()
|
||||
|
||||
img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
|
||||
out = driver.test_step(img)
|
||||
res = paddle.nn.functional.softmax(out).argmax().item()
|
||||
label = label.item()
|
||||
if res == label:
|
||||
correct += 1
|
||||
"""
|
||||
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数
|
||||
为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader
|
||||
"""
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
|
||||
的表现
|
||||
"""
|
||||
dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 4
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler)
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
|
||||
的表现
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, True)
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=batch_sampler
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 2
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == True
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == True
|
||||
dist.barrier()
|
||||
|
||||
"""
|
||||
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数
|
||||
为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader
|
||||
"""
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
|
||||
的表现
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, True)
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=batch_sampler
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 2
|
||||
assert replaced_loader.batch_sampler.sampler.shuffle == True
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
|
||||
的表现
|
||||
"""
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
|
||||
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True)
|
||||
dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_sampler=batch_sampler
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
|
||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 2
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
|
||||
@magic_argv_env_context
|
||||
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
|
||||
"""
|
||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
|
||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
|
||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler)
|
||||
assert replaced_loader.batch_sampler.batch_size == 4
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
|
||||
dist.barrier()
|
||||
|
||||
def check_distributed_sampler(self, sampler):
|
||||
"""
|
||||
测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确
|
||||
"""
|
||||
assert sampler.num_replicas == dist.get_world_size()
|
||||
assert sampler.rank == dist.get_rank()
|
||||
if not isinstance(sampler, UnrepeatedSampler):
|
||||
assert sampler.pad == True
|
||||
|
||||
print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset)))
|
||||
|
Loading…
Reference in New Issue
Block a user