1. 完成JittorSingleDriver的功能,并添加测试用例 2.在Sampler中添加属性num_samplers 用于动态获取dataset的长度 3.添加便于测试断点重训的数据集 4.修改jittor其它测试的一些bug,统一ArgMaxDataset 的命名

This commit is contained in:
x54-729 2022-05-29 01:32:16 +00:00
parent 49e8ae2daa
commit 75a3278d69
33 changed files with 1381 additions and 451 deletions

View File

@ -24,7 +24,6 @@ from fastNLP.core.dataset import DataSet as FDataSet
class _JittorDataset(Dataset):
"""
对用户传的dataset进行封装以便JittorDataLoader能够支持使用自定义的dataset
"""
def __init__(self, dataset) -> None:
@ -83,7 +82,7 @@ class JittorDataLoader:
# TODO 验证支持replacesampler (以后完成) 增加Sampler
# 将内部dataset批次设置为1
if isinstance(dataset, Dataset):
dataset.set_attrs(batch_size=1)
dataset.set_attrs(batch_size=1, shuffle=False, endless=False)
# FastNLP Datset, collate_fn not None
if isinstance(dataset, FDataSet) and collate_fn is None:
@ -115,6 +114,12 @@ class JittorDataLoader:
self.cur_batch_indices = None
def __getattr__(self, attr):
if attr in ["batch_size", "shuffle", "drop_last", "num_workers", "buffer_size", "stop_grad",
"keep_numpy_array", "endless", "sampler"]:
return getattr(self.dataset, attr)
raise AttributeError(f"{self} has not attribute '{attr}'")
def __iter__(self):
# TODO 第一次迭代后不能设置collate_fn设置是无效的
if self.cur_batch_indices is None:

View File

@ -10,7 +10,7 @@ if _NEED_IMPORT_JITTOR:
__all__ = []
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver:
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: "jittor.Module", **kwargs) -> JittorDriver:
r"""
用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去
@ -30,7 +30,7 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].")
# TODO 实现更详细的判断
if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]:
if device in ["cpu", "gpu", "cuda", None]:
return JittorSingleDriver(model, device, **kwargs)
elif type(device) is int:
return JittorMPIDriver(model, device, **kwargs)

View File

@ -1,23 +1,31 @@
import os
import random
from pathlib import Path
from typing import Union, Optional
from functools import partial
import numpy as np
from typing import Union, Optional, Dict
from contextlib import nullcontext
from dataclasses import dataclass
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS
from fastNLP.envs import (
FASTNLP_MODEL_FILENAME,
FASTNLP_CHECKPOINT_FILENAME,
)
if _NEED_IMPORT_JITTOR:
import jittor as jt
from jittor import Module
from jittor.optim import Optimizer
from jittor.dataset import Dataset
from jittor.dataset import (
BatchSampler as JittorBatchSampler,
Sampler as JittorSampler,
RandomSampler as JittorRandomSampler,
SequentialSampler as JittorSequentialSampler
)
_reduces = {
'max': jt.max,
@ -56,6 +64,7 @@ class JittorDriver(Driver):
else:
jt.flags.auto_mixed_precision_level = 0
self.fp16 = fp16
self._auto_cast = nullcontext
# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
@ -68,7 +77,7 @@ class JittorDriver(Driver):
def _check_optimizer_legality(optimizers):
for each_optimizer in optimizers:
if not isinstance(each_optimizer, Optimizer):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, "
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, "
f"not {type(each_optimizer)}.")
def step(self):
@ -117,30 +126,118 @@ class JittorDriver(Driver):
model = self.unwrap_model()
model.load(filepath)
def save_checkpoint(self):
...
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
dataloader_args = self.get_dataloader_args(dataloader)
if dataloader_args.sampler:
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
sampler_states = sampler.state_dict()
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。因为
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
else: # 有可能 batch_size 为 None就只有损失精度了
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
else:
if dataloader_args.batch_size is not None:
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
states['sampler_states'] = sampler_states
else:
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training '
'state.')
# 2. 保存模型的状态;
if should_save_model:
if not os.path.exists(folder):
os.mkdir(folder)
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
self.save_model(model_path, only_state_dict=only_state_dict)
# 3. 保存 optimizers 的状态;
states["optimizers_state_dict"] = self.get_optimizer_state()
# 4. 保存fp16的状态
logger.debug("Save optimizer state dict")
jt.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
def get_optimizer_state(self):
# optimizers_state_dict = {}
# for i in range(len(self.optimizers)):
# optimizer: torch.optim.Optimizer = self.optimizers[i]
# optimizer_state = optimizer.state_dict()
# optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
# optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy测试是不需要的
# return optimizers_state_dict
...
optimizers_state_dict = {}
for i in range(len(self.optimizers)):
optimizer: Optimizer = self.optimizers[i]
optimizers_state_dict[f"optimizer{i}"] = optimizer.state_dict() # 注意这里没有使用 deepcopy测试是不需要的
return optimizers_state_dict
def load_optimizer_state(self, states):
# assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
# f"checkpoint it is:{len(states)}"
# for i in range(len(self.optimizers)):
# optimizer: torch.optim.Optimizer = self.optimizers[i]
# optimizer.load_state_dict(states[f"optimizer{i}"])
# logger.debug("Load optimizer state dict.")
...
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
f"checkpoint it is:{len(states)}"
for i in range(len(self.optimizers)):
optimizer: Optimizer = self.optimizers[i]
optimizer.load_state_dict(states[f"optimizer{i}"])
logger.debug("Load optimizer state dict.")
def load_checkpoint(self):
...
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
states = jt.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))
# 1. 加载 optimizers 的状态;
optimizers_state_dict = states.pop("optimizers_state_dict")
self.load_optimizer_state(optimizers_state_dict)
# 2. 加载模型状态;
if should_load_model:
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)
# 3. 加载fp16的状态
# 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
if dataloader_args.sampler is None:
sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=dataloader_args.shuffle)
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler
elif isinstance(dataloader_args.sampler, JittorRandomSampler):
sampler = RandomSampler(dataloader_args.sampler.dataset)
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.")
elif isinstance(dataloader_args.sampler, JittorSequentialSampler):
sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=False)
logger.debug("Replace jittor Sampler into fastNLP RandomSampler without shuffle.")
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our"
"`ReproducibleSampler`.")
else:
raise RuntimeError(f"Jittor sampler {type(dataloader_args.sampler)} is not supported now.")
sampler.load_state_dict(states.pop('sampler_states'))
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler不是 batch_sampler
if dataloader_args.drop_last:
batch_idx_in_epoch = len(
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
else:
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
states["batch_idx_in_epoch"] = batch_idx_in_epoch
return states
def get_evaluate_context(self):
return jt.no_grad
@ -198,26 +295,8 @@ class JittorDriver(Driver):
"""
return batch
@staticmethod
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))
process_seed = jt.get_seed()
# back out the base seed so we can use all the bits
base_seed = process_seed - worker_id
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
# use 128 bits (4 x 32-bit words)
np.random.seed(ss.generate_state(4))
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
jittor_ss, stdlib_ss = ss.spawn(2)
jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0])
# use 128 bits expressed as an integer
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
random.seed(stdlib_seed)
def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]):
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(self.worker_init_function,
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)))
...
def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int):
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
@ -226,4 +305,45 @@ class JittorDriver(Driver):
@staticmethod
def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]):
pass
@dataclass
class Res:
dataset: Optional[Dataset] = None
batch_sampler: Optional[JittorBatchSampler] = None
sampler: Optional[JittorSampler] = None
batch_size: Optional[int] = None
shuffle: Optional[bool] = None
drop_last: Optional[bool] = None
res = Res()
from fastNLP.core.dataloaders.jittor_dataloader.fdl import _JittorDataset
if isinstance(dataloader, JittorDataLoader):
# JittorDataLoader 实际上是迭代 dataset 成员的
dataloader = dataloader.dataset
if isinstance(dataloader, _JittorDataset):
# 获取最原始的 dataset
res.dataset = dataloader.dataset
else:
res.dataset = dataloader
# jittor 现在不支持 batch_sampler所以除了 shuffle 都可以直接获取
res.batch_size = dataloader.batch_size
res.drop_last = dataloader.drop_last
if dataloader.sampler is None:
# sampler 是 None那么就从 Dataset 的属性中获取
res.shuffle = dataloader.shuffle
elif isinstance(list(dataloader.sampler.__iter__())[0], (list,tuple)):
# jittor 目前不支持 batch_sampler
raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, "
"please check if you have set `Dataset.sampler` as `BatchSampler`")
else:
# sampler 不为 None
res.sampler = dataloader.sampler
if hasattr(dataloader.sampler, "shuffle"):
# 这种情况一般出现在 fastNLP 的 ReproduceSampler 中
res.shuffle = dataloader.sampler.shuffle
elif isinstance(dataloader.sampler, JittorRandomSampler):
res.shuffle = True
else:
res.shuffle = False
return res

View File

