Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
yh_cc 2022-04-16 00:33:32 +08:00
commit 048e409233
9 changed files with 513 additions and 422 deletions

View File

@ -98,6 +98,7 @@ class TorchDataLoader(DataLoader):
def __getattr__(self, item):
"""
为FDataLoader提供dataset的方法和属性实现该方法后用户可以在FDataLoader实例化后使用apply等dataset的方法
:param item:
:return:
"""
@ -119,6 +120,7 @@ class TorchDataLoader(DataLoader):
"""
设置每个field_name的padding值默认为0只有当autocollate存在时该方法有效 若没有则会添加auto_collator函数
当val=None时意味着给定的field_names都不需要尝试padding
:param field_names:
:param val: padding值默认为0
:return:

View File

@ -37,7 +37,7 @@ if _NEED_IMPORT_PADDLE:
import paddle
from paddle import DataParallel
import paddle.distributed.fleet as fleet
import paddle.distributed as dist
import paddle.distributed as paddledist
from paddle.io import BatchSampler
from paddle.optimizer import Optimizer
from paddle.fluid.reader import _DatasetKind
@ -185,8 +185,8 @@ class PaddleFleetDriver(PaddleDriver):
if sorted(pre_gpus) != sorted(self.parallel_device):
raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not"
"allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.")
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
self.world_size = paddledist.get_world_size()
self.global_rank = paddledist.get_rank()
if not self.outside_fleet:
# self.model.to(self.model_device)
@ -197,12 +197,12 @@ class PaddleFleetDriver(PaddleDriver):
# 初始化 self._pids从而使得每一个进程都能接受到 rank0 的 send 操作;
# TODO 不用.to会怎么样
self._pids = []
dist.all_gather(self._pids, paddle.to_tensor(os.getpid(), dtype="int32"))
paddledist.all_gather(self._pids, paddle.to_tensor(os.getpid(), dtype="int32"))
# TODO LOCAL_WORLD_SIZE
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None
if local_world_size is None:
local_world_size = paddle.to_tensor(self.local_rank, dtype="int32")
dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
paddledist.all_reduce(local_world_size, op=paddledist.ReduceOp.MAX)
local_world_size = local_world_size.item() + 1
node_rank = self.global_rank // local_world_size
@ -232,11 +232,11 @@ class PaddleFleetDriver(PaddleDriver):
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时我们需要
根据 paddle 设置的环境变量来获得各种属性
"""
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
self.world_size = paddledist.get_world_size()
self.global_rank = paddledist.get_rank()
def barrier(self):
dist.barrier()
paddledist.barrier()
def configure_fleet(self):
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel):

View File

