Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
YWMditto 2022-04-15 00:18:27 +08:00
commit 02e080d239

View File

@ -1,21 +1,35 @@
from dataclasses import replace
import pytest
import os
import numpy as np
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle
set_env_on_import_paddle()
os.environ["FASTNLP_BACKEND"] = "paddle"
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
from fastNLP.core.samplers import (
RandomSampler,
UnrepeatedSampler,
BucketedBatchSampler,
UnrepeatedRandomSampler,
UnrepeatedSequentialSampler,
)
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.utils import magic_argv_env_context
import paddle
import paddle.distributed as dist
from paddle.io import DataLoader
from paddle.io import DataLoader, BatchSampler
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm
def generate_driver(num_labels, feature_dimension):
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
driver.setup()
return driver
############################################################################
#
@ -23,269 +37,340 @@ from fastNLP.core import synchronize_safe_rm
#
############################################################################
@magic_argv_env_context
def test_move_data_to_device():
class TestFleetDriverFunction:
"""
这个函数仅调用了paddle_move_data_to_device测试例在tests/core/utils/test_paddle_utils.py中
就不重复测试了
测试 PaddleFleetDriver 一些简单函数的测试类基本都是测试能否运行是否存在 import 错误等问题
"""
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.move_data_to_device(paddle.rand((32, 64)))
finally:
synchronize_safe_rm("log")
dist.barrier()
@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)
@magic_argv_env_context
def test_move_data_to_device(self):
"""
这个函数仅调用了paddle_move_data_to_device测试例在tests/core/utils/test_paddle_utils.py中
就不重复测试了
"""
self.driver.move_data_to_device(paddle.rand((32, 64)))
@magic_argv_env_context
def test_is_distributed():
print(os.getenv("CUDA_VISIBLE_DEVICES"))
print(paddle.device.get_device())
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
output_from_new_proc='all'
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
assert driver.is_distributed() == True
finally:
synchronize_safe_rm("log")
dist.barrier()
dist.barrier()
@magic_argv_env_context
def test_is_distributed(self):
"""
测试 is_distributed 函数
"""
assert self.driver.is_distributed() == True
dist.barrier()
@magic_argv_env_context
def test_get_no_sync_context():
"""
测试能否运行
"""
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
res = driver.get_no_sync_context()
finally:
synchronize_safe_rm("log")
dist.barrier()
@magic_argv_env_context
def test_get_no_sync_context(self):
"""
测试 get_no_sync_context 函数
"""
res = self.driver.get_no_sync_context()
dist.barrier()
@magic_argv_env_context
def test_is_global_zero(self):
"""
测试 is_global_zero 函数
"""
self.driver.is_global_zero()
dist.barrier()
@magic_argv_env_context
def test_is_global_zero():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.is_global_zero()
finally:
synchronize_safe_rm("log")
dist.barrier()
@magic_argv_env_context
def test_unwrap_model(self):
"""
测试 unwrap_model 函数
"""
self.driver.unwrap_model()
dist.barrier()
@magic_argv_env_context
def test_unwrap_model():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.unwrap_model()
finally:
synchronize_safe_rm("log")
dist.barrier()
@magic_argv_env_context
def test_get_local_rank():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.get_local_rank()
finally:
synchronize_safe_rm("log")
dist.barrier()
@magic_argv_env_context
@pytest.mark.parametrize(
"dist_sampler",
["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))]
)
@pytest.mark.parametrize(
"reproducible",
[True, False]
)
def test_replace_sampler(dist_sampler, reproducible):
"""
测试replace_sampler
"""
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
finally:
synchronize_safe_rm("log")
dist.barrier()
@magic_argv_env_context
def test_get_local_rank(self):
"""
测试 get_local_rank 函数
"""
self.driver.get_local_rank()
dist.barrier()
############################################################################
#
# 测试单机多卡的训练情况
# 测试 set_dist_repro_dataloader 函数
#
############################################################################
@magic_argv_env_context
class SingleMachineMultiGPUTrainingTestCase:
class TestSetDistReproDataloader:
@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)
def setup_method(self):
self.dataset = PaddleNormalDataset(20)
"""
测试在单机多卡上使用PaddleFleetDriver进行训练
分布式训练用pytest会有些混乱
传入的 `dist` 参数为具体的 ReproducibleSampler ReproducibleBatchSampler 的情况
此时对应 driver.load 中的情况
"""
def test_case1(self):
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
"""
测试 set_dist_repro_dataloader dist BucketedBatchSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
gpus = [0, 1]
lr = 0.0003
epochs = 20
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier()
paddle_model = PaddleNormalModel_Classification()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_sampler(self):
"""
测试 set_dist_repro_dataloader dist RandomSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
sampler = RandomSampler(self.dataset, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
train_dataset = PaddleDataset_MNIST("train")
test_dataset = PaddleDataset_MNIST("test")
loss_func = paddle.nn.CrossEntropyLoss()
dist.barrier()
"""
传入的参数 `dist` None 的情况这种情况出现在 trainer evaluator 的初始化过程中用户指定了 `use_dist_sampler`
参数为 False此时函数会根据 `reproducible` 的设置进行不同的处理
`reproducible` False 需要根据 dataloader batch_sampler sampler 是否为 Reproducible 来决定
是否重新实例化 dataloader
"""
dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True)
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self):
"""
测试 set_dist_repro_dataloader dist Nonereproducible True 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError):
# 应当抛出 RuntimeError
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=gpus,
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self):
"""
测试 set_dist_repro_dataloader dist Nonereproducible False dataloader BucketedBatchSampler
时的表现
"""
dataloader = DataLoader(
self.dataset,
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4),
)
driver.set_optimizers(paddle_opt)
dataloader = driver.set_dist_repro_dataloader(dataloader, )
driver.setup()
# 检查model_device
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}")
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
pad=True
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
driver.barrier()
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler)
driver.zero_grad()
current_epoch_idx = 0
while current_epoch_idx < epochs:
epoch_loss, batch = 0, 0
driver.set_model_mode("train")
driver.set_sampler_epoch(dataloader, current_epoch_idx)
for batch, (img, label) in enumerate(dataloader):
dist.barrier()
img = paddle.to_tensor(img)
out = driver.train_step(img)
label + 1
loss = loss_func(out, label)
epoch_loss += loss.item()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self):
"""
测试 set_dist_repro_dataloader dist Nonereproducible False dataloader RandomSampler 时的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
if batch % 50 == 0:
print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank))
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
driver.backward(loss)
driver.step()
driver.zero_grad()
driver.barrier()
current_epoch_idx += 1
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self):
"""
测试 set_dist_repro_dataloader dist Nonereproducible False dataloader 为一般情况时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
# test
correct = 0
driver.set_model_mode("eval")
for img, label in test_dataset:
assert replaced_loader is dataloader
dist.barrier()
img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
out = driver.test_step(img)
res = paddle.nn.functional.softmax(out).argmax().item()
label = label.item()
if res == label:
correct += 1
"""
传入的参数 `dist` 'dist' 的情况这种情况出现在 trainer 的初始化过程中用户指定了 `use_dist_sampler` 参数
True此时函数会根据 dataloader batch_sampler sampler 是否为 Reproducible 来决定如何重新实例化 dataloader
"""
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self):
"""
测试 set_dist_repro_dataloader dist 'dist'dataloader.batch_sampler ReproducibleBatchSampler
的表现
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self):
"""
测试 set_dist_repro_dataloader dist 'dist'dataloader.batch_sampler.sampler ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self):
"""
测试 set_dist_repro_dataloader dist 'dist'dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == True
dist.barrier()
"""
传入的参数 `dist` 'unrepeatdist' 的情况这种情况出现在 evaluator 的初始化过程中用户指定了 `use_dist_sampler` 参数
True此时函数会根据 dataloader sampler 是否为 Unrepeated Reproducible 来决定如何重新实例化 dataloader
"""
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self):
"""
测试 set_dist_repro_dataloader dist 'unrepeatdist'dataloader.batch_sampler.sampler ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self):
"""
测试 set_dist_repro_dataloader dist 'unrepeatdist'dataloader.batch_sampler.sampler UnrepeatedSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self):
"""
测试 set_dist_repro_dataloader dist 'unrepeatdist'dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()
def check_distributed_sampler(self, sampler):
"""
测试替换得到的 sampler batch_sampler 的分布式设置是否正确
"""
assert sampler.num_replicas == dist.get_world_size()
assert sampler.rank == dist.get_rank()
if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True
print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset)))