@ -38,6 +38,7 @@ class JittorMPIDriver(JittorDriver):
):
super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs)
raise NotImplementedError("MPI for Jittor is not supported right now.")
self.is_pull_by_jittor_run = is_pull_by_jittor_run
self.parallel_device = parallel_device
@ -100,22 +101,6 @@ class JittorMPIDriver(JittorDriver):
return self._data_device
return self.parallel_device
def step(self):
# for optimizer in self.optimizers:
# self.grad_scaler.step(optimizer)
# self.grad_scaler.update()
for optimizer in self.optimizers:
optimizer.step()
def backward(self, loss):
# self.grad_scaler.scale(loss).backward()
for optimizer in self.optimizers:
optimizer.backward(loss)
def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(fn, batch, signature_fn=signature_fn)

View File

@ -1,14 +1,21 @@
from typing import Dict, Union, Tuple, Callable, Optional
from .jittor_driver import JittorDriver
from .utils import replace_batch_sampler, replace_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
ReproduceBatchSampler
from fastNLP.core.samplers import RandomSampler
from fastNLP.core.log import logger
if _NEED_IMPORT_JITTOR:
import jittor as jt
from jittor.dataset import (
RandomSampler as JittorRandomSampler,
SequentialSampler as JittorSequentialSampler,
)
__all__ = [
"JittorSingleDriver",
@ -89,31 +96,46 @@ class JittorSingleDriver(JittorDriver):
"""
return False
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
reproducible: bool = False, sampler_or_batch_sampler=None):
# reproducible 的相关功能暂时没有实现
def set_dist_repro_dataloader(self, dataloader,
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
reproducible: bool = False):
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
if isinstance(dist, ReproducibleBatchSampler):
raise NotImplementedError
dataloader.batch_sampler = dist_sample
if isinstance(dist, ReproducibleSampler):
raise NotImplementedError
dataloader.batch_sampler.sampler = dist
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist)
# 如果 dist 为 str 或者 None说明是在 trainer 初试化时调用;
args = self.get_dataloader_args(dataloader)
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler)
if reproducible:
raise NotImplementedError
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
return dataloader
elif isinstance(dataloader.batch_sampler, RandomBatchSampler):
return dataloader
else:
# TODO
batch_sampler = RandomBatchSampler(
batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last
)
dataloader.batch_sampler = batch_sampler
return dataloader
if args.sampler is None:
sampler = RandomSampler(args.dataset, args.shuffle)
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, JittorRandomSampler):
if getattr(args.sampler, '_num_samples', None) is None \
and getattr(args.sampler, 'rep', False) is False:
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
sampler = RandomSampler(args.sampler.dataset, shuffle=True)
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
elif isinstance(args.sampler, JittorSequentialSampler):
# 需要替换为不要 shuffle 的。
sampler = RandomSampler(args.sampler.dataset, shuffle=False)
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler)
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last
)
return replace_batch_sampler(dataloader, batch_sampler)
else:
return dataloader

View File

@ -1,6 +1,29 @@
import inspect
from copy import deepcopy
from typing import Union
from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
import jittor
from jittor.dataset import Dataset
__all__ = []
def replace_batch_sampler(dataloader, batch_sampler):
raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, "
"please check if you have set `Dataset.sampler` as `BatchSampler`"
"or report this bug to us.")
def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler):
if isinstance(dataloader, JittorDataLoader):
init_params = dict(inspect.signature(dataloader.__init__).parameters)
reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()}
reconstruct_args["dataset"] = replace_sampler(reconstruct_args["dataset"].dataset, reconstruct_args["dataset"].sampler)
new_dataloader = type(dataloader)(**reconstruct_args)
new_dataloader.dataset.set_attrs(sampler=sampler)
else:
new_dataloader = deepcopy(dataloader)
new_dataloader.set_attrs(sampler=sampler)
return new_dataloader

View File

@ -31,7 +31,6 @@ if _NEED_IMPORT_PADDLE:
import paddle
from paddle.io import (
DataLoader,
IterableDataset,
Dataset,
Sampler,
BatchSampler,
@ -97,6 +96,9 @@ class PaddleDriver(Driver):
def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
if dataloader.batch_size is None and dataloader.batch_sampler is None:
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler"
"is not None")
@staticmethod
def _check_optimizer_legality(optimizers):
@ -107,7 +109,7 @@ class PaddleDriver(Driver):
"""
for each_optimizer in optimizers:
if not isinstance(each_optimizer, Optimizer):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
f"not {type(each_optimizer)}.")
@staticmethod
@ -263,9 +265,7 @@ class PaddleDriver(Driver):
optimizers_state_dict = {}
for i in range(len(self.optimizers)):
optimizer: Optimizer = self.optimizers[i]
optimizer_state = optimizer.state_dict()
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu")
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy测试是不需要的
optimizers_state_dict[f"optimizer{i}"] = optimizer_state_to_device(optimizer.state_dict(), "cpu")
return optimizers_state_dict
@ -399,6 +399,8 @@ class PaddleDriver(Driver):
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx):
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
dataloader.batch_sampler.set_epoch(cur_epoch_idx)
elif callable(getattr(dataloader.batch_sampler.sampler, "set_epoch", None)):
dataloader.batch_sampler.sampler.set_epoch(cur_epoch_idx)
@staticmethod
def get_dataloader_args(dataloader: "DataLoader"):

View File

@ -99,7 +99,7 @@ class TorchDriver(Driver):
def _check_optimizer_legality(optimizers):
for each_optimizer in optimizers:
if not isinstance(each_optimizer, Optimizer):
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, "
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, "
f"not {type(each_optimizer)}.")
@staticmethod

View File

@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.num_consumed_samples = 0
self.during_iter = True
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
indices = list(range(self.num_samples))
if self.shuffle:
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
if len(indices)%self.batch_size!=0:
batches.append(indices[_num_batches*self.batch_size:])
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas
need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler):
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size
@property
@ -313,8 +313,12 @@ class RandomBatchSampler(ReproducibleBatchSampler):
:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas))
@property
def num_samples(self):
return getattr(self.dataset, 'total_len', len(self.dataset))
def __len__(self)->int:
"""
@ -332,7 +336,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle,
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_replicas': self.num_replicas}
@ -347,7 +351,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
f"we cannot use {self.__class__.__name__} to load it."
length = states['length']
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \
assert length == self.num_samples, "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
@ -464,8 +468,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas))
@property
def num_samples(self):
return getattr(self.dataset, 'total_len', len(self.dataset))
def __len__(self)->int:
"""
@ -515,7 +523,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if len(sorted_indices)%self.batch_size!=0:
batches.append(sorted_indices[_num_batches*self.batch_size:])
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas
need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
@ -593,7 +601,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle,
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle,
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas
}
@ -609,7 +617,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
f"we cannot use {self.__class__.__name__} to load it."
length = states['length']
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \
assert length == self.num_samples, "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
@ -630,7 +638,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size

View File