@ -28,7 +28,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
"""
if is_in_paddle_launch_dist():
if device is not None:
logger.warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull "
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull "
"up your script. And we will directly get the local device via "
"and `os.environ['CUDA_VISIBLE_DEVICES']``.")
device = [int(g) for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]

View File

@ -255,11 +255,14 @@ class TorchDriver(Driver):
logger.debug("Load model...")
# 3. 加载fp16的状态
if 'grad_scaler_state_dict' in states:
grad_scaler_state_dict = states.pop('grad_scaler_state_dict')
if not isinstance(self.grad_scaler, DummyGradScaler):
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
if "grad_scaler_state_dict" in states:
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
if isinstance(self.grad_scaler, DummyGradScaler):
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
self.grad_scaler = _grad_scaler()
self.fp16 = True
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")

View File

@ -14,11 +14,13 @@ if _NEED_IMPORT_PADDLE:
import paddle.distributed as dist
from paddle.fluid.dygraph import parallel_helper
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List:
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)]
dist.all_gather(gathered_result, result, group)
return gathered_result
class PaddleBackend(Backend):
def __init__(self):
super().__init__()
@ -124,4 +126,3 @@ class PaddleBackend(Backend):
# TODO 如果在这里处理的话会不会在别的地方引起bug
device = get_device_from_visible(device)
return paddle_to(tensor, device)

View File

@ -32,7 +32,7 @@ class TorchBackend(Backend):
if dist.is_initialized():
if method is None:
raise AggregateMethodError(should_have_aggregate_method=True)
tensor = fastnlp_torch_all_gather(tensor)
tensor = self.all_gather_object(tensor)
if isinstance(tensor[0], torch.Tensor):
tensor = torch.stack(tensor)
# 第一步, aggregate结果

View File

@ -21,11 +21,12 @@ class TestFdl:
ds.set_pad_val("x", val=-1)
fdl = TorchDataLoader(ds, batch_size=3)
fdl.set_input("x", "y")
fdl.set_pad_val("x", val=None)
for batch in fdl:
print(batch)
fdl.set_pad_val("x", val=-2)
for batch in fdl:
print(batch)
# fdl.set_pad_val("x", val=-2)
# for batch in fdl:
# print(batch)
def test_add_collator(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
@ -38,6 +39,7 @@ class TestFdl:
fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True)
fdl.set_input("x", "y")
# fdl.set_pad_val("x", val=None)
fdl.add_collator(collate_fn)
for batch in fdl:
print(batch)

View File

@ -1,3 +1,4 @@
from dataclasses import replace
import os
from re import S
os.environ["FASTNLP_BACKEND"] = "paddle"
@ -16,413 +17,12 @@ import paddle
from paddle.io import DataLoader, BatchSampler
import torch
############################################################################
#
# 测试save和load相关的功能
# 测试基类 PaddleDrvier 中的一些简单函数
#
############################################################################
def generate_random_driver(features, labels):
"""
生成driver
"""
model = PaddleNormalModel_Classification_1(labels, features)
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model, device="cpu")
driver.set_optimizers(opt)
driver.setup()
return driver
@pytest.fixture
def prepare_test_save_load():
dataset = PaddleRandomMaxDataset(320, 10)
dataloader = DataLoader(dataset, batch_size=32)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
"""
测试save和load函数主要测试 dataloader 被替换了 sampler 之后的情况
"""
try:
path = "model.ckp"
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(80, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
num_consumed_batches = 2
# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"]
# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
# 4. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)
assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)
finally:
synchronize_safe_rm(path)
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
"""
测试save和load函数主要测试 dataloader 被替换了 batch_sampler 的情况
"""
try:
path = "model.ckp"
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(80, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2
# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
# 2. 检查 sampler 是否被正确地加载和替换
replaced_loader = load_states["dataloader"]
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"]
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
# 4. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)
assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)
finally:
synchronize_safe_rm(path)
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
"""
测试 save_model load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load
if only_state_dict:
driver1.save_model(path, only_state_dict)
else:
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict)
for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")
class TestSingleDeviceFunction:
"""
测试其它函数的测试例
"""
@classmethod
def setup_class(cls):
model = PaddleNormalModel_Classification_1(10, 784)
cls.driver = PaddleSingleDriver(model, device="cpu")
def test_unwrap_model(self):
"""
测试能否运行
"""
res = self.driver.unwrap_model()
assert res is self.driver.model
def test_is_distributed(self):
assert self.driver.is_distributed() == False
def test_move_data_to_device(self):
"""
这个函数仅调用了paddle_move_data_to_device测试例在tests/core/utils/test_paddle_utils.py中
就不重复测试了
"""
self.driver.move_data_to_device(paddle.rand((32, 64)))
class TestSetDistReproDataloder:
"""
专门测试 set_dist_repro_dataloader 函数的类
"""
def setup_method(self):
self.dataset = PaddleNormalDataset(20)
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")
def test_set_dist_repro_dataloader_with_reproducible_false(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` False 时的表现
当dist为字符串时此时应该返回原来的 dataloader
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert replaced_loader is dataloader
@pytest.mark.parametrize("shuffle", [True, False])
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` True 时的表现
当dist为字符串时此时应该返回新的 dataloader且如果原 sampler paddle.io.RandomSamplershuffle=True
只会替换 Sampler RandomSampler否则会替换 batch_sampler RandomBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
assert not (replaced_loader is dataloader)
if shuffle:
# 此时会替换 sampler
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 dist ReproducibleBatchSampler
应该返回新的 dataloader并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader且其余各项设置和原来相同
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader且其余各项设置和原来相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
# 加载 num_consumed_samples_array设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
)
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
assert len(left_idxes | already_seen_idx) == len(self.dataset)
class TestPaddleDriverFunctions:
"""
使用 PaddleSingleDriver 测试基类的函数
@ -706,4 +306,428 @@ class TestPaddleDriverFunctions:
assert isinstance(res.sampler, RandomSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last
assert res.drop_last == drop_last
############################################################################
#
# 测试 PaddleSingleDrvier 中的一些简单函数
#
############################################################################
class TestSingleDeviceFunction:
"""
测试其它函数的测试例
"""
@classmethod
def setup_class(cls):
model = PaddleNormalModel_Classification_1(10, 784)
cls.driver = PaddleSingleDriver(model, device="cpu")
def test_unwrap_model(self):
"""
测试能否运行
"""
res = self.driver.unwrap_model()
assert res is self.driver.model
def test_is_distributed(self):
assert self.driver.is_distributed() == False
def test_move_data_to_device(self):
"""
这个函数仅调用了paddle_move_data_to_device测试例在tests/core/utils/test_paddle_utils.py中
就不重复测试了
"""
self.driver.move_data_to_device(paddle.rand((32, 64)))
############################################################################
#
# 测试 set_dist_repro_dataloader 函数
#
############################################################################
class TestSetDistReproDataloder:
"""
专门测试 set_dist_repro_dataloader 函数的类
"""
def setup_method(self):
self.dataset = PaddleNormalDataset(20)
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")
def test_set_dist_repro_dataloader_with_reproducible_false(self):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` False 时的表现
当dist为字符串时此时应该返回原来的 dataloader
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert replaced_loader is dataloader
@pytest.mark.parametrize("shuffle", [True, False])
def test_set_dist_repro_dataloader_with_reproducible_true(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` True 时的表现
当dist为字符串时此时应该返回新的 dataloader且如果原 sampler paddle.io.RandomSamplershuffle=True
只会替换 Sampler RandomSampler否则会替换 batch_sampler RandomBatchSampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
assert not (replaced_loader is dataloader)
if shuffle:
# 此时会替换 sampler
assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else:
# 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 dist ReproducibleBatchSampler
应该返回新的 dataloader并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader且其余各项设置和原来相同
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader且其余各项设置和原来相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_consumed_batches = 2
already_seen_idx = set()
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict()
# 加载 num_consumed_samples_array设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
)
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
assert len(left_idxes | already_seen_idx) == len(self.dataset)
############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################
def generate_random_driver(features, labels):
"""
生成driver
"""
model = PaddleNormalModel_Classification_1(labels, features)
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model, device="cpu")
driver.set_optimizers(opt)
driver.setup()
return driver
@pytest.fixture
def prepare_test_save_load():
dataset = PaddleRandomMaxDataset(320, 10)
dataloader = DataLoader(dataset, batch_size=32)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
"""
测试 save_model load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load
if only_state_dict:
driver1.save_model(path, only_state_dict)
else:
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict)
for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
"""
测试save和load函数主要测试 dataloader 被替换了 sampler 之后的情况
"""
try:
path = "model.ckp"
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
num_consumed_batches = 2
already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4
# 3. 检查 model 的参数是否正确
# 4. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
"""
测试save和load函数主要测试 dataloader 被替换了 batch_sampler 的情况
"""
try:
path = "model.ckp"
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2
already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 model 的参数是否正确
# 4. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)