@ -48,6 +48,10 @@ class ReproducibleSampler:
def num_left_samples(self):
raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.")
@property
def num_samples(self):
raise NotImplementedError("Each specific sampler should implement its own `num_samples` method.")
def set_epoch(self, epoch):
pass
@ -131,19 +135,19 @@ class RandomSampler(ReproducibleSampler):
:return:
"""
if self.shuffle:
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
indices = list(range(self.num_samples))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
indices = list(range(self.num_samples))
return indices
def state_dict(self) -> Dict:
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle}
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle}
return states
def load_state_dict(self, states: Dict):
@ -155,8 +159,8 @@ class RandomSampler(ReproducibleSampler):
f"we cannot use {self.__class__.__name__} to load it."
length = states['length']
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})."
assert length == self.num_samples, "The number of samples is different between the checkpoint " \
f"record({length}) and current dataset({self.num_samples})."
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
@ -208,9 +212,17 @@ class RandomSampler(ReproducibleSampler):
:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas))
@property
def num_samples(self):
"""
返回样本的总数
:return:
"""
return getattr(self.dataset, 'total_len', len(self.dataset))
class SequentialSampler(RandomSampler):
"""
@ -258,12 +270,10 @@ class SequentialSampler(RandomSampler):
:return:
"""
return list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
return list(range(self.num_samples))
def state_dict(self) -> Dict:
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__,
'length': getattr(self.dataset, 'total_len', len(self.dataset))
}
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': self.num_samples}
return states
def load_state_dict(self, states: Dict):
@ -275,8 +285,8 @@ class SequentialSampler(RandomSampler):
f"we cannot use {self.__class__.__name__} to load it."
length = states['length']
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})."
assert length == self.num_samples, "The number of samples is different between the checkpoint " \
f"record({length}) and current dataset({self.num_samples})."
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了则直接将结果重置为0
self.num_consumed_samples = 0
@ -314,9 +324,9 @@ class SortedSampler(SequentialSampler):
except BaseException as e:
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.")
assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \
f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal."
assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length."
assert len(length) == self.num_samples, f"The length of `dataset`({len(dataset)}) and " \
f"`length`({self.num_samples}) should be equal."
assert len(self.sorted_indices) == self.num_samples, "The indices and dataset should have equal length."
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的

View File

@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
返回 sampler 一次完整的迭代过程会产生多少个index多卡的情况下只考虑当前rank
:return:
"""
num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas
num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas))
num_common = self.num_samples//self.num_replicas
num_samples = num_common + int(self.rank < (self.num_samples-num_common*self.num_replicas))
return num_samples
def __iter__(self):
@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
:return:
"""
if self.shuffle:
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
indices = list(range(self.num_samples))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
indices = list(range(self.num_samples))
return indices
def set_epoch(self, epoch: int) -> None:
@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
:param rank:
:return:
"""
assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \
f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})."
assert num_replicas<=self.num_samples, f"The number of replicas({num_replicas}) should be lesser than the " \
f"number of samples({self.num_samples})."
assert num_replicas>0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0<=rank<num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
@ -94,6 +94,15 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
return self
@property
def num_samples(self):
"""
返回样本的总数
:return:
"""
return getattr(self.dataset, 'total_len', len(self.dataset))
class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
"""
@ -147,5 +156,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler):
yield index
def generate_indices(self) -> List[int]:
return list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
return list(range(self.num_samples))

View File

@ -27,7 +27,7 @@ from paddle.optimizer import Adam
from paddle.io import DataLoader
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback
@dataclass
@ -52,12 +52,12 @@ def test_trainer_fleet(
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension),
dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension),
batch_size=MNISTTrainFleetConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension),
dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension),
batch_size=MNISTTrainFleetConfig.batch_size,
shuffle=True
)

View File

@ -24,7 +24,7 @@ from paddle.io import DataLoader
import paddle.distributed.fleet as fleet
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback
@dataclass
@ -54,12 +54,12 @@ def test_trainer_fleet(
optimizers = fleet.distributed_optimizer(optimizers)
train_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension),
dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension),
batch_size=MNISTTrainFleetConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension),
dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension),
batch_size=MNISTTrainFleetConfig.batch_size,
shuffle=True
)

View File

@ -46,8 +46,8 @@ class LSTM(Module):
def init_hidden(self, x):
# batch_first
batch_size = x.shape[0]
h0 = jt.randn(1, batch_size, hidden_size)
c0 = jt.randn(1, batch_size, hidden_size)
h0 = jt.randn(1, batch_size, self.hidden_size)
c0 = jt.randn(1, batch_size, self.hidden_size)
return h0, c0

View File

@ -1,4 +1,5 @@
import pytest
from fastNLP.core.callbacks import callback
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.controllers.trainer import Evaluator
@ -14,6 +15,7 @@ if _NEED_IMPORT_JITTOR:
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
jt.flags.use_cuda=1
class JittorNormalModel_Classification(Module):
@ -68,11 +70,9 @@ class TrainJittorConfig:
batch_size: int = 4
shuffle: bool = True
@pytest.mark.parametrize("driver", ["jittor"])
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"])
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
@pytest.mark.jittor
def test_trainer_jittor(
driver,
device,

View File

@ -15,7 +15,7 @@ if _NEED_IMPORT_PADDLE:
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset
from tests.helpers.utils import magic_argv_env_context
@dataclass
@ -44,12 +44,12 @@ def test_trainer_paddle(
)
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
train_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(20, TrainPaddleConfig.feature_dimension),
dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension),
batch_size=TrainPaddleConfig.batch_size,
shuffle=True
)
val_dataloader = DataLoader(
dataset=PaddleRandomMaxDataset(12, TrainPaddleConfig.feature_dimension),
dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension),
batch_size=TrainPaddleConfig.batch_size,
shuffle=True
)

View File

@ -76,7 +76,7 @@ class TestPaddle:
from paddle.io import Dataset
import paddle
class PaddleRandomMaxDataset(Dataset):
class PaddleArgMaxDataset(Dataset):
def __init__(self, num_samples, num_features):
self.x = paddle.randn((num_samples, num_features))
self.y = self.x.argmax(axis=-1)
@ -87,7 +87,7 @@ class TestPaddle:
def __getitem__(self, item):
return {"x": self.x[item], "y": self.y[item]}
ds = PaddleRandomMaxDataset(100, 2)
ds = PaddleArgMaxDataset(100, 2)
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4)
for batch in dl:
print(batch)

View File

@ -0,0 +1,45 @@
import pytest
from fastNLP.core.drivers import JittorSingleDriver, JittorMPIDriver
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver
from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
import jittor as jt
@pytest.mark.jittor
def test_incorrect_driver():
model = JittorNormalModel_Classification_1(20, 10)
with pytest.raises(ValueError):
driver = initialize_jittor_driver("torch", 0, model)
@pytest.mark.jittor
@pytest.mark.parametrize(
"device",
["cpu", "gpu", None, "cuda"]
)
def test_get_single_device(device):
"""
测试正常情况下初始化 JittorSingleDriver 的情况
"""
model = JittorNormalModel_Classification_1(20, 10)
driver = initialize_jittor_driver("jittor", device, model)
assert isinstance(driver, JittorSingleDriver)
@pytest.mark.jittor
@pytest.mark.parametrize(
"device",
[[0, 2, 3], 1, 2]
)
def test_get_mpi(device):
"""
测试 jittor 多卡的初始化情况
"""
model = JittorNormalModel_Classification_1(20, 10)
with pytest.raises(NotImplementedError):
driver = initialize_jittor_driver("jittor", device, model)
# assert isinstance(driver, JittorMPIDriver)

View File

@ -1,99 +1,614 @@
import pytest
import os
from copy import deepcopy
from pathlib import Path
import numpy as np
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.drivers.jittor_driver import JittorSingleDriver
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from fastNLP.core.dataloaders import JittorDataLoader
from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1
from tests.helpers.datasets.jittor_data import JittorNormalDataset, JittorNormalXYDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH
if _NEED_IMPORT_JITTOR:
import jittor as jt # 将 jittor 引入
from jittor import nn, Module # 引入相关的模块
from jittor import init
from jittor.dataset import MNIST
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module
import jittor as jt
from jittor.dataset import (
BatchSampler as JittorBatchSampler,
RandomSampler as JittorRandomSampler,
SequentialSampler as JittorSequentialSampler,
SubsetRandomSampler as JittorSubsetRandomSampler
)
if _NEED_IMPORT_TORCH:
import torch
def get_dataloader(dataset, use_dataloader, sampler, batch_size, shuffle, drop_last=False):
"""
:param dataset:
:param use_dataloader: 是否使用 JittorDataLoader 包裹
:param sampler: 使用 BatchSampler Samlper 还是不使用 Sampler
"""
if use_dataloader:
dataloader = JittorDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
dataloader.dataset.set_attrs(sampler=sampler)
else:
dataloader = dataset
dataloader.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, sampler=sampler)
return dataloader
############################################################################
#
# 测试基类 JittorDrvier 中的一些简单函数
#
############################################################################
class TestJittorDriverFunctions:
"""
使用 JittorSingleDriver 测试基类的函数
"""
@classmethod
def setup_class(self):
model = JittorNormalModel_Classification_1(10, 32)
self.driver = JittorSingleDriver(model, device="cpu")
@pytest.mark.jittor
def test_check_optimizers_legality(self):
"""
测试对合法的 optimizers 的检查
"""
# 单个 optimizer
optimizer = jt.optim.Adam(
params=self.driver.model.parameters(),
lr=0.01
)
self.driver.set_optimizers(optimizer)
# optimizer 列表
optimizers = [
jt.optim.Adam(
params=self.driver.model.parameters(),
lr=0.01
) for i in range(10)
]
self.driver.set_optimizers(optimizers)
@pytest.mark.torchjittor
def test_invalid_optimizers(self):
"""
测试传入非法的 optimizers
"""
# 单个 optimizer
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
with pytest.raises(TypeError):
self.driver.set_optimizers(optimizer)
optimizers = [
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
]
with pytest.raises(TypeError):
self.driver.set_optimizers(optimizers)
@pytest.mark.jittor
def test_check_dataloader_legality(self):
"""
测试 check_dataloader_legality 函数的表现
"""
# 使用 JittorDataLoader
dataloader = JittorDataLoader(JittorNormalDataset())
self.driver.check_dataloader_legality(dataloader)
# 使用 jittor.dataset.Dataset
self.driver.check_dataloader_legality(JittorNormalDataset())
@pytest.mark.torchjittor
def test_check_dataloader_legality_invalid(self):
"""
测试 check_dataloader_legality 函数传入其他类型的表现
"""
# 创建 torch 的 dataloader
dataloader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
with pytest.raises(TypeError):
self.driver.check_dataloader_legality(dataloader)
@pytest.mark.jittor
def test_tensor_to_numeric(self):
"""
测试 tensor_to_numeric 函数
"""
# 单个张量
tensor = jt.Var(3)
res = JittorSingleDriver.tensor_to_numeric(tensor)
assert res == 3
tensor = jt.rand(3, 4)
res = JittorSingleDriver.tensor_to_numeric(tensor)
assert res == tensor.tolist()
# 张量list
tensor_list = [jt.rand(6, 4, 2) for i in range(10)]
res = JittorSingleDriver.tensor_to_numeric(tensor_list)
assert isinstance(res, list)
tensor_list = [t.tolist() for t in tensor_list]
assert res == tensor_list
# 张量tuple
tensor_tuple = tuple([jt.rand(6, 4, 2) for i in range(10)])
res = JittorSingleDriver.tensor_to_numeric(tensor_tuple)
assert isinstance(res, tuple)
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
assert res == tensor_tuple
# 张量dict
tensor_dict = {
"tensor": jt.rand(3, 4),
"list": [jt.rand(6, 4, 2) for i in range(10)],
"dict":{
"list": [jt.rand(6, 4, 2) for i in range(10)],
"tensor": jt.rand(3, 4)
},
"int": 2,
"string": "test string"
}
res = JittorSingleDriver.tensor_to_numeric(tensor_dict)
assert isinstance(res, dict)
assert res["tensor"] == tensor_dict["tensor"].tolist()
assert isinstance(res["list"], list)
for r, d in zip(res["list"], tensor_dict["list"]):
assert r == d.tolist()
assert isinstance(res["int"], int)
assert isinstance(res["string"], str)
assert isinstance(res["dict"], dict)
assert isinstance(res["dict"]["list"], list)
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
assert r == d.tolist()
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()
@pytest.mark.jittor
def test_tensor_to_numeric_reduce(self):
tensor = jt.Var([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
res_max = JittorSingleDriver.tensor_to_numeric(tensor, reduce="max")
res_min = JittorSingleDriver.tensor_to_numeric(tensor, reduce="min")
res_sum = JittorSingleDriver.tensor_to_numeric(tensor, reduce="sum")
res_mean = JittorSingleDriver.tensor_to_numeric(tensor, reduce="mean")
assert res_max == 6
assert res_min == 1
assert res_sum == 21
assert res_mean == 3.5
@pytest.mark.jittor
def test_set_model_mode(self):
"""
测试 set_model_mode 函数
"""
self.driver.set_model_mode("train")
assert self.driver.model.is_training()
self.driver.set_model_mode("eval")
assert not self.driver.model.is_training()
# 应该报错
with pytest.raises(AssertionError):
self.driver.set_model_mode("test")
class Model(Module):
def __init__ (self):
super (Model, self).__init__()
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
self.conv2 = nn.Conv (32, 64, 3, 1)
self.bn = nn.BatchNorm(64)
@pytest.mark.jittor
def test_move_model_to_device_cpu(self):
"""
测试 move_model_to_device 函数仅测试能否运行
"""
JittorSingleDriver.move_model_to_device(self.driver.model, "cpu")
self.max_pool = nn.Pool (2, 2)
self.relu = nn.Relu()
self.fc1 = nn.Linear (64 * 12 * 12, 256)
self.fc2 = nn.Linear (256, 10)
@pytest.mark.jittor
def test_move_model_to_device_gpu(self):
"""
测试 move_model_to_device 函数仅测试能否运行
"""
JittorSingleDriver.move_model_to_device(self.driver.model, "gpu")
def execute(self, x) :
# it's simliar to forward function in Pytorch
x = self.conv1 (x)
x = self.relu (x)
x = self.conv2 (x)
x = self.bn (x)
x = self.relu (x)
x = self.max_pool (x)
x = jt.reshape (x, [x.shape[0], -1])
x = self.fc1 (x)
x = self.relu(x)
x = self.fc2 (x)
return x
@pytest.mark.jittor
def test_set_deterministic_dataloader(self):
"""
测试 set_deterministic_dataloader仅测试能否运行
"""
# 先确保不影响运行
# TODO正确性
dataloader = JittorDataLoader(JittorNormalDataset())
self.driver.set_deterministic_dataloader(dataloader)
self.driver.set_deterministic_dataloader(JittorNormalDataset())
@pytest.mark.jittor
def test_set_sampler_epoch(self):
"""
测试 set_sampler_epoch
"""
# 先确保不影响运行
# TODO正确性
dataloader = JittorDataLoader(JittorNormalDataset())
self.driver.set_sampler_epoch(dataloader, 0)
self.driver.set_sampler_epoch(JittorNormalDataset(), 0)
@pytest.mark.jittor
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_get_dataloader_args(self, batch_size, shuffle, drop_last, use_dataloader):
"""
测试正常情况下 get_dataloader_args 的表现
"""
dataloader = get_dataloader(
JittorNormalDataset(),
use_dataloader=use_dataloader,
sampler=None,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last
)
res = JittorSingleDriver.get_dataloader_args(dataloader)
assert isinstance(res.dataset, JittorNormalDataset)
assert res.sampler is None
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last
@pytest.mark.jittor
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last, use_dataloader):
"""
测试替换了 sampler get_dataloader_args 的表现
"""
dataset = JittorNormalDataset()
dataloader = get_dataloader(
dataset,
use_dataloader=use_dataloader,
batch_size=batch_size,
sampler=RandomSampler(dataset, shuffle=shuffle),
shuffle=shuffle,
drop_last=drop_last
)
res = JittorSingleDriver.get_dataloader_args(dataloader)
assert isinstance(res.dataset, JittorNormalDataset)
assert isinstance(res.sampler, RandomSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last
############################################################################
#
# 测试 JittorSingleDrvier 中的一些简单函数
#
############################################################################
@pytest.mark.jittor
@pytest.mark.skip("Skip jittor tests now.")
class TestSingleDevice:
class TestSingleDeviceFunction:
"""
测试其它函数的测试例
"""
def test_on_gpu_without_fp16(self):
# TODO get_dataloader
batch_size = 64
learning_rate = 0.1
epochs = 5
losses = []
losses_idx = []
@classmethod
def setup_class(cls):
model = JittorNormalModel_Classification_1(10, 784)
cls.driver = JittorSingleDriver(model, device="cpu")
train_loader = MNIST(train=True, batch_size=batch_size, shuffle=True)
val_loader = MNIST(train=False, batch_size=1, shuffle=False)
def test_unwrap_model(self):
"""
测试能否运行
"""
res = self.driver.unwrap_model()
assert res is self.driver.model
model = Model()
driver = JittorSingleDriver(model, device=[1])
optimizer = nn.SGD(model.parameters(), learning_rate)
driver.set_optimizers(optimizer)
def test_is_distributed(self):
assert self.driver.is_distributed() == False
for epoch in range(epochs):
driver.set_model_mode("train")
lens = len(train_loader)
for batch_idx, (inputs, targets) in enumerate(train_loader):
outputs =driver.train_step(inputs)
loss = nn.cross_entropy_loss(outputs, targets)
driver.backward(loss)
driver.step()
driver.zero_grad()
losses.append(loss.data[0])
losses_idx.append(epoch * lens + batch_idx)
test_loss = 0
correct = 0
total_acc = 0
total_num = 0
driver.set_model_mode("eval")
for batch_idx, (inputs, targets) in enumerate(val_loader):
batch_size = inputs.shape[0]
outputs = driver.test_step(inputs)
pred = np.argmax(outputs.data, axis=1)
acc = np.sum(targets.data==pred)
total_acc += acc
total_num += batch_size
acc = acc / batch_size
assert total_acc / total_num > 0.95
def test_move_data_to_device(self):
self.driver.move_data_to_device(jt.rand(32, 64))
def test_on_cpu_without_fp16(self):
pass
############################################################################
#
# 测试 set_dist_repro_dataloader 函数
#
############################################################################
def test_on_gpu_with_fp16(self):
pass
@pytest.mark.jittor
class TestSetDistReproDataloader:
"""
专门测试 set_dist_repro_dataloader 函数的类
"""
def setup_method(self):
self.dataset = JittorNormalDataset(20)
model = JittorNormalModel_Classification_1(10, 32)
self.driver = JittorSingleDriver(model, device="cpu")
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_with_reproducible_false(self, use_dataloader):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` False 时的表现
当dist为字符串时此时应该返回原来的 dataloader
"""
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert replaced_loader is dataloader
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("sampler", [None, "random", "sequential"])
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_with_reproducible_true(self, shuffle, sampler, use_dataloader):
"""
测试 set_dist_repro_dataloader 参数 `reproducible` True 时的表现
当dist为字符串时此时应该返回新的 dataloader会替换 sampler RandomSampler
"""
if sampler == "random":
sampler = JittorRandomSampler(self.dataset)
_shuffle = True
elif sampler == "sequential":
sampler = JittorSequentialSampler(self.dataset)
_shuffle = False
else:
_shuffle = shuffle
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=sampler, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.sampler, RandomSampler)
assert replaced_loader.sampler.shuffle == _shuffle
assert replaced_loader.batch_size == dataloader.batch_size
assert replaced_loader.drop_last == dataloader.drop_last
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader)
@pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_with_dist_batch_sampler(self, shuffle, use_dataloader):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 dist ReproducibleBatchSampler
应该返回新的 dataloader并将 batch_sampler 替换为 dist 对应的 Sampler
jittor 暂时不支持这种情况会报错
"""
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle)
dist = ReproduceBatchSampler(JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), 4, False)
with pytest.raises(RuntimeError):
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
@pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_with_dist_sampler(self, shuffle, use_dataloader):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader并将 sampler 替换为 dist 对应的 Sampler
"""
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle)
dist = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.sampler, RandomSampler)
assert replaced_loader.sampler is dist
assert replaced_loader.batch_size == dataloader.batch_size
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader)
@pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_with_dataloader_reproducible_batch_sampler(self, shuffle, use_dataloader):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader且其余各项设置和原来相同
"""
dataloader = get_dataloader(
self.dataset,
use_dataloader=use_dataloader,
sampler=ReproduceBatchSampler(
JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False),
batch_size=4,
drop_last=False,
),
batch_size=4,
shuffle=shuffle,
)
with pytest.raises(RuntimeError):
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
@pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_with_dataloader_reproducible_sampler(self, shuffle, use_dataloader):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader且其余各项设置和原来相同
"""
dataloader = get_dataloader(
self.dataset,
use_dataloader=use_dataloader,
sampler=RandomSampler(self.dataset, shuffle),
batch_size=2,
shuffle=shuffle,
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
assert not (replaced_loader is dataloader)
assert not (replaced_loader.sampler is dataloader.sampler)
assert isinstance(replaced_loader.sampler, RandomSampler)
assert replaced_loader.batch_size == 2
assert replaced_loader.shuffle == shuffle
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader)
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle, use_dataloader):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
# 迭代两个 batch
num_consumed_batches = 2
already_seen_idx = set()
replaced_loader.sampler.set_epoch(6)
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch.tolist())
sampler_states = replaced_loader.sampler.state_dict()
# 重新加载,应该可以输出剩下的内容,且对于 JittorNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
batch_size = replaced_loader.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader
if use_dataloader:
dataset = deepcopy(replaced_loader.dataset.dataset)
else:
dataset = deepcopy(replaced_loader)
new_loader = get_dataloader(
dataset=dataset,
use_dataloader=use_dataloader,
sampler = RandomSampler(dataset, shuffle=shuffle),
batch_size=batch_size,
shuffle=shuffle,
drop_last=False
)
new_loader.sampler.load_state_dict(sampler_states)
new_loader.sampler.set_epoch(6)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch.tolist())
print(already_seen_idx)
print(left_idxes)
assert len(left_idxes) + len(already_seen_idx) == self.dataset.total_len
assert len(left_idxes | already_seen_idx) == self.dataset.total_len
############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################
def generate_random_driver(labels, features, fp16=False, device="cpu", lr=0.01):
"""
生成driver
"""
model = JittorNormalModel_Classification_1(labels, features)
opt = jt.optim.Adam(params=model.parameters(), lr=lr)
driver = JittorSingleDriver(model, device=device, fp16=fp16)
driver.set_optimizers(opt)
driver.setup()
return driver
@pytest.mark.jittor
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_save_and_load_model(only_state_dict, use_dataloader):
"""
测试 save_model load_model 函数
"""
try:
path = "model"
dataset = JittorNormalXYDataset(20)
dataloader = get_dataloader(dataset, sampler=None, use_dataloader=use_dataloader, batch_size=4, shuffle=True)
driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu")
driver1.save_model(path, only_state_dict)
driver2.load_model(path, only_state_dict)
for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert jt.all_(jt.equal(res1["pred"], res2["pred"]))
finally:
rank_zero_rm(path)
@pytest.mark.jittor
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("use_dataloader", [True, False])
def test_save_and_load_with_randomsampler(only_state_dict, use_dataloader):
"""
测试save和load函数主要测试 dataloader 被替换了 sampler 的情况
"""
try:
path = "model.ckp"
driver1, driver2 = generate_random_driver(20, 1, device="gpu", lr=0.01), \
generate_random_driver(20, 1, device="gpu", lr=0.001)
dataset = JittorNormalXYDataset(20)
dataloader = get_dataloader(
dataset, use_dataloader,
sampler = RandomSampler(dataset, True),
batch_size=4,
shuffle=True
)
num_consumed_batches = 2
already_seen_x_set = set()
already_seen_y_set = set()
driver1.set_sampler_epoch(dataloader, 7)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
sampler_states = dataloader.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
# 加载
# 更改 batch_size
dataloader = get_dataloader(
dataset, use_dataloader,
sampler=RandomSampler(dataset, True),
batch_size=2,
shuffle=True
)
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
assert driver2.optimizers[0].lr == driver1.optimizers[0].lr
# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.sampler, RandomSampler)
assert replaced_loader.sampler.seed == sampler_states["seed"]
assert replaced_loader.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.sampler.num_consumed_samples == 4 * num_consumed_batches
assert replaced_loader.sampler.dataset.total_len == sampler_states["length"]
assert replaced_loader.sampler.shuffle == sampler_states["shuffle"]
# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
driver2.set_sampler_epoch(replaced_loader, 7)
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert jt.all_(jt.equal(res1["pred"], res2["pred"]))
assert len(left_x_batches) + len(already_seen_x_set) == dataset.total_len
assert len(left_x_batches | already_seen_x_set) == dataset.total_len
assert len(left_y_batches) + len(already_seen_y_set) == dataset.total_len
assert len(left_y_batches | already_seen_y_set) == dataset.total_len
finally:
rank_zero_rm(path)

View File

@ -0,0 +1,43 @@
import pytest
from fastNLP.core.drivers.jittor_driver.utils import replace_sampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from fastNLP.core.dataloaders import JittorDataLoader
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
import jittor as jt
from tests.helpers.datasets.jittor_data import JittorNormalDataset
@pytest.mark.jittor
@pytest.mark.parametrize("dataset", [
JittorNormalDataset(20, batch_size=10, shuffle=True),
JittorNormalDataset(20, batch_size=5, drop_last=True),
JittorNormalDataset(20)
])
def test_replace_sampler_dataset(dataset):
dataset = JittorNormalDataset(20)
sampler = RandomSampler(dataset)
replaced_loader = replace_sampler(dataset, sampler)
assert not (replaced_loader is dataset)
assert isinstance(replaced_loader.sampler, RandomSampler)
assert replaced_loader.batch_size == dataset.batch_size
assert replaced_loader.drop_last == dataset.drop_last
assert replaced_loader.shuffle == dataset.shuffle
assert replaced_loader.total_len == dataset.total_len
@pytest.mark.jittor
def test_replace_sampler_jittordataloader():
dataset = JittorNormalDataset(20, batch_size=10, shuffle=True)
dataloader = JittorDataLoader(dataset, batch_size=8, shuffle=True)
sampler = RandomSampler(dataset)
replaced_loader = replace_sampler(dataloader, sampler)
assert not (replaced_loader is dataloader)
assert not (replaced_loader.dataset.dataset is dataloader.dataset.dataset)
assert isinstance(replaced_loader.sampler, RandomSampler)
assert replaced_loader.batch_size == 8
assert replaced_loader.shuffle == True