View File

@ -0,0 +1,59 @@
import os
import pytest
import paddle
import paddle.distributed
import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.distributed.fleet as fleet
from fastNLP.core.metrics import Accuracy
from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher
############################################################################
#
# 测试 单机单卡情况下的Accuracy
#
############################################################################
def test_accuracy_single():
pred = paddle.to_tensor([[1.19812393, -0.82041764, -0.53517765, -0.73061031, -1.45006669,
0.46514302],
[-0.85775983, -2.18273783, -1.07505429, -1.45561373, 0.40011844,
1.02202022],
[-0.39487389, 0.65682763, -0.62424040, 0.53692561, -0.28390560,
-0.02559055],
[-0.22586937, -0.07676325, -0.95977223, 0.36395910, -0.91758579,
-0.83857095],
[0.25136873, 2.49652624, 1.06251311, 1.60194016, 1.01451588,
0.08403367],
[0.10844281, 1.19017303, -0.11378096, 1.12686944, -0.08654942,
0.48605862],
[1.27320433, -1.13902378, 1.47072780, -0.98665696, -0.42589864,
0.64618838],
[0.83809763, -0.05356205, 0.03042423, -0.28371972, 0.81611472,
-0.45802942],
[0.38535264, 0.09721313, 2.27187467, 0.32045507, -0.20711982,
-0.13550705],
[-0.75228405, -1.34161997, 1.08697927, 0.33218071, -1.19470012,
2.58735061]])
tg = paddle.to_tensor([1, 2, 1, 3, 5, 4, 4, 2, 1, 5])
acc_metric = Accuracy()
acc_metric.update(pred, tg)
result = acc_metric.get_metric()
true_result = {'acc': 0.3}
assert true_result == result
############################################################################
#
# 测试 单机多卡情况下的Accuracy
#
############################################################################
def test_accuracy_ddp():
launcher = FleetLauncher(devices=[0, 1])
launcher.launch()
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
if fleet.is_server():
pass
elif fleet.is_worker():
print(os.getenv("PADDLE_TRAINER_ID"))