View File

@ -10,7 +10,7 @@ from fastNLP.core.samplers import (
UnrepeatedSequentialSampler,
)
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
@ -19,8 +19,8 @@ if _NEED_IMPORT_PADDLE:
import paddle.distributed as dist
from paddle.io import DataLoader, BatchSampler
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="only_error"):
paddle_model = PaddleNormalModel_Classification_1(labels, features)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
@ -465,10 +465,14 @@ class TestSetDistReproDataloader:
num_replicas = len(self.device)
num_consumed_batches = 2
already_seen_idx = set()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.set_epoch(10)
else:
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(10)
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
already_seen_idx.update(batch.tolist())
dist.barrier()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
@ -496,6 +500,7 @@ class TestSetDistReproDataloader:
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
new_loader.batch_sampler.set_epoch(10)
else:
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
@ -508,8 +513,9 @@ class TestSetDistReproDataloader:
)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
new_loader.batch_sampler.sampler.set_epoch(10)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)
left_idxes.update(batch.tolist())
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas
@ -533,7 +539,7 @@ class TestSaveLoad:
cls.driver = generate_driver(10, 10, device=[0,1])
def setup_method(self):
self.dataset = PaddleRandomMaxDataset(20, 10)
self.dataset = PaddleNormalXYDataset(40)
@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@ -545,12 +551,12 @@ class TestSaveLoad:
path = "model"
dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)
self.driver1, self.driver2 = generate_driver(40, 1), generate_driver(40, 1)
if only_state_dict:
self.driver1.save_model(path, only_state_dict)
else:
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))])
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 1))])
# 同步
dist.barrier()
@ -594,8 +600,8 @@ class TestSaveLoad:
path = "model.ckp"
num_replicas = len(device)
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False)
self.driver1, self.driver2 = generate_driver(40, 1, device=device, fp16=fp16), \
generate_driver(40, 1, device=device, fp16=False)
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(
@ -613,11 +619,12 @@ class TestSaveLoad:
already_seen_x_set = set()
already_seen_y_set = set()
self.driver1.set_sampler_epoch(dataloader, 2)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist())
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist())
# 同步
dist.barrier()
@ -669,10 +676,11 @@ class TestSaveLoad:
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
self.driver2.set_sampler_epoch(replaced_loader, 2)
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
left_x_batches.update(batch["x"].reshape((-1, )).tolist())
left_y_batches.update(batch["y"].reshape((-1, )).tolist())
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
@ -709,8 +717,8 @@ class TestSaveLoad:
num_replicas = len(device)
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
self.driver1 = generate_driver(40, 1, device=device, fp16=fp16)
self.driver2 = generate_driver(40, 1, device=device, fp16=False)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
@ -726,11 +734,12 @@ class TestSaveLoad:
already_seen_x_set = set()
already_seen_y_set = set()
self.driver1.set_sampler_epoch(dataloader, 2)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist())
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist())
# 同步
dist.barrier()
@ -779,10 +788,11 @@ class TestSaveLoad:
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
self.driver2.set_sampler_epoch(replaced_loader, 2)
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
left_x_batches.update(batch["x"].reshape((-1, )).tolist())
left_y_batches.update(batch["y"].reshape((-1, )).tolist())
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,

View File

@ -12,7 +12,7 @@ if _NEED_IMPORT_PADDLE:
@pytest.mark.paddle
def test_incorrect_driver():
model = PaddleNormalModel_Classification_1(2, 100)
model = PaddleNormalModel_Classification_1(20, 10)
with pytest.raises(ValueError):
driver = initialize_paddle_driver("torch", 0, model)
@ -26,7 +26,7 @@ def test_get_single_device(device):
测试正常情况下初始化 PaddleSingleDriver 的情况
"""
model = PaddleNormalModel_Classification_1(2, 100)
model = PaddleNormalModel_Classification_1(20, 10)
driver = initialize_paddle_driver("paddle", device, model)
assert isinstance(driver, PaddleSingleDriver)
@ -41,7 +41,7 @@ def test_get_fleet(device):
测试 fleet 多卡的初始化情况
"""
model = PaddleNormalModel_Classification_1(64, 10)
model = PaddleNormalModel_Classification_1(20, 10)
driver = initialize_paddle_driver("paddle", device, model)
assert isinstance(driver, PaddleFleetDriver)
@ -56,6 +56,6 @@ def test_device_out_of_range(device):
"""
测试传入的device超过范围的情况
"""
model = PaddleNormalModel_Classification_1(2, 100)
model = PaddleNormalModel_Classification_1(20, 10)
with pytest.raises(ValueError):
driver = initialize_paddle_driver("paddle", device, model)

View File

@ -4,14 +4,16 @@ from pathlib import Path
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
if _NEED_IMPORT_PADDLE:
import paddle
from paddle.io import DataLoader, BatchSampler
if _NEED_IMPORT_TORCH:
import torch
@ -31,102 +33,70 @@ class TestPaddleDriverFunctions:
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")
@pytest.mark.torchpaddle
def test_check_single_optimizer_legality(self):
@pytest.mark.paddle
def test_check_optimizers_legality(self):
"""
测试传入单个 optimizer 时的表现
测试对合法的 optimizers 的检查
"""
# 单个 optimizer
optimizer = paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
)
self.driver.set_optimizers(optimizer)
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
# 传入torch的optimizer时应该报错ValueError
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizer)
@pytest.mark.torchpaddle
def test_check_optimizers_legality(self):
"""
测试传入 optimizer list 的表现
"""
# optimizer 列表
optimizers = [
paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
) for i in range(10)
]
self.driver.set_optimizers(optimizers)
optimizers += [
@pytest.mark.torchpaddle
def test_invalid_optimizers(self):
"""
测试传入非法的 optimizers
"""
# 单个 optimizer
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
with pytest.raises(TypeError):
self.driver.set_optimizers(optimizer)
optimizers = [
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
]
with pytest.raises(ValueError):
with pytest.raises(TypeError):
self.driver.set_optimizers(optimizers)
@pytest.mark.torchpaddle
def test_check_dataloader_legality_in_train(self):
@pytest.mark.paddle
def test_check_dataloader_legality(self):
"""
测试 `is_train` 参数为 True _check_dataloader_legality 函数的表现
测试 check_dataloader_legality 函数的表现
"""
dataloader = DataLoader(PaddleNormalDataset())
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
self.driver.check_dataloader_legality(dataloader)
# batch_size 和 batch_sampler 均为 None 的情形
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
self.driver.check_dataloader_legality(dataloader)
# 创建torch的dataloader
@pytest.mark.torchpaddle
def test_check_dataloader_legality_invalid(self):
"""
测试 check_dataloader_legality 函数传入其他类型的表现
"""
# 创建 torch 的 dataloader
dataloader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
with pytest.raises(TypeError):
self.driver.check_dataloader_legality(dataloader)
@pytest.mark.torchpaddle
def test_check_dataloader_legality_in_test(self):
"""
测试 `is_train` 参数为 False _check_dataloader_legality 函数的表现
"""
# 此时传入的应该是dict
dataloader = {
"train": DataLoader(PaddleNormalDataset()),
"test":DataLoader(PaddleNormalDataset())
}
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
# batch_size 和 batch_sampler 均为 None 的情形
dataloader = {
"train": DataLoader(PaddleNormalDataset()),
"test":DataLoader(PaddleNormalDataset(), batch_size=None)
}
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
# 传入的不是 dict ,应该报错
dataloader = DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
# 创建 torch 的 dataloader
train_loader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
@pytest.mark.paddle
def test_tensor_to_numeric(self):
@ -505,10 +475,14 @@ class TestSetDistReproDataloader:
# 迭代两个 batch
num_consumed_batches = 2
already_seen_idx = set()
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.set_epoch(5)
else:
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(5)
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
already_seen_idx.update(batch.tolist())
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
@ -529,6 +503,7 @@ class TestSetDistReproDataloader:
)
)
new_loader.batch_sampler.load_state_dict(sampler_states)
new_loader.batch_sampler.set_epoch(5)
else:
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
@ -537,8 +512,9 @@ class TestSetDistReproDataloader:
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
new_loader.batch_sampler.sampler.set_epoch(5)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)
left_idxes.update(batch.tolist())
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
assert len(left_idxes | already_seen_idx) == len(self.dataset)
@ -549,7 +525,7 @@ class TestSetDistReproDataloader:
#
############################################################################
def generate_random_driver(features, labels, fp16=False, device="cpu"):
def generate_random_driver(labels, features, fp16=False, device="cpu"):
"""
生成driver
"""
@ -569,9 +545,9 @@ def test_save_and_load_model(only_state_dict):
"""
try:
path = "model"
dataset = PaddleRandomMaxDataset(40, 10)
dataset = PaddleNormalXYDataset(20)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu")
driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu")
if only_state_dict:
driver1.save_model(path, only_state_dict)
@ -580,6 +556,7 @@ def test_save_and_load_model(only_state_dict):
driver2.load_model(path, only_state_dict)
for batch in dataloader:
print("?")
batch = driver1.move_data_to_device(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
@ -604,22 +581,23 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
try:
path = "model.ckp"
dataset = PaddleRandomMaxDataset(40, 10)
dataset = PaddleNormalXYDataset(40)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu")
num_consumed_batches = 2
already_seen_x_set = set()
already_seen_y_set = set()
driver1.set_sampler_epoch(dataloader, 3)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist())
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist())
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
@ -656,10 +634,11 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
driver2.set_sampler_epoch(replaced_loader, 3)
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
left_x_batches.update(batch["x"].reshape((-1, )).tolist())
left_y_batches.update(batch["y"].reshape((-1, )).tolist())
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])
@ -679,14 +658,14 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict, fp16):
"""
测试save和load函数主要测试 dataloader 被替换了 batch_sampler 的情况
测试save和load函数主要测试 dataloader 被替换了 sampler 的情况
"""
try:
path = "model.ckp"
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
dataset = PaddleRandomMaxDataset(40, 10)
driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu")
dataset = PaddleNormalXYDataset(40)
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
@ -697,11 +676,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
already_seen_x_set = set()
already_seen_y_set = set()
driver1.set_sampler_epoch(dataloader, 3)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist())
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist())
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
@ -743,10 +723,11 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
driver1.set_sampler_epoch(replaced_loader, 3)
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
left_x_batches.update(batch["x"].reshape((-1, )).tolist())
left_y_batches.update(batch["y"].reshape((-1, )).tolist())
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])

View File

@ -10,7 +10,7 @@ from fastNLP.core.samplers import (
UnrepeatedSequentialSampler,
)
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
@ -19,8 +19,8 @@ if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"):
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension)
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"):
torch_model = TorchNormalModel_Classification_1(labels, features)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
driver = TorchDDPDriver(
@ -504,10 +504,14 @@ class TestSetDistReproDataloader:
num_replicas = len(self.device)
num_consumed_batches = 2
already_seen_idx = set()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.set_epoch(4)
else:
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4)
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
already_seen_idx.update(batch.tolist())
dist.barrier()
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
@ -533,6 +537,7 @@ class TestSetDistReproDataloader:
pad=True
)
new_loader.batch_sampler.load_state_dict(sampler_states)
new_loader.batch_sampler.set_epoch(4)
else:
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
@ -543,8 +548,9 @@ class TestSetDistReproDataloader:
rank=driver.global_rank
)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
new_loader.batch_sampler.sampler.set_epoch(4)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)
left_idxes.update(batch.tolist())
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas
@ -562,7 +568,7 @@ class TestSaveLoad:
"""
def setup_method(self):
self.dataset = TorchArgMaxDataset(10, 20)
self.dataset = TorchNormalXYDataset(20)
@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@ -574,7 +580,7 @@ class TestSaveLoad:
path = "model"
dataloader = DataLoader(self.dataset, batch_size=2)
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10)
driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1)
driver1.save_model(path, only_state_dict)
@ -618,8 +624,8 @@ class TestSaveLoad:
path = "model.ckp"
num_replicas = len(device)
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False)
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \
generate_driver(20, 1, device=device, fp16=False)
dataloader = dataloader_with_bucketedbatchsampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
@ -636,11 +642,12 @@ class TestSaveLoad:
already_seen_x_set = set()
already_seen_y_set = set()
driver1.set_sampler_epoch(dataloader, 4)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
# 同步
dist.barrier()
@ -665,7 +672,6 @@ class TestSaveLoad:
pad=True
)
dist.barrier()
print("========load=======", driver1.global_rank, driver2.global_rank)
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True)
dist.barrier()
replaced_loader = load_states.pop("dataloader")
@ -690,10 +696,11 @@ class TestSaveLoad:
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
driver2.set_sampler_epoch(replaced_loader, 4)
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
res1 = driver1.model(
batch,
fastnlp_fn=driver1.model.module.model.evaluate_step,
@ -716,7 +723,6 @@ class TestSaveLoad:
dist.barrier()
finally:
rank_zero_rm(path)
print("=======delete======")
if dist.is_initialized():
dist.destroy_process_group()
@ -735,8 +741,8 @@ class TestSaveLoad:
num_replicas = len(device)
driver1 = generate_driver(10, 10, device=device, fp16=fp16)
driver2 = generate_driver(10, 10, device=device, fp16=False)
driver1 = generate_driver(20, 1, device=device, fp16=fp16)
driver2 = generate_driver(20, 1, device=device, fp16=False)
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed(
@ -748,11 +754,12 @@ class TestSaveLoad:
already_seen_x_set = set()
already_seen_y_set = set()
driver1.set_sampler_epoch(dataloader, 4)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
# 同步
dist.barrier()
@ -797,10 +804,11 @@ class TestSaveLoad:
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
driver2.set_sampler_epoch(replaced_loader, 4)
for idx, batch in enumerate(replaced_loader):
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
res1 = driver1.model(
batch,
fastnlp_fn=driver1.model.module.model.evaluate_step,

View File

@ -14,7 +14,7 @@ else:
@pytest.mark.torch
def test_incorrect_driver():
model = TorchNormalModel_Classification_1(2, 100)
model = TorchNormalModel_Classification_1(20, 10)
with pytest.raises(ValueError):
driver = initialize_torch_driver("paddle", 0, model)
@ -33,7 +33,7 @@ def test_get_single_device(driver, device):
测试正常情况下初始化TorchSingleDriver的情况
"""
model = TorchNormalModel_Classification_1(2, 100)
model = TorchNormalModel_Classification_1(20, 10)
driver = initialize_torch_driver(driver, device, model)
assert isinstance(driver, TorchSingleDriver)
@ -52,7 +52,7 @@ def test_get_ddp(driver, device):
测试 ddp 多卡的初始化情况
"""
model = TorchNormalModel_Classification_1(64, 10)
model = TorchNormalModel_Classification_1(20, 10)
driver = initialize_torch_driver(driver, device, model)
assert isinstance(driver, TorchDDPDriver)
@ -70,6 +70,6 @@ def test_device_out_of_range(driver, device):
"""
测试传入的device超过范围的情况
"""
model = TorchNormalModel_Classification_1(2, 100)
model = TorchNormalModel_Classification_1(20, 10)
with pytest.raises(ValueError):
driver = initialize_torch_driver(driver, device, model)

View File

@ -6,7 +6,7 @@ from pkg_resources import parse_version
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from fastNLP.envs.distributed import rank_zero_rm
@ -15,6 +15,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch.utils.data import DataLoader, BatchSampler
if _NEED_IMPORT_PADDLE:
import paddle
@ -67,95 +68,67 @@ class TestTorchDriverFunctions:
model = TorchNormalModel_Classification_1(10, 32)
self.driver = TorchSingleDriver(model, device="cpu")
@pytest.mark.torchpaddle
def test_check_single_optimizer_legality(self):
@pytest.mark.torch
def test_check_optimizers_legality(self):
"""
测试传入单个 optimizer 时的表现
测试对合法 optimizers 的检查
"""
# 单个 optimizer
optimizer = torch.optim.Adam(
params=self.driver.model.parameters(),
lr=0.01
)
self.driver.set_optimizers(optimizer)
optimizer = paddle.optimizer.Adam(
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(),
learning_rate=0.01,
)
# 传入 torch 的 optimize r时应该报错 ValueError
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizer)
@pytest.mark.torchpaddle
def test_check_optimizers_legality(self):
"""
测试传入 optimizer list 的表现
"""
# 列表
optimizers = [
torch.optim.Adam(
params=self.driver.model.parameters(),
lr=0.01
) for i in range(10)
]
self.driver.set_optimizers(optimizers)
optimizers += [
@pytest.mark.torchpaddle
def test_invalid_optimizers(self):
"""
测试传入非法的 optimizers
"""
optimizer = paddle.optimizer.Adam(
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(),
learning_rate=0.01,
)
with pytest.raises(TypeError):
self.driver.set_optimizers(optimizer)
optimizers = [
paddle.optimizer.Adam(
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(),
learning_rate=0.01,
)
]
with pytest.raises(ValueError):
with pytest.raises(TypeError):
self.driver.set_optimizers(optimizers)
@pytest.mark.torchpaddle
def test_check_dataloader_legality_in_train(self):
@pytest.mark.torch
def test_check_dataloader_legality(self):
"""
测试 `is_train` 参数为 True _check_dataloader_legality 函数的表现
测试 check_dataloader_legality 函数的表现
"""
dataloader = DataLoader(TorchNormalDataset())
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
self.driver.check_dataloader_legality(dataloader)
@pytest.mark.torchpaddle
def test_check_dataloader_legality_invalid(self):
"""
测试 check_dataloader_legality 函数传入其他类型的表现
"""
# 创建 paddle 的 dataloader
dataloader = paddle.io.DataLoader(
PaddleNormalDataset(),
batch_size=32, shuffle=True
)
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
@pytest.mark.torchpaddle
def test_check_dataloader_legality_in_test(self):
"""
测试 `is_train` 参数为 False _check_dataloader_legality 函数的表现
"""
# 此时传入的应该是dict
dataloader = {
"train": DataLoader(TorchNormalDataset()),
"test": DataLoader(TorchNormalDataset())
}
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
# 传入的不是 dict应该报错
dataloader = DataLoader(TorchNormalDataset())
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
# 创建 paddle 的 dataloader
train_loader = paddle.io.DataLoader(
PaddleNormalDataset(),
batch_size=32, shuffle=True
)
test_loader = paddle.io.DataLoader(
PaddleNormalDataset(),
batch_size=32, shuffle=True
)
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
with pytest.raises(TypeError):
self.driver.check_dataloader_legality(dataloader)
@pytest.mark.torch
def test_tensor_to_numeric(self):
@ -515,10 +488,14 @@ class TestSetDistReproDataloader:
# 迭代两个 batch
num_consumed_batches = 2
already_seen_idx = set()
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
replaced_loader.batch_sampler.set_epoch(3)
else:
replaced_loader.batch_sampler.sampler.set_epoch(3)
for idx, batch in enumerate(replaced_loader):
if idx >= num_consumed_batches:
break
already_seen_idx.update(batch)
already_seen_idx.update(batch.tolist())
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict()
else:
@ -532,14 +509,16 @@ class TestSetDistReproDataloader:
# 重新改造 dataloader
new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False)
new_loader.batch_sampler.load_state_dict(sampler_states)
new_loader.batch_sampler.set_epoch(3)
else:
batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新构造 dataloader
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
new_loader.batch_sampler.sampler.set_epoch(3)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)
left_idxes.update(batch.tolist())
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
assert len(left_idxes | already_seen_idx) == len(self.dataset)
@ -550,7 +529,7 @@ class TestSetDistReproDataloader:
#
############################################################################
def generate_random_driver(features, labels, fp16=False, device="cpu"):
def generate_random_driver(labels, features, fp16=False, device="cpu"):
"""
生成driver
"""
@ -570,9 +549,9 @@ def test_save_and_load_model(only_state_dict):
"""
try:
path = "model"
dataset = TorchArgMaxDataset(10, 40)
dataset = TorchNormalXYDataset(20)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1)
driver1.save_model(path, only_state_dict)
driver2.load_model(path, only_state_dict)
@ -596,19 +575,20 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
try:
path = "model.ckp"
dataset = TorchArgMaxDataset(10, 40)
dataset = TorchNormalXYDataset(20)
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False)
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda")
driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda")
num_consumed_batches = 2
already_seen_x_set = set()
already_seen_y_set = set()
driver1.set_sampler_epoch(dataloader, 3)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
@ -639,11 +619,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
driver1.set_sampler_epoch(replaced_loader, 3)
for idx, batch in enumerate(replaced_loader):
batch = driver2.move_data_to_device(batch)
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert torch.equal(res1["preds"], res2["preds"])
@ -660,24 +641,25 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict, fp16):
"""
测试save和load函数主要测试 dataloader 被替换了 batch_sampler 的情况
测试save和load函数主要测试 dataloader 被替换了 sampler 的情况
"""
try:
path = "model.ckp"
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda")
dataset = TorchArgMaxDataset(10, 40)
driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda")
dataset = TorchNormalXYDataset(40)
dataloader = dataloader_with_randomsampler(dataset, 4, True, False)
num_consumed_batches = 2
already_seen_x_set = set()
already_seen_y_set = set()
driver1.set_sampler_epoch(dataloader, 3)
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
@ -711,11 +693,13 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
# set epoch
driver2.set_sampler_epoch(replaced_loader, 3)
for idx, batch in enumerate(replaced_loader):
batch = driver2.move_data_to_device(batch)
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert torch.equal(res1["preds"], res2["preds"])

View File

@ -0,0 +1,46 @@
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
import jittor as jt
from jittor.dataset import Dataset
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
class JittorNormalDataset(Dataset):
def __init__(self, num_of_data=100, **kwargs):
super(JittorNormalDataset, self).__init__(**kwargs)
self._data = list(range(num_of_data))
self.set_attrs(total_len=num_of_data)
def __getitem__(self, item):
return self._data[item]
class JittorNormalXYDataset(Dataset):
"""
可以被输入到分类模型中的普通数据集
"""
def __init__(self, num_of_data=1000, **kwargs):
super(JittorNormalXYDataset, self).__init__(**kwargs)
self.num_of_data = num_of_data
self._data = list(range(num_of_data))
self.set_attrs(total_len=num_of_data)
def __getitem__(self, item):
return {
"x": jt.Var([self._data[item]]),
"y": jt.Var([self._data[item]])
}
class JittorArgMaxDataset(Dataset):
def __init__(self, num_samples, num_features, **kwargs):
super(JittorArgMaxDataset, self).__init__(**kwargs)
self.x = jt.randn(num_samples, num_features)
self.y = self.x.argmax(dim=-1)
self.set_attrs(total_len=num_samples)
def __getitem__(self, item):
return {"x": self.x[item], "y": self.y[item]}
if __name__ == "__main__":
dataset = JittorNormalDataset()
print(len(dataset))

View File

@ -19,8 +19,24 @@ class PaddleNormalDataset(Dataset):
def __getitem__(self, item):
return self._data[item]
class PaddleNormalXYDataset(Dataset):
"""
可以被输入到分类模型中的普通数据集
"""
def __init__(self, num_of_data=1000):
self.num_of_data = num_of_data
self._data = list(range(num_of_data))
class PaddleRandomMaxDataset(Dataset):
def __len__(self):
return self.num_of_data
def __getitem__(self, item):
return {
"x": paddle.to_tensor([self._data[item]], dtype="float32"),
"y": paddle.to_tensor([self._data[item]], dtype="float32")
}
class PaddleArgMaxDataset(Dataset):
def __init__(self, num_samples, num_features):
self.x = paddle.randn((num_samples, num_features))
self.y = self.x.argmax(axis=-1)

View File

@ -1,4 +1,6 @@
from functools import reduce
from numpy import dtype
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
@ -19,6 +21,23 @@ class TorchNormalDataset(Dataset):
def __getitem__(self, item):
return self._data[item]
class TorchNormalXYDataset(Dataset):
"""
可以被输入到分类模型中的普通数据集
"""
def __init__(self, num_of_data=1000):
self.num_of_data = num_of_data
self._data = list(range(num_of_data))
def __len__(self):
return self.num_of_data
def __getitem__(self, item):
return {
"x": torch.tensor([self._data[item]], dtype=torch.float),
"y": torch.tensor([self._data[item]], dtype=torch.float)
}
# 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据;
class TorchNormalDataset_Classification(Dataset):

View File

@ -0,0 +1,57 @@
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
from jittor import Module, nn
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module
class JittorNormalModel_Classification_1(Module):
"""
基础的 jittor 分类模型
"""
def __init__(self, num_labels, feature_dimension):
super(JittorNormalModel_Classification_1, self).__init__()
self.num_labels = num_labels
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
self.ac1 = nn.ReLU()
self.linear2 = nn.Linear(in_features=64, out_features=32)
self.ac2 = nn.ReLU()
self.output = nn.Linear(in_features=32, out_features=num_labels)
self.loss_fn = nn.CrossEntropyLoss()
def execute(self, x):
x = self.ac1(self.linear1(x))
x = self.ac2(self.linear2(x))
x = self.output(x)
return x
def train_step(self, x, y):
x = self(x)
return {"loss": self.loss_fn(x, y)}
def evaluate_step(self, x, y):
x = self(x)
return {"pred": x, "target": y.reshape((-1,))}
class JittorNormalModel_Classification_2(Module):
"""
基础的 jittor 分类模型只实现 execute 函数测试用户自己初始化了分布式的场景
"""
def __init__(self, num_labels, feature_dimension):
super(JittorNormalModel_Classification_2, self).__init__()
self.num_labels = num_labels
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
self.ac1 = nn.ReLU()
self.linear2 = nn.Linear(in_features=64, out_features=32)
self.ac2 = nn.ReLU()
self.output = nn.Linear(in_features=32, out_features=num_labels)
self.loss_fn = nn.CrossEntropyLoss()
def execute(self, x, y):
x = self.ac1(self.linear1(x))
x = self.ac2(self.linear2(x))
x = self.output(x)
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))}

View File

@ -8,7 +8,7 @@ else:
class PaddleNormalModel_Classification_1(Layer):
"""
基础的paddle分类模型
基础的 paddle 分类模型
"""
def __init__(self, num_labels, feature_dimension):
super(PaddleNormalModel_Classification_1, self).__init__()
@ -39,7 +39,7 @@ class PaddleNormalModel_Classification_1(Layer):
class PaddleNormalModel_Classification_2(Layer):
"""
基础的paddle分类模型只实现 forward 函数测试用户自己初始化了分布式的场景
基础的 paddle 分类模型只实现 forward 函数测试用户自己初始化了分布式的场景
"""
def __init__(self, num_labels, feature_dimension):
super(PaddleNormalModel_Classification_2, self).__init__()
@ -56,5 +56,4 @@ class PaddleNormalModel_Classification_2(Layer):
x = self.ac1(self.linear1(x))
x = self.ac2(self.linear2(x))
x = self.output(x)
loss = self.loss_fn(x, y)
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))}

View File

@ -33,7 +33,11 @@ class TestPaddle2Torch:
"""
assert isinstance(tensor, torch.Tensor)
assert tensor.device == torch.device(device)
if device == "cpu":
assert not tensor.is_cuda
else:
assert tensor.is_cuda
assert tensor.device.index == torch.device(device).index
assert tensor.requires_grad == requires_grad
def test_gradient(self):
@ -261,7 +265,8 @@ class TestJittor2Torch:
if device == "cpu":
assert not tensor.is_cuda
else:
assert tensor.device == torch.device(device)
assert tensor.is_cuda
assert tensor.device.index == torch.device(device).index
assert tensor.requires_grad == requires_grad
def test_var_transfer(self):
@ -271,7 +276,10 @@ class TestJittor2Torch:
jittor_var = jittor.rand((3, 4, 5))
res = jittor2torch(jittor_var)
self.check_torch_tensor(res, "cpu", True)
if jittor.flags.use_cuda:
self.check_torch_tensor(res, "cuda:0", True)
else:
self.check_torch_tensor(res, "cpu", True)
res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None)
self.check_torch_tensor(res, "cuda:2", True)
@ -291,7 +299,10 @@ class TestJittor2Torch:
res = jittor2torch(jittor_list)
assert isinstance(res, list)
for t in res:
self.check_torch_tensor(t, "cpu", True)
if jittor.flags.use_cuda:
self.check_torch_tensor(t, "cuda:0", True)
else:
self.check_torch_tensor(t, "cpu", True)
res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False)
assert isinstance(res, list)
@ -327,17 +338,29 @@ class TestJittor2Torch:
}
res = jittor2torch(jittor_dict)
assert isinstance(res, dict)
self.check_torch_tensor(res["tensor"], "cpu", True)
if jittor.flags.use_cuda:
self.check_torch_tensor(res["tensor"], "cuda:0", True)
else:
self.check_torch_tensor(res["tensor"], "cpu", True)
assert isinstance(res["list"], list)
for t in res["list"]:
self.check_torch_tensor(t, "cpu", True)
if jittor.flags.use_cuda:
self.check_torch_tensor(t, "cuda:0", True)
else:
self.check_torch_tensor(t, "cpu", True)
assert isinstance(res["int"], int)
assert isinstance(res["string"], str)
assert isinstance(res["dict"], dict)
assert isinstance(res["dict"]["list"], list)
for t in res["dict"]["list"]:
self.check_torch_tensor(t, "cpu", True)
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True)
if jittor.flags.use_cuda:
self.check_torch_tensor(t, "cuda:0", True)
else:
self.check_torch_tensor(t, "cpu", True)
if jittor.flags.use_cuda:
self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", True)
else:
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True)
############################################################################