mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 19:57:34 +08:00
1. 完成JittorSingleDriver的功能,并添加测试用例 2.在Sampler中添加属性num_samplers 用于动态获取dataset的长度 3.添加便于测试断点重训的数据集 4.修改jittor其它测试的一些bug,统一ArgMaxDataset 的命名
This commit is contained in:
parent
49e8ae2daa
commit
75a3278d69
@ -24,7 +24,6 @@ from fastNLP.core.dataset import DataSet as FDataSet
|
||||
class _JittorDataset(Dataset):
|
||||
"""
|
||||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataset) -> None:
|
||||
@ -83,7 +82,7 @@ class JittorDataLoader:
|
||||
# TODO 验证支持replacesampler (以后完成) 增加Sampler
|
||||
# 将内部dataset批次设置为1
|
||||
if isinstance(dataset, Dataset):
|
||||
dataset.set_attrs(batch_size=1)
|
||||
dataset.set_attrs(batch_size=1, shuffle=False, endless=False)
|
||||
|
||||
# FastNLP Datset, collate_fn not None
|
||||
if isinstance(dataset, FDataSet) and collate_fn is None:
|
||||
@ -115,6 +114,12 @@ class JittorDataLoader:
|
||||
|
||||
self.cur_batch_indices = None
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in ["batch_size", "shuffle", "drop_last", "num_workers", "buffer_size", "stop_grad",
|
||||
"keep_numpy_array", "endless", "sampler"]:
|
||||
return getattr(self.dataset, attr)
|
||||
raise AttributeError(f"{self} has not attribute '{attr}'")
|
||||
|
||||
def __iter__(self):
|
||||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的
|
||||
if self.cur_batch_indices is None:
|
||||
|
@ -10,7 +10,7 @@ if _NEED_IMPORT_JITTOR:
|
||||
|
||||
__all__ = []
|
||||
|
||||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver:
|
||||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: "jittor.Module", **kwargs) -> JittorDriver:
|
||||
r"""
|
||||
用来根据参数 ``device`` 来确定并且初始化一个具体的 ``Driver`` 实例然后返回回去。
|
||||
|
||||
@ -30,7 +30,7 @@ def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], mo
|
||||
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].")
|
||||
|
||||
# TODO 实现更详细的判断
|
||||
if device in ["cpu", "gpu", "cuda", "cuda:0", 0, None]:
|
||||
if device in ["cpu", "gpu", "cuda", None]:
|
||||
return JittorSingleDriver(model, device, **kwargs)
|
||||
elif type(device) is int:
|
||||
return JittorMPIDriver(model, device, **kwargs)
|
||||
|
@ -1,23 +1,31 @@
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from typing import Union, Optional, Dict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
from fastNLP.core.drivers.driver import Driver
|
||||
from fastNLP.core.dataloaders import JittorDataLoader
|
||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler
|
||||
from fastNLP.core.log import logger
|
||||
from fastNLP.core.utils import apply_to_collection
|
||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_SEED_WORKERS
|
||||
from fastNLP.envs import (
|
||||
FASTNLP_MODEL_FILENAME,
|
||||
FASTNLP_CHECKPOINT_FILENAME,
|
||||
)
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor as jt
|
||||
from jittor import Module
|
||||
from jittor.optim import Optimizer
|
||||
from jittor.dataset import Dataset
|
||||
from jittor.dataset import (
|
||||
BatchSampler as JittorBatchSampler,
|
||||
Sampler as JittorSampler,
|
||||
RandomSampler as JittorRandomSampler,
|
||||
SequentialSampler as JittorSequentialSampler
|
||||
)
|
||||
|
||||
_reduces = {
|
||||
'max': jt.max,
|
||||
@ -56,6 +64,7 @@ class JittorDriver(Driver):
|
||||
else:
|
||||
jt.flags.auto_mixed_precision_level = 0
|
||||
self.fp16 = fp16
|
||||
self._auto_cast = nullcontext
|
||||
|
||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
|
||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)
|
||||
@ -68,7 +77,7 @@ class JittorDriver(Driver):
|
||||
def _check_optimizer_legality(optimizers):
|
||||
for each_optimizer in optimizers:
|
||||
if not isinstance(each_optimizer, Optimizer):
|
||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, "
|
||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, "
|
||||
f"not {type(each_optimizer)}.")
|
||||
|
||||
def step(self):
|
||||
@ -117,30 +126,118 @@ class JittorDriver(Driver):
|
||||
model = self.unwrap_model()
|
||||
model.load(filepath)
|
||||
|
||||
def save_checkpoint(self):
|
||||
...
|
||||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
if dataloader_args.sampler:
|
||||
sampler = dataloader_args.sampler
|
||||
else:
|
||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
|
||||
|
||||
num_consumed_batches = states.pop('num_consumed_batches')
|
||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
|
||||
sampler_states = sampler.state_dict()
|
||||
# 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
|
||||
# 会造成多余实际消耗的问题。因为
|
||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
|
||||
if num_consumed_samples_array is not None:
|
||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
|
||||
if dataloader_args.batch_size is not None:
|
||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
|
||||
else: # 有可能 batch_size 为 None,就只有损失精度了
|
||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
|
||||
"it may cause missing some samples when reload.")
|
||||
num_consumed_batches = sampler_states['num_consumed_samples']
|
||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
|
||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
|
||||
else:
|
||||
if dataloader_args.batch_size is not None:
|
||||
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
|
||||
* num_consumed_batches
|
||||
else:
|
||||
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
|
||||
"it may cause missing some samples when reload.")
|
||||
|
||||
states['sampler_states'] = sampler_states
|
||||
else:
|
||||
raise RuntimeError('The sampler has no `state_dict()` method, fastNLP cannot save the training '
|
||||
'state.')
|
||||
|
||||
# 2. 保存模型的状态;
|
||||
if should_save_model:
|
||||
if not os.path.exists(folder):
|
||||
os.mkdir(folder)
|
||||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
|
||||
self.save_model(model_path, only_state_dict=only_state_dict)
|
||||
|
||||
# 3. 保存 optimizers 的状态;
|
||||
states["optimizers_state_dict"] = self.get_optimizer_state()
|
||||
|
||||
# 4. 保存fp16的状态
|
||||
|
||||
logger.debug("Save optimizer state dict")
|
||||
jt.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))
|
||||
|
||||
def get_optimizer_state(self):
|
||||
# optimizers_state_dict = {}
|
||||
# for i in range(len(self.optimizers)):
|
||||
# optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
# optimizer_state = optimizer.state_dict()
|
||||
# optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
|
||||
# optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
# return optimizers_state_dict
|
||||
...
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer.state_dict() # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
return optimizers_state_dict
|
||||
|
||||
def load_optimizer_state(self, states):
|
||||
# assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
|
||||
# f"checkpoint it is:{len(states)}"
|
||||
# for i in range(len(self.optimizers)):
|
||||
# optimizer: torch.optim.Optimizer = self.optimizers[i]
|
||||
# optimizer.load_state_dict(states[f"optimizer{i}"])
|
||||
# logger.debug("Load optimizer state dict.")
|
||||
...
|
||||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
|
||||
f"checkpoint it is:{len(states)}"
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizer.load_state_dict(states[f"optimizer{i}"])
|
||||
logger.debug("Load optimizer state dict.")
|
||||
|
||||
def load_checkpoint(self):
|
||||
...
|
||||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
|
||||
|
||||
states = jt.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)))
|
||||
|
||||
# 1. 加载 optimizers 的状态;
|
||||
optimizers_state_dict = states.pop("optimizers_state_dict")
|
||||
self.load_optimizer_state(optimizers_state_dict)
|
||||
|
||||
# 2. 加载模型状态;
|
||||
if should_load_model:
|
||||
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)
|
||||
|
||||
# 3. 加载fp16的状态
|
||||
|
||||
# 4. 恢复 sampler 的状态;
|
||||
dataloader_args = self.get_dataloader_args(dataloader)
|
||||
if dataloader_args.sampler is None:
|
||||
sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=dataloader_args.shuffle)
|
||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
|
||||
sampler = dataloader_args.sampler
|
||||
elif isinstance(dataloader_args.sampler, JittorRandomSampler):
|
||||
sampler = RandomSampler(dataloader_args.sampler.dataset)
|
||||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.")
|
||||
elif isinstance(dataloader_args.sampler, JittorSequentialSampler):
|
||||
sampler = RandomSampler(dataloader_args.sampler.dataset, shuffle=False)
|
||||
logger.debug("Replace jittor Sampler into fastNLP RandomSampler without shuffle.")
|
||||
elif self.is_distributed():
|
||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our"
|
||||
"`ReproducibleSampler`.")
|
||||
else:
|
||||
raise RuntimeError(f"Jittor sampler {type(dataloader_args.sampler)} is not supported now.")
|
||||
sampler.load_state_dict(states.pop('sampler_states'))
|
||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler)
|
||||
|
||||
# 4. 修改 trainer_state.batch_idx_in_epoch
|
||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
|
||||
if dataloader_args.drop_last:
|
||||
batch_idx_in_epoch = len(
|
||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
|
||||
else:
|
||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
|
||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
|
||||
|
||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch
|
||||
|
||||
return states
|
||||
|
||||
def get_evaluate_context(self):
|
||||
return jt.no_grad
|
||||
@ -198,26 +295,8 @@ class JittorDriver(Driver):
|
||||
"""
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
|
||||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))
|
||||
process_seed = jt.get_seed()
|
||||
# back out the base seed so we can use all the bits
|
||||
base_seed = process_seed - worker_id
|
||||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
|
||||
# use 128 bits (4 x 32-bit words)
|
||||
np.random.seed(ss.generate_state(4))
|
||||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
|
||||
jittor_ss, stdlib_ss = ss.spawn(2)
|
||||
jt.set_global_seed(jittor_ss.generate_state(1, dtype=np.uint64)[0])
|
||||
# use 128 bits expressed as an integer
|
||||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
|
||||
random.seed(stdlib_seed)
|
||||
|
||||
def set_deterministic_dataloader(self, dataloader: Union["JittorDataLoader", "Dataset"]):
|
||||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None:
|
||||
dataloader.worker_init_fn = partial(self.worker_init_function,
|
||||
rank=int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)))
|
||||
...
|
||||
|
||||
def set_sampler_epoch(self, dataloader: Union["JittorDataLoader", "Dataset"], cur_epoch_idx: int):
|
||||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
|
||||
@ -226,4 +305,45 @@ class JittorDriver(Driver):
|
||||
|
||||
@staticmethod
|
||||
def get_dataloader_args(dataloader: Union["JittorDataLoader", "Dataset"]):
|
||||
pass
|
||||
@dataclass
|
||||
class Res:
|
||||
dataset: Optional[Dataset] = None
|
||||
batch_sampler: Optional[JittorBatchSampler] = None
|
||||
sampler: Optional[JittorSampler] = None
|
||||
batch_size: Optional[int] = None
|
||||
shuffle: Optional[bool] = None
|
||||
drop_last: Optional[bool] = None
|
||||
|
||||
res = Res()
|
||||
from fastNLP.core.dataloaders.jittor_dataloader.fdl import _JittorDataset
|
||||
if isinstance(dataloader, JittorDataLoader):
|
||||
# JittorDataLoader 实际上是迭代 dataset 成员的
|
||||
dataloader = dataloader.dataset
|
||||
if isinstance(dataloader, _JittorDataset):
|
||||
# 获取最原始的 dataset
|
||||
res.dataset = dataloader.dataset
|
||||
else:
|
||||
res.dataset = dataloader
|
||||
|
||||
# jittor 现在不支持 batch_sampler,所以除了 shuffle 都可以直接获取
|
||||
res.batch_size = dataloader.batch_size
|
||||
res.drop_last = dataloader.drop_last
|
||||
if dataloader.sampler is None:
|
||||
# sampler 是 None,那么就从 Dataset 的属性中获取
|
||||
res.shuffle = dataloader.shuffle
|
||||
elif isinstance(list(dataloader.sampler.__iter__())[0], (list,tuple)):
|
||||
# jittor 目前不支持 batch_sampler
|
||||
raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, "
|
||||
"please check if you have set `Dataset.sampler` as `BatchSampler`")
|
||||
else:
|
||||
# sampler 不为 None
|
||||
res.sampler = dataloader.sampler
|
||||
if hasattr(dataloader.sampler, "shuffle"):
|
||||
# 这种情况一般出现在 fastNLP 的 ReproduceSampler 中
|
||||
res.shuffle = dataloader.sampler.shuffle
|
||||
elif isinstance(dataloader.sampler, JittorRandomSampler):
|
||||
res.shuffle = True
|
||||
else:
|
||||
res.shuffle = False
|
||||
|
||||
return res
|
@ -38,6 +38,7 @@ class JittorMPIDriver(JittorDriver):
|
||||
):
|
||||
|
||||
super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs)
|
||||
raise NotImplementedError("MPI for Jittor is not supported right now.")
|
||||
|
||||
self.is_pull_by_jittor_run = is_pull_by_jittor_run
|
||||
self.parallel_device = parallel_device
|
||||
@ -100,22 +101,6 @@ class JittorMPIDriver(JittorDriver):
|
||||
return self._data_device
|
||||
return self.parallel_device
|
||||
|
||||
def step(self):
|
||||
# for optimizer in self.optimizers:
|
||||
# self.grad_scaler.step(optimizer)
|
||||
# self.grad_scaler.update()
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.step()
|
||||
|
||||
def backward(self, loss):
|
||||
# self.grad_scaler.scale(loss).backward()
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.backward(loss)
|
||||
|
||||
def zero_grad(self):
|
||||
for optimizer in self.optimizers:
|
||||
optimizer.zero_grad()
|
||||
|
||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
|
||||
if isinstance(batch, Dict) and not self.wo_auto_param_call:
|
||||
return auto_param_call(fn, batch, signature_fn=signature_fn)
|
||||
|
@ -1,14 +1,21 @@
|
||||
from typing import Dict, Union, Tuple, Callable, Optional
|
||||
|
||||
from .jittor_driver import JittorDriver
|
||||
from .utils import replace_batch_sampler, replace_sampler
|
||||
from fastNLP.core.utils import auto_param_call
|
||||
from fastNLP.core.utils.utils import _get_fun_msg
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
|
||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \
|
||||
ReproduceBatchSampler
|
||||
from fastNLP.core.samplers import RandomSampler
|
||||
from fastNLP.core.log import logger
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor as jt
|
||||
from jittor.dataset import (
|
||||
RandomSampler as JittorRandomSampler,
|
||||
SequentialSampler as JittorSequentialSampler,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JittorSingleDriver",
|
||||
@ -89,31 +96,46 @@ class JittorSingleDriver(JittorDriver):
|
||||
"""
|
||||
return False
|
||||
|
||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
|
||||
reproducible: bool = False, sampler_or_batch_sampler=None):
|
||||
# reproducible 的相关功能暂时没有实现
|
||||
def set_dist_repro_dataloader(self, dataloader,
|
||||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None,
|
||||
reproducible: bool = False):
|
||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load_checkpoint 函数调用;
|
||||
if isinstance(dist, ReproducibleBatchSampler):
|
||||
raise NotImplementedError
|
||||
dataloader.batch_sampler = dist_sample
|
||||
if isinstance(dist, ReproducibleSampler):
|
||||
raise NotImplementedError
|
||||
dataloader.batch_sampler.sampler = dist
|
||||
return replace_batch_sampler(dataloader, dist)
|
||||
elif isinstance(dist, ReproducibleSampler):
|
||||
return replace_sampler(dataloader, dist)
|
||||
|
||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
|
||||
batch_sampler = re_instantiate_sampler(args.batch_sampler)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
elif isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = re_instantiate_sampler(args.sampler)
|
||||
return replace_sampler(dataloader, sampler)
|
||||
|
||||
if reproducible:
|
||||
raise NotImplementedError
|
||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
|
||||
return dataloader
|
||||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler):
|
||||
return dataloader
|
||||
else:
|
||||
# TODO
|
||||
batch_sampler = RandomBatchSampler(
|
||||
batch_sampler=dataloader.batch_sampler,
|
||||
batch_size=dataloader.batch_sampler.batch_size,
|
||||
drop_last=dataloader.drop_last
|
||||
)
|
||||
dataloader.batch_sampler = batch_sampler
|
||||
return dataloader
|
||||
if args.sampler is None:
|
||||
sampler = RandomSampler(args.dataset, args.shuffle)
|
||||
return replace_sampler(dataloader, sampler)
|
||||
elif isinstance(args.sampler, JittorRandomSampler):
|
||||
if getattr(args.sampler, '_num_samples', None) is None \
|
||||
and getattr(args.sampler, 'rep', False) is False:
|
||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。
|
||||
sampler = RandomSampler(args.sampler.dataset, shuffle=True)
|
||||
logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.")
|
||||
return replace_sampler(dataloader, sampler)
|
||||
elif isinstance(args.sampler, JittorSequentialSampler):
|
||||
# 需要替换为不要 shuffle 的。
|
||||
sampler = RandomSampler(args.sampler.dataset, shuffle=False)
|
||||
logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.")
|
||||
return replace_sampler(dataloader, sampler)
|
||||
batch_sampler = ReproduceBatchSampler(
|
||||
batch_sampler=args.batch_sampler,
|
||||
batch_size=args.batch_size,
|
||||
drop_last=args.drop_last
|
||||
)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
return dataloader
|
||||
|
||||
|
@ -1,6 +1,29 @@
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from typing import Union
|
||||
|
||||
from fastNLP.core.dataloaders import JittorDataLoader
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor
|
||||
from jittor.dataset import Dataset
|
||||
|
||||
__all__ = []
|
||||
|
||||
def replace_batch_sampler(dataloader, batch_sampler):
|
||||
raise NotImplementedError("Jittor does not support using batch_sampler in `Dataset` now, "
|
||||
"please check if you have set `Dataset.sampler` as `BatchSampler`"
|
||||
"or report this bug to us.")
|
||||
|
||||
def replace_sampler(dataloader: Union["Dataset", "JittorDataLoader"], sampler):
|
||||
if isinstance(dataloader, JittorDataLoader):
|
||||
init_params = dict(inspect.signature(dataloader.__init__).parameters)
|
||||
reconstruct_args = {name: getattr(dataloader, name, p.default) for name, p in init_params.items()}
|
||||
reconstruct_args["dataset"] = replace_sampler(reconstruct_args["dataset"].dataset, reconstruct_args["dataset"].sampler)
|
||||
new_dataloader = type(dataloader)(**reconstruct_args)
|
||||
new_dataloader.dataset.set_attrs(sampler=sampler)
|
||||
else:
|
||||
new_dataloader = deepcopy(dataloader)
|
||||
new_dataloader.set_attrs(sampler=sampler)
|
||||
|
||||
return new_dataloader
|
@ -31,7 +31,6 @@ if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
from paddle.io import (
|
||||
DataLoader,
|
||||
IterableDataset,
|
||||
Dataset,
|
||||
Sampler,
|
||||
BatchSampler,
|
||||
@ -97,6 +96,9 @@ class PaddleDriver(Driver):
|
||||
def check_dataloader_legality(self, dataloader):
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
|
||||
if dataloader.batch_size is None and dataloader.batch_sampler is None:
|
||||
raise ValueError("Please ensure at least one of your dataloader's batch_size and batch_sampler"
|
||||
"is not None")
|
||||
|
||||
@staticmethod
|
||||
def _check_optimizer_legality(optimizers):
|
||||
@ -107,7 +109,7 @@ class PaddleDriver(Driver):
|
||||
"""
|
||||
for each_optimizer in optimizers:
|
||||
if not isinstance(each_optimizer, Optimizer):
|
||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
|
||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, "
|
||||
f"not {type(each_optimizer)}.")
|
||||
|
||||
@staticmethod
|
||||
@ -263,9 +265,7 @@ class PaddleDriver(Driver):
|
||||
optimizers_state_dict = {}
|
||||
for i in range(len(self.optimizers)):
|
||||
optimizer: Optimizer = self.optimizers[i]
|
||||
optimizer_state = optimizer.state_dict()
|
||||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu")
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
|
||||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state_to_device(optimizer.state_dict(), "cpu")
|
||||
|
||||
return optimizers_state_dict
|
||||
|
||||
@ -399,6 +399,8 @@ class PaddleDriver(Driver):
|
||||
def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx):
|
||||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
|
||||
dataloader.batch_sampler.set_epoch(cur_epoch_idx)
|
||||
elif callable(getattr(dataloader.batch_sampler.sampler, "set_epoch", None)):
|
||||
dataloader.batch_sampler.sampler.set_epoch(cur_epoch_idx)
|
||||
|
||||
@staticmethod
|
||||
def get_dataloader_args(dataloader: "DataLoader"):
|
||||
|
@ -99,7 +99,7 @@ class TorchDriver(Driver):
|
||||
def _check_optimizer_legality(optimizers):
|
||||
for each_optimizer in optimizers:
|
||||
if not isinstance(each_optimizer, Optimizer):
|
||||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, "
|
||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, "
|
||||
f"not {type(each_optimizer)}.")
|
||||
|
||||
@staticmethod
|
||||
|
@ -210,7 +210,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
self.num_consumed_samples = 0
|
||||
self.during_iter = True
|
||||
|
||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
|
||||
indices = list(range(self.num_samples))
|
||||
|
||||
if self.shuffle:
|
||||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
|
||||
@ -237,7 +237,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
if len(indices)%self.batch_size!=0:
|
||||
batches.append(indices[_num_batches*self.batch_size:])
|
||||
|
||||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas
|
||||
need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas
|
||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
|
||||
if len(batches) > 0:
|
||||
if len(batches[-1])<self.batch_size:
|
||||
@ -290,9 +290,9 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
@property
|
||||
def batch_idx_in_epoch(self):
|
||||
if self.drop_last:
|
||||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
|
||||
return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
|
||||
else:
|
||||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
|
||||
return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \
|
||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size
|
||||
|
||||
@property
|
||||
@ -313,8 +313,12 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
:return:
|
||||
"""
|
||||
num_consumed_samples = self.num_consumed_samples
|
||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
|
||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))
|
||||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \
|
||||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas))
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
return getattr(self.dataset, 'total_len', len(self.dataset))
|
||||
|
||||
def __len__(self)->int:
|
||||
"""
|
||||
@ -332,7 +336,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
|
||||
" consumed. ")
|
||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
|
||||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle,
|
||||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle,
|
||||
'batch_size': self.batch_size,
|
||||
'num_replicas': self.num_replicas}
|
||||
|
||||
@ -347,7 +351,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
|
||||
f"we cannot use {self.__class__.__name__} to load it."
|
||||
|
||||
length = states['length']
|
||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \
|
||||
assert length == self.num_samples, "The number of samples is different between the checkpoint record " \
|
||||
"and current dataset."
|
||||
self.seed = states['seed']
|
||||
self.epoch = states['epoch']
|
||||
@ -464,8 +468,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
:return:
|
||||
"""
|
||||
num_consumed_samples = self.num_consumed_samples
|
||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
|
||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))
|
||||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \
|
||||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas))
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
return getattr(self.dataset, 'total_len', len(self.dataset))
|
||||
|
||||
def __len__(self)->int:
|
||||
"""
|
||||
@ -515,7 +523,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
if len(sorted_indices)%self.batch_size!=0:
|
||||
batches.append(sorted_indices[_num_batches*self.batch_size:])
|
||||
|
||||
need_pad_num = (getattr(self.dataset, 'total_len', len(self.dataset))-self.num_consumed_samples) % self.num_replicas
|
||||
need_pad_num = (self.num_samples-self.num_consumed_samples) % self.num_replicas
|
||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
|
||||
if len(batches) > 0:
|
||||
if len(batches[-1])<self.batch_size:
|
||||
@ -593,7 +601,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
|
||||
" consumed. ")
|
||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
|
||||
'sampler_type': self.__class__.__name__, 'length': getattr(self.dataset, 'total_len', len(self.dataset)), 'shuffle': self.shuffle,
|
||||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle,
|
||||
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket,
|
||||
'num_replicas': self.num_replicas
|
||||
}
|
||||
@ -609,7 +617,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
f"we cannot use {self.__class__.__name__} to load it."
|
||||
|
||||
length = states['length']
|
||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), "The number of samples is different between the checkpoint record " \
|
||||
assert length == self.num_samples, "The number of samples is different between the checkpoint record " \
|
||||
"and current dataset."
|
||||
self.seed = states['seed']
|
||||
self.epoch = states['epoch']
|
||||
@ -630,7 +638,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
|
||||
@property
|
||||
def batch_idx_in_epoch(self):
|
||||
if self.drop_last:
|
||||
return getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
|
||||
return self.num_samples // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
|
||||
else:
|
||||
return (getattr(self.dataset, 'total_len', len(self.dataset)) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
|
||||
return (self.num_samples // self.num_replicas + self.batch_size - 1) // self.batch_size - \
|
||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size
|
@ -48,6 +48,10 @@ class ReproducibleSampler:
|
||||
def num_left_samples(self):
|
||||
raise NotImplementedError("Each specific sampler should implement its own `num_left_samples` method.")
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
raise NotImplementedError("Each specific sampler should implement its own `num_samples` method.")
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
pass
|
||||
|
||||
@ -131,19 +135,19 @@ class RandomSampler(ReproducibleSampler):
|
||||
:return:
|
||||
"""
|
||||
if self.shuffle:
|
||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
|
||||
indices = list(range(self.num_samples))
|
||||
seed = self.seed + self.epoch
|
||||
rng = np.random.default_rng(abs(seed))
|
||||
rng.shuffle(indices)
|
||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
|
||||
self.epoch -= 1
|
||||
else:
|
||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
|
||||
indices = list(range(self.num_samples))
|
||||
return indices
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
|
||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle}
|
||||
'sampler_type': self.__class__.__name__, 'length': self.num_samples, 'shuffle': self.shuffle}
|
||||
return states
|
||||
|
||||
def load_state_dict(self, states: Dict):
|
||||
@ -155,8 +159,8 @@ class RandomSampler(ReproducibleSampler):
|
||||
f"we cannot use {self.__class__.__name__} to load it."
|
||||
|
||||
length = states['length']
|
||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \
|
||||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})."
|
||||
assert length == self.num_samples, "The number of samples is different between the checkpoint " \
|
||||
f"record({length}) and current dataset({self.num_samples})."
|
||||
self.seed = states['seed']
|
||||
self.epoch = states['epoch']
|
||||
self.num_consumed_samples = states['num_consumed_samples']
|
||||
@ -208,9 +212,17 @@ class RandomSampler(ReproducibleSampler):
|
||||
:return:
|
||||
"""
|
||||
num_consumed_samples = self.num_consumed_samples
|
||||
return math.ceil((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas) if \
|
||||
self.pad else math.floor(((getattr(self.dataset, 'total_len', len(self.dataset)) - num_consumed_samples) / self.num_replicas))
|
||||
return math.ceil((self.num_samples - num_consumed_samples) / self.num_replicas) if \
|
||||
self.pad else math.floor(((self.num_samples - num_consumed_samples) / self.num_replicas))
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
"""
|
||||
返回样本的总数
|
||||
|
||||
:return:
|
||||
"""
|
||||
return getattr(self.dataset, 'total_len', len(self.dataset))
|
||||
|
||||
class SequentialSampler(RandomSampler):
|
||||
"""
|
||||
@ -258,12 +270,10 @@ class SequentialSampler(RandomSampler):
|
||||
|
||||
:return:
|
||||
"""
|
||||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
|
||||
return list(range(self.num_samples))
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__,
|
||||
'length': getattr(self.dataset, 'total_len', len(self.dataset))
|
||||
}
|
||||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': self.num_samples}
|
||||
return states
|
||||
|
||||
def load_state_dict(self, states: Dict):
|
||||
@ -275,8 +285,8 @@ class SequentialSampler(RandomSampler):
|
||||
f"we cannot use {self.__class__.__name__} to load it."
|
||||
|
||||
length = states['length']
|
||||
assert length == getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of samples is different between the checkpoint record({length}) " \
|
||||
f"and current dataset({getattr(self.dataset, 'total_len', len(self.dataset))})."
|
||||
assert length == self.num_samples, "The number of samples is different between the checkpoint " \
|
||||
f"record({length}) and current dataset({self.num_samples})."
|
||||
self.num_consumed_samples = states['num_consumed_samples']
|
||||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
|
||||
self.num_consumed_samples = 0
|
||||
@ -314,9 +324,9 @@ class SortedSampler(SequentialSampler):
|
||||
except BaseException as e:
|
||||
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.")
|
||||
|
||||
assert len(length) == getattr(self.dataset, 'total_len', len(self.dataset)), f"The length of `dataset`({len(dataset)}) and " \
|
||||
f"`length`({getattr(self.dataset, 'total_len', len(self.dataset))}) should be equal."
|
||||
assert len(self.sorted_indices) == getattr(self.dataset, 'total_len', len(self.dataset)), "The indices and dataset should have equal length."
|
||||
assert len(length) == self.num_samples, f"The length of `dataset`({len(dataset)}) and " \
|
||||
f"`length`({self.num_samples}) should be equal."
|
||||
assert len(self.sorted_indices) == self.num_samples, "The indices and dataset should have equal length."
|
||||
|
||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
|
||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的
|
||||
|
@ -42,8 +42,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
|
||||
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
|
||||
:return:
|
||||
"""
|
||||
num_common = getattr(self.dataset, 'total_len', len(self.dataset))//self.num_replicas
|
||||
num_samples = num_common + int(self.rank < (getattr(self.dataset, 'total_len', len(self.dataset))-num_common*self.num_replicas))
|
||||
num_common = self.num_samples//self.num_replicas
|
||||
num_samples = num_common + int(self.rank < (self.num_samples-num_common*self.num_replicas))
|
||||
return num_samples
|
||||
|
||||
def __iter__(self):
|
||||
@ -63,14 +63,14 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
|
||||
:return:
|
||||
"""
|
||||
if self.shuffle:
|
||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
|
||||
indices = list(range(self.num_samples))
|
||||
seed = self.seed + self.epoch
|
||||
rng = np.random.default_rng(abs(seed))
|
||||
rng.shuffle(indices)
|
||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
|
||||
self.epoch -= 1
|
||||
else:
|
||||
indices = list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
|
||||
indices = list(range(self.num_samples))
|
||||
return indices
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
@ -84,8 +84,8 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
|
||||
:param rank:
|
||||
:return:
|
||||
"""
|
||||
assert num_replicas<=getattr(self.dataset, 'total_len', len(self.dataset)), f"The number of replicas({num_replicas}) should be lesser than the " \
|
||||
f"number of samples({getattr(self.dataset, 'total_len', len(self.dataset))})."
|
||||
assert num_replicas<=self.num_samples, f"The number of replicas({num_replicas}) should be lesser than the " \
|
||||
f"number of samples({self.num_samples})."
|
||||
assert num_replicas>0 and isinstance(num_replicas, int)
|
||||
assert isinstance(rank, int) and 0<=rank<num_replicas
|
||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
|
||||
@ -94,6 +94,15 @@ class UnrepeatedRandomSampler(UnrepeatedSampler):
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
"""
|
||||
返回样本的总数
|
||||
|
||||
:return:
|
||||
"""
|
||||
return getattr(self.dataset, 'total_len', len(self.dataset))
|
||||
|
||||
|
||||
class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
|
||||
"""
|
||||
@ -147,5 +156,5 @@ class UnrepeatedSequentialSampler(UnrepeatedRandomSampler):
|
||||
yield index
|
||||
|
||||
def generate_indices(self) -> List[int]:
|
||||
return list(range(getattr(self.dataset, 'total_len', len(self.dataset))))
|
||||
return list(range(self.num_samples))
|
||||
|
||||
|
@ -27,7 +27,7 @@ from paddle.optimizer import Adam
|
||||
from paddle.io import DataLoader
|
||||
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset
|
||||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback
|
||||
|
||||
@dataclass
|
||||
@ -52,12 +52,12 @@ def test_trainer_fleet(
|
||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension),
|
||||
dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension),
|
||||
batch_size=MNISTTrainFleetConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension),
|
||||
dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension),
|
||||
batch_size=MNISTTrainFleetConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
@ -24,7 +24,7 @@ from paddle.io import DataLoader
|
||||
import paddle.distributed.fleet as fleet
|
||||
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_2
|
||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset
|
||||
from tests.helpers.callbacks.helper_callbacks import RecordMetricCallback
|
||||
|
||||
@dataclass
|
||||
@ -54,12 +54,12 @@ def test_trainer_fleet(
|
||||
optimizers = fleet.distributed_optimizer(optimizers)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
dataset=PaddleRandomMaxDataset(20, MNISTTrainFleetConfig.feature_dimension),
|
||||
dataset=PaddleArgMaxDataset(20, MNISTTrainFleetConfig.feature_dimension),
|
||||
batch_size=MNISTTrainFleetConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
dataset=PaddleRandomMaxDataset(12, MNISTTrainFleetConfig.feature_dimension),
|
||||
dataset=PaddleArgMaxDataset(12, MNISTTrainFleetConfig.feature_dimension),
|
||||
batch_size=MNISTTrainFleetConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
@ -46,8 +46,8 @@ class LSTM(Module):
|
||||
def init_hidden(self, x):
|
||||
# batch_first
|
||||
batch_size = x.shape[0]
|
||||
h0 = jt.randn(1, batch_size, hidden_size)
|
||||
c0 = jt.randn(1, batch_size, hidden_size)
|
||||
h0 = jt.randn(1, batch_size, self.hidden_size)
|
||||
c0 = jt.randn(1, batch_size, self.hidden_size)
|
||||
|
||||
return h0, c0
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
from fastNLP.core.callbacks import callback
|
||||
|
||||
from fastNLP.core.controllers.trainer import Trainer
|
||||
from fastNLP.core.controllers.trainer import Evaluator
|
||||
@ -14,6 +15,7 @@ if _NEED_IMPORT_JITTOR:
|
||||
else:
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as Module
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
|
||||
jt.flags.use_cuda=1
|
||||
|
||||
|
||||
class JittorNormalModel_Classification(Module):
|
||||
@ -68,11 +70,9 @@ class TrainJittorConfig:
|
||||
batch_size: int = 4
|
||||
shuffle: bool = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("driver", ["jittor"])
|
||||
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda:0"])
|
||||
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None])
|
||||
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
|
||||
@pytest.mark.jittor
|
||||
def test_trainer_jittor(
|
||||
driver,
|
||||
device,
|
||||
|
@ -15,7 +15,7 @@ if _NEED_IMPORT_PADDLE:
|
||||
|
||||
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleArgMaxDataset
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
|
||||
@dataclass
|
||||
@ -44,12 +44,12 @@ def test_trainer_paddle(
|
||||
)
|
||||
optimizers = Adam(parameters=model.parameters(), learning_rate=0.0001)
|
||||
train_dataloader = DataLoader(
|
||||
dataset=PaddleRandomMaxDataset(20, TrainPaddleConfig.feature_dimension),
|
||||
dataset=PaddleArgMaxDataset(20, TrainPaddleConfig.feature_dimension),
|
||||
batch_size=TrainPaddleConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
dataset=PaddleRandomMaxDataset(12, TrainPaddleConfig.feature_dimension),
|
||||
dataset=PaddleArgMaxDataset(12, TrainPaddleConfig.feature_dimension),
|
||||
batch_size=TrainPaddleConfig.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
@ -76,7 +76,7 @@ class TestPaddle:
|
||||
from paddle.io import Dataset
|
||||
import paddle
|
||||
|
||||
class PaddleRandomMaxDataset(Dataset):
|
||||
class PaddleArgMaxDataset(Dataset):
|
||||
def __init__(self, num_samples, num_features):
|
||||
self.x = paddle.randn((num_samples, num_features))
|
||||
self.y = self.x.argmax(axis=-1)
|
||||
@ -87,7 +87,7 @@ class TestPaddle:
|
||||
def __getitem__(self, item):
|
||||
return {"x": self.x[item], "y": self.y[item]}
|
||||
|
||||
ds = PaddleRandomMaxDataset(100, 2)
|
||||
ds = PaddleArgMaxDataset(100, 2)
|
||||
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4)
|
||||
for batch in dl:
|
||||
print(batch)
|
@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
|
||||
from fastNLP.core.drivers import JittorSingleDriver, JittorMPIDriver
|
||||
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver
|
||||
from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor as jt
|
||||
|
||||
@pytest.mark.jittor
|
||||
def test_incorrect_driver():
|
||||
|
||||
model = JittorNormalModel_Classification_1(20, 10)
|
||||
with pytest.raises(ValueError):
|
||||
driver = initialize_jittor_driver("torch", 0, model)
|
||||
|
||||
@pytest.mark.jittor
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
["cpu", "gpu", None, "cuda"]
|
||||
)
|
||||
def test_get_single_device(device):
|
||||
"""
|
||||
测试正常情况下初始化 JittorSingleDriver 的情况
|
||||
"""
|
||||
|
||||
model = JittorNormalModel_Classification_1(20, 10)
|
||||
driver = initialize_jittor_driver("jittor", device, model)
|
||||
assert isinstance(driver, JittorSingleDriver)
|
||||
|
||||
@pytest.mark.jittor
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[[0, 2, 3], 1, 2]
|
||||
)
|
||||
def test_get_mpi(device):
|
||||
"""
|
||||
测试 jittor 多卡的初始化情况
|
||||
"""
|
||||
|
||||
model = JittorNormalModel_Classification_1(20, 10)
|
||||
with pytest.raises(NotImplementedError):
|
||||
driver = initialize_jittor_driver("jittor", device, model)
|
||||
|
||||
# assert isinstance(driver, JittorMPIDriver)
|
@ -1,99 +1,614 @@
|
||||
import pytest
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
from fastNLP.core.drivers.jittor_driver import JittorSingleDriver
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from fastNLP.core.dataloaders import JittorDataLoader
|
||||
from tests.helpers.models.jittor_model import JittorNormalModel_Classification_1
|
||||
from tests.helpers.datasets.jittor_data import JittorNormalDataset, JittorNormalXYDataset
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor as jt # 将 jittor 引入
|
||||
from jittor import nn, Module # 引入相关的模块
|
||||
from jittor import init
|
||||
from jittor.dataset import MNIST
|
||||
else:
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as Module
|
||||
import jittor as jt
|
||||
from jittor.dataset import (
|
||||
BatchSampler as JittorBatchSampler,
|
||||
RandomSampler as JittorRandomSampler,
|
||||
SequentialSampler as JittorSequentialSampler,
|
||||
SubsetRandomSampler as JittorSubsetRandomSampler
|
||||
)
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
|
||||
def get_dataloader(dataset, use_dataloader, sampler, batch_size, shuffle, drop_last=False):
|
||||
"""
|
||||
:param dataset:
|
||||
:param use_dataloader: 是否使用 JittorDataLoader 包裹
|
||||
:param sampler: 使用 BatchSampler Samlper 还是不使用 Sampler
|
||||
"""
|
||||
if use_dataloader:
|
||||
dataloader = JittorDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataloader.dataset.set_attrs(sampler=sampler)
|
||||
else:
|
||||
dataloader = dataset
|
||||
dataloader.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, sampler=sampler)
|
||||
|
||||
return dataloader
|
||||
############################################################################
|
||||
#
|
||||
# 测试基类 JittorDrvier 中的一些简单函数
|
||||
#
|
||||
############################################################################
|
||||
|
||||
class TestJittorDriverFunctions:
|
||||
"""
|
||||
使用 JittorSingleDriver 测试基类的函数
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setup_class(self):
|
||||
model = JittorNormalModel_Classification_1(10, 32)
|
||||
self.driver = JittorSingleDriver(model, device="cpu")
|
||||
|
||||
@pytest.mark.jittor
|
||||
def test_check_optimizers_legality(self):
|
||||
"""
|
||||
测试对合法的 optimizers 的检查
|
||||
"""
|
||||
# 单个 optimizer
|
||||
optimizer = jt.optim.Adam(
|
||||
params=self.driver.model.parameters(),
|
||||
lr=0.01
|
||||
)
|
||||
self.driver.set_optimizers(optimizer)
|
||||
|
||||
# optimizer 列表
|
||||
optimizers = [
|
||||
jt.optim.Adam(
|
||||
params=self.driver.model.parameters(),
|
||||
lr=0.01
|
||||
) for i in range(10)
|
||||
]
|
||||
self.driver.set_optimizers(optimizers)
|
||||
|
||||
@pytest.mark.torchjittor
|
||||
def test_invalid_optimizers(self):
|
||||
"""
|
||||
测试传入非法的 optimizers
|
||||
"""
|
||||
# 单个 optimizer
|
||||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.set_optimizers(optimizer)
|
||||
|
||||
optimizers = [
|
||||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
|
||||
]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.set_optimizers(optimizers)
|
||||
|
||||
@pytest.mark.jittor
|
||||
def test_check_dataloader_legality(self):
|
||||
"""
|
||||
测试 check_dataloader_legality 函数的表现
|
||||
"""
|
||||
# 使用 JittorDataLoader
|
||||
dataloader = JittorDataLoader(JittorNormalDataset())
|
||||
self.driver.check_dataloader_legality(dataloader)
|
||||
# 使用 jittor.dataset.Dataset
|
||||
self.driver.check_dataloader_legality(JittorNormalDataset())
|
||||
|
||||
@pytest.mark.torchjittor
|
||||
def test_check_dataloader_legality_invalid(self):
|
||||
"""
|
||||
测试 check_dataloader_legality 函数传入其他类型的表现
|
||||
"""
|
||||
# 创建 torch 的 dataloader
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
TorchNormalDataset(),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.check_dataloader_legality(dataloader)
|
||||
|
||||
@pytest.mark.jittor
|
||||
def test_tensor_to_numeric(self):
|
||||
"""
|
||||
测试 tensor_to_numeric 函数
|
||||
"""
|
||||
# 单个张量
|
||||
tensor = jt.Var(3)
|
||||
res = JittorSingleDriver.tensor_to_numeric(tensor)
|
||||
assert res == 3
|
||||
|
||||
tensor = jt.rand(3, 4)
|
||||
res = JittorSingleDriver.tensor_to_numeric(tensor)
|
||||
assert res == tensor.tolist()
|
||||
|
||||
# 张量list
|
||||
tensor_list = [jt.rand(6, 4, 2) for i in range(10)]
|
||||
res = JittorSingleDriver.tensor_to_numeric(tensor_list)
|
||||
assert isinstance(res, list)
|
||||
tensor_list = [t.tolist() for t in tensor_list]
|
||||
assert res == tensor_list
|
||||
|
||||
# 张量tuple
|
||||
tensor_tuple = tuple([jt.rand(6, 4, 2) for i in range(10)])
|
||||
res = JittorSingleDriver.tensor_to_numeric(tensor_tuple)
|
||||
assert isinstance(res, tuple)
|
||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
|
||||
assert res == tensor_tuple
|
||||
|
||||
# 张量dict
|
||||
tensor_dict = {
|
||||
"tensor": jt.rand(3, 4),
|
||||
"list": [jt.rand(6, 4, 2) for i in range(10)],
|
||||
"dict":{
|
||||
"list": [jt.rand(6, 4, 2) for i in range(10)],
|
||||
"tensor": jt.rand(3, 4)
|
||||
},
|
||||
"int": 2,
|
||||
"string": "test string"
|
||||
}
|
||||
|
||||
res = JittorSingleDriver.tensor_to_numeric(tensor_dict)
|
||||
assert isinstance(res, dict)
|
||||
assert res["tensor"] == tensor_dict["tensor"].tolist()
|
||||
assert isinstance(res["list"], list)
|
||||
for r, d in zip(res["list"], tensor_dict["list"]):
|
||||
assert r == d.tolist()
|
||||
assert isinstance(res["int"], int)
|
||||
assert isinstance(res["string"], str)
|
||||
assert isinstance(res["dict"], dict)
|
||||
assert isinstance(res["dict"]["list"], list)
|
||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
|
||||
assert r == d.tolist()
|
||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()
|
||||
|
||||
@pytest.mark.jittor
|
||||
def test_tensor_to_numeric_reduce(self):
|
||||
tensor = jt.Var([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
|
||||
res_max = JittorSingleDriver.tensor_to_numeric(tensor, reduce="max")
|
||||
res_min = JittorSingleDriver.tensor_to_numeric(tensor, reduce="min")
|
||||
res_sum = JittorSingleDriver.tensor_to_numeric(tensor, reduce="sum")
|
||||
res_mean = JittorSingleDriver.tensor_to_numeric(tensor, reduce="mean")
|
||||
|
||||
assert res_max == 6
|
||||
assert res_min == 1
|
||||
assert res_sum == 21
|
||||
assert res_mean == 3.5
|
||||
|
||||
|
||||
@pytest.mark.jittor
|
||||
def test_set_model_mode(self):
|
||||
"""
|
||||
测试 set_model_mode 函数
|
||||
"""
|
||||
self.driver.set_model_mode("train")
|
||||
assert self.driver.model.is_training()
|
||||
self.driver.set_model_mode("eval")
|
||||
assert not self.driver.model.is_training()
|
||||
# 应该报错
|
||||
with pytest.raises(AssertionError):
|
||||
self.driver.set_model_mode("test")
|
||||
|
||||
class Model(Module):
|
||||
def __init__ (self):
|
||||
super (Model, self).__init__()
|
||||
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
|
||||
|
||||
self.conv2 = nn.Conv (32, 64, 3, 1)
|
||||
self.bn = nn.BatchNorm(64)
|
||||
@pytest.mark.jittor
|
||||
def test_move_model_to_device_cpu(self):
|
||||
"""
|
||||
测试 move_model_to_device 函数,仅测试能否运行
|
||||
"""
|
||||
JittorSingleDriver.move_model_to_device(self.driver.model, "cpu")
|
||||
|
||||
self.max_pool = nn.Pool (2, 2)
|
||||
self.relu = nn.Relu()
|
||||
self.fc1 = nn.Linear (64 * 12 * 12, 256)
|
||||
self.fc2 = nn.Linear (256, 10)
|
||||
@pytest.mark.jittor
|
||||
def test_move_model_to_device_gpu(self):
|
||||
"""
|
||||
测试 move_model_to_device 函数,仅测试能否运行
|
||||
"""
|
||||
JittorSingleDriver.move_model_to_device(self.driver.model, "gpu")
|
||||
|
||||
def execute(self, x) :
|
||||
# it's simliar to forward function in Pytorch
|
||||
x = self.conv1 (x)
|
||||
x = self.relu (x)
|
||||
|
||||
x = self.conv2 (x)
|
||||
x = self.bn (x)
|
||||
x = self.relu (x)
|
||||
|
||||
x = self.max_pool (x)
|
||||
x = jt.reshape (x, [x.shape[0], -1])
|
||||
x = self.fc1 (x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2 (x)
|
||||
return x
|
||||
@pytest.mark.jittor
|
||||
def test_set_deterministic_dataloader(self):
|
||||
"""
|
||||
测试 set_deterministic_dataloader,仅测试能否运行
|
||||
"""
|
||||
# 先确保不影响运行
|
||||
# TODO:正确性
|
||||
dataloader = JittorDataLoader(JittorNormalDataset())
|
||||
self.driver.set_deterministic_dataloader(dataloader)
|
||||
self.driver.set_deterministic_dataloader(JittorNormalDataset())
|
||||
|
||||
@pytest.mark.jittor
|
||||
def test_set_sampler_epoch(self):
|
||||
"""
|
||||
测试 set_sampler_epoch
|
||||
"""
|
||||
# 先确保不影响运行
|
||||
# TODO:正确性
|
||||
dataloader = JittorDataLoader(JittorNormalDataset())
|
||||
self.driver.set_sampler_epoch(dataloader, 0)
|
||||
self.driver.set_sampler_epoch(JittorNormalDataset(), 0)
|
||||
|
||||
@pytest.mark.jittor
|
||||
@pytest.mark.parametrize("batch_size", [16])
|
||||
@pytest.mark.parametrize("shuffle", [True, False])
|
||||
@pytest.mark.parametrize("drop_last", [True, False])
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last, use_dataloader):
|
||||
"""
|
||||
测试正常情况下 get_dataloader_args 的表现
|
||||
"""
|
||||
dataloader = get_dataloader(
|
||||
JittorNormalDataset(),
|
||||
use_dataloader=use_dataloader,
|
||||
sampler=None,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last
|
||||
)
|
||||
res = JittorSingleDriver.get_dataloader_args(dataloader)
|
||||
|
||||
assert isinstance(res.dataset, JittorNormalDataset)
|
||||
assert res.sampler is None
|
||||
assert res.shuffle == shuffle
|
||||
assert res.batch_size == batch_size
|
||||
assert res.drop_last == drop_last
|
||||
|
||||
@pytest.mark.jittor
|
||||
@pytest.mark.parametrize("batch_size", [16])
|
||||
@pytest.mark.parametrize("shuffle", [True, False])
|
||||
@pytest.mark.parametrize("drop_last", [True, False])
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last, use_dataloader):
|
||||
"""
|
||||
测试替换了 sampler 后 get_dataloader_args 的表现
|
||||
"""
|
||||
dataset = JittorNormalDataset()
|
||||
dataloader = get_dataloader(
|
||||
dataset,
|
||||
use_dataloader=use_dataloader,
|
||||
batch_size=batch_size,
|
||||
sampler=RandomSampler(dataset, shuffle=shuffle),
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last
|
||||
)
|
||||
|
||||
res = JittorSingleDriver.get_dataloader_args(dataloader)
|
||||
|
||||
assert isinstance(res.dataset, JittorNormalDataset)
|
||||
assert isinstance(res.sampler, RandomSampler)
|
||||
assert res.shuffle == shuffle
|
||||
assert res.batch_size == batch_size
|
||||
assert res.drop_last == drop_last
|
||||
|
||||
|
||||
############################################################################
|
||||
#
|
||||
# 测试 JittorSingleDrvier 中的一些简单函数
|
||||
#
|
||||
############################################################################
|
||||
|
||||
@pytest.mark.jittor
|
||||
@pytest.mark.skip("Skip jittor tests now.")
|
||||
class TestSingleDevice:
|
||||
class TestSingleDeviceFunction:
|
||||
"""
|
||||
测试其它函数的测试例
|
||||
"""
|
||||
|
||||
def test_on_gpu_without_fp16(self):
|
||||
# TODO get_dataloader
|
||||
batch_size = 64
|
||||
learning_rate = 0.1
|
||||
epochs = 5
|
||||
losses = []
|
||||
losses_idx = []
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
model = JittorNormalModel_Classification_1(10, 784)
|
||||
cls.driver = JittorSingleDriver(model, device="cpu")
|
||||
|
||||
train_loader = MNIST(train=True, batch_size=batch_size, shuffle=True)
|
||||
val_loader = MNIST(train=False, batch_size=1, shuffle=False)
|
||||
def test_unwrap_model(self):
|
||||
"""
|
||||
测试能否运行
|
||||
"""
|
||||
res = self.driver.unwrap_model()
|
||||
assert res is self.driver.model
|
||||
|
||||
model = Model()
|
||||
driver = JittorSingleDriver(model, device=[1])
|
||||
optimizer = nn.SGD(model.parameters(), learning_rate)
|
||||
driver.set_optimizers(optimizer)
|
||||
def test_is_distributed(self):
|
||||
assert self.driver.is_distributed() == False
|
||||
|
||||
for epoch in range(epochs):
|
||||
driver.set_model_mode("train")
|
||||
lens = len(train_loader)
|
||||
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
||||
outputs =driver.train_step(inputs)
|
||||
loss = nn.cross_entropy_loss(outputs, targets)
|
||||
driver.backward(loss)
|
||||
driver.step()
|
||||
driver.zero_grad()
|
||||
losses.append(loss.data[0])
|
||||
losses_idx.append(epoch * lens + batch_idx)
|
||||
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
total_acc = 0
|
||||
total_num = 0
|
||||
driver.set_model_mode("eval")
|
||||
for batch_idx, (inputs, targets) in enumerate(val_loader):
|
||||
batch_size = inputs.shape[0]
|
||||
outputs = driver.test_step(inputs)
|
||||
pred = np.argmax(outputs.data, axis=1)
|
||||
acc = np.sum(targets.data==pred)
|
||||
total_acc += acc
|
||||
total_num += batch_size
|
||||
acc = acc / batch_size
|
||||
assert total_acc / total_num > 0.95
|
||||
def test_move_data_to_device(self):
|
||||
self.driver.move_data_to_device(jt.rand(32, 64))
|
||||
|
||||
|
||||
def test_on_cpu_without_fp16(self):
|
||||
pass
|
||||
############################################################################
|
||||
#
|
||||
# 测试 set_dist_repro_dataloader 函数
|
||||
#
|
||||
############################################################################
|
||||
|
||||
def test_on_gpu_with_fp16(self):
|
||||
pass
|
||||
@pytest.mark.jittor
|
||||
class TestSetDistReproDataloader:
|
||||
"""
|
||||
专门测试 set_dist_repro_dataloader 函数的类
|
||||
"""
|
||||
def setup_method(self):
|
||||
self.dataset = JittorNormalDataset(20)
|
||||
model = JittorNormalModel_Classification_1(10, 32)
|
||||
self.driver = JittorSingleDriver(model, device="cpu")
|
||||
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_with_reproducible_false(self, use_dataloader):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现
|
||||
当dist为字符串时,此时应该返回原来的 dataloader
|
||||
"""
|
||||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=True)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
|
||||
|
||||
assert replaced_loader is dataloader
|
||||
|
||||
@pytest.mark.parametrize("shuffle", [True, False])
|
||||
@pytest.mark.parametrize("sampler", [None, "random", "sequential"])
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_with_reproducible_true(self, shuffle, sampler, use_dataloader):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
|
||||
当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler
|
||||
"""
|
||||
if sampler == "random":
|
||||
sampler = JittorRandomSampler(self.dataset)
|
||||
_shuffle = True
|
||||
elif sampler == "sequential":
|
||||
sampler = JittorSequentialSampler(self.dataset)
|
||||
_shuffle = False
|
||||
else:
|
||||
_shuffle = shuffle
|
||||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=sampler, batch_size=2, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.sampler, RandomSampler)
|
||||
assert replaced_loader.sampler.shuffle == _shuffle
|
||||
assert replaced_loader.batch_size == dataloader.batch_size
|
||||
assert replaced_loader.drop_last == dataloader.drop_last
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader)
|
||||
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_with_dist_batch_sampler(self, shuffle, use_dataloader):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
|
||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
|
||||
jittor 暂时不支持这种情况,会报错
|
||||
"""
|
||||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle)
|
||||
dist = ReproduceBatchSampler(JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False), 4, False)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
|
||||
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_with_dist_sampler(self, shuffle, use_dataloader):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
|
||||
应该返回新的 dataloader,并将 sampler 替换为 dist 对应的 Sampler
|
||||
"""
|
||||
dataloader = get_dataloader(self.dataset, use_dataloader, sampler=None, batch_size=2, shuffle=not shuffle)
|
||||
dist = RandomSampler(self.dataset, shuffle=shuffle)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.sampler, RandomSampler)
|
||||
assert replaced_loader.sampler is dist
|
||||
assert replaced_loader.batch_size == dataloader.batch_size
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader)
|
||||
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_with_dataloader_reproducible_batch_sampler(self, shuffle, use_dataloader):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
|
||||
应该返回新的 dataloader,且其余各项设置和原来相同
|
||||
"""
|
||||
dataloader = get_dataloader(
|
||||
self.dataset,
|
||||
use_dataloader=use_dataloader,
|
||||
sampler=ReproduceBatchSampler(
|
||||
JittorBatchSampler(JittorRandomSampler(self.dataset), 4, False),
|
||||
batch_size=4,
|
||||
drop_last=False,
|
||||
),
|
||||
batch_size=4,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
|
||||
|
||||
@pytest.mark.parametrize("shuffle", ([True, False]))
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_with_dataloader_reproducible_sampler(self, shuffle, use_dataloader):
|
||||
"""
|
||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
|
||||
应该返回新的 dataloader,且其余各项设置和原来相同
|
||||
"""
|
||||
dataloader = get_dataloader(
|
||||
self.dataset,
|
||||
use_dataloader=use_dataloader,
|
||||
sampler=RandomSampler(self.dataset, shuffle),
|
||||
batch_size=2,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert not (replaced_loader.sampler is dataloader.sampler)
|
||||
assert isinstance(replaced_loader.sampler, RandomSampler)
|
||||
assert replaced_loader.batch_size == 2
|
||||
assert replaced_loader.shuffle == shuffle
|
||||
|
||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle, use_dataloader)
|
||||
|
||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle, use_dataloader):
|
||||
"""
|
||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
|
||||
"""
|
||||
# 迭代两个 batch
|
||||
num_consumed_batches = 2
|
||||
already_seen_idx = set()
|
||||
replaced_loader.sampler.set_epoch(6)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_idx.update(batch.tolist())
|
||||
sampler_states = replaced_loader.sampler.state_dict()
|
||||
|
||||
# 重新加载,应该可以输出剩下的内容,且对于 JittorNormalDataset 来说,排序后应该是一个 range
|
||||
left_idxes = set()
|
||||
batch_size = replaced_loader.batch_size
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
|
||||
# 重新构造 dataloader
|
||||
if use_dataloader:
|
||||
dataset = deepcopy(replaced_loader.dataset.dataset)
|
||||
else:
|
||||
dataset = deepcopy(replaced_loader)
|
||||
new_loader = get_dataloader(
|
||||
dataset=dataset,
|
||||
use_dataloader=use_dataloader,
|
||||
sampler = RandomSampler(dataset, shuffle=shuffle),
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_last=False
|
||||
)
|
||||
new_loader.sampler.load_state_dict(sampler_states)
|
||||
new_loader.sampler.set_epoch(6)
|
||||
for idx, batch in enumerate(new_loader):
|
||||
left_idxes.update(batch.tolist())
|
||||
|
||||
print(already_seen_idx)
|
||||
print(left_idxes)
|
||||
|
||||
assert len(left_idxes) + len(already_seen_idx) == self.dataset.total_len
|
||||
assert len(left_idxes | already_seen_idx) == self.dataset.total_len
|
||||
|
||||
############################################################################
|
||||
#
|
||||
# 测试 save 和 load 相关的功能
|
||||
#
|
||||
############################################################################
|
||||
|
||||
def generate_random_driver(labels, features, fp16=False, device="cpu", lr=0.01):
|
||||
"""
|
||||
生成driver
|
||||
"""
|
||||
model = JittorNormalModel_Classification_1(labels, features)
|
||||
opt = jt.optim.Adam(params=model.parameters(), lr=lr)
|
||||
driver = JittorSingleDriver(model, device=device, fp16=fp16)
|
||||
driver.set_optimizers(opt)
|
||||
driver.setup()
|
||||
|
||||
return driver
|
||||
|
||||
@pytest.mark.jittor
|
||||
@pytest.mark.parametrize("only_state_dict", ([True, False]))
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_save_and_load_model(only_state_dict, use_dataloader):
|
||||
"""
|
||||
测试 save_model 和 load_model 函数
|
||||
"""
|
||||
try:
|
||||
path = "model"
|
||||
dataset = JittorNormalXYDataset(20)
|
||||
dataloader = get_dataloader(dataset, sampler=None, use_dataloader=use_dataloader, batch_size=4, shuffle=True)
|
||||
driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu")
|
||||
|
||||
driver1.save_model(path, only_state_dict)
|
||||
driver2.load_model(path, only_state_dict)
|
||||
|
||||
for batch in dataloader:
|
||||
batch = driver1.move_data_to_device(batch)
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
|
||||
assert jt.all_(jt.equal(res1["pred"], res2["pred"]))
|
||||
finally:
|
||||
rank_zero_rm(path)
|
||||
|
||||
@pytest.mark.jittor
|
||||
@pytest.mark.parametrize("only_state_dict", ([True, False]))
|
||||
@pytest.mark.parametrize("use_dataloader", [True, False])
|
||||
def test_save_and_load_with_randomsampler(only_state_dict, use_dataloader):
|
||||
"""
|
||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况
|
||||
"""
|
||||
|
||||
try:
|
||||
path = "model.ckp"
|
||||
|
||||
driver1, driver2 = generate_random_driver(20, 1, device="gpu", lr=0.01), \
|
||||
generate_random_driver(20, 1, device="gpu", lr=0.001)
|
||||
dataset = JittorNormalXYDataset(20)
|
||||
dataloader = get_dataloader(
|
||||
dataset, use_dataloader,
|
||||
sampler = RandomSampler(dataset, True),
|
||||
batch_size=4,
|
||||
shuffle=True
|
||||
)
|
||||
num_consumed_batches = 2
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
driver1.set_sampler_epoch(dataloader, 7)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
|
||||
|
||||
sampler_states = dataloader.sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
|
||||
|
||||
# 加载
|
||||
# 更改 batch_size
|
||||
dataloader = get_dataloader(
|
||||
dataset, use_dataloader,
|
||||
sampler=RandomSampler(dataset, True),
|
||||
batch_size=2,
|
||||
shuffle=True
|
||||
)
|
||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
replaced_loader = load_states.pop("dataloader")
|
||||
|
||||
# 1. 检查 optimizer 的状态
|
||||
assert driver2.optimizers[0].lr == driver1.optimizers[0].lr
|
||||
|
||||
# 2. 检查 sampler 是否被正确地加载和替换
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert isinstance(replaced_loader.sampler, RandomSampler)
|
||||
assert replaced_loader.sampler.seed == sampler_states["seed"]
|
||||
assert replaced_loader.sampler.epoch == sampler_states["epoch"]
|
||||
assert replaced_loader.sampler.num_consumed_samples == 4 * num_consumed_batches
|
||||
assert replaced_loader.sampler.dataset.total_len == sampler_states["length"]
|
||||
assert replaced_loader.sampler.shuffle == sampler_states["shuffle"]
|
||||
|
||||
# 4. 检查 model 的参数是否正确
|
||||
# 5. 检查 batch_idx
|
||||
start_batch = load_states.pop('batch_idx_in_epoch')
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
driver2.set_sampler_epoch(replaced_loader, 7)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
|
||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
assert jt.all_(jt.equal(res1["pred"], res2["pred"]))
|
||||
|
||||
assert len(left_x_batches) + len(already_seen_x_set) == dataset.total_len
|
||||
assert len(left_x_batches | already_seen_x_set) == dataset.total_len
|
||||
assert len(left_y_batches) + len(already_seen_y_set) == dataset.total_len
|
||||
assert len(left_y_batches | already_seen_y_set) == dataset.total_len
|
||||
finally:
|
||||
rank_zero_rm(path)
|
||||
|
@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from fastNLP.core.drivers.jittor_driver.utils import replace_sampler
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from fastNLP.core.dataloaders import JittorDataLoader
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor as jt
|
||||
|
||||
from tests.helpers.datasets.jittor_data import JittorNormalDataset
|
||||
|
||||
@pytest.mark.jittor
|
||||
@pytest.mark.parametrize("dataset", [
|
||||
JittorNormalDataset(20, batch_size=10, shuffle=True),
|
||||
JittorNormalDataset(20, batch_size=5, drop_last=True),
|
||||
JittorNormalDataset(20)
|
||||
])
|
||||
def test_replace_sampler_dataset(dataset):
|
||||
dataset = JittorNormalDataset(20)
|
||||
sampler = RandomSampler(dataset)
|
||||
|
||||
replaced_loader = replace_sampler(dataset, sampler)
|
||||
|
||||
assert not (replaced_loader is dataset)
|
||||
assert isinstance(replaced_loader.sampler, RandomSampler)
|
||||
assert replaced_loader.batch_size == dataset.batch_size
|
||||
assert replaced_loader.drop_last == dataset.drop_last
|
||||
assert replaced_loader.shuffle == dataset.shuffle
|
||||
assert replaced_loader.total_len == dataset.total_len
|
||||
|
||||
@pytest.mark.jittor
|
||||
def test_replace_sampler_jittordataloader():
|
||||
dataset = JittorNormalDataset(20, batch_size=10, shuffle=True)
|
||||
dataloader = JittorDataLoader(dataset, batch_size=8, shuffle=True)
|
||||
sampler = RandomSampler(dataset)
|
||||
|
||||
replaced_loader = replace_sampler(dataloader, sampler)
|
||||
|
||||
assert not (replaced_loader is dataloader)
|
||||
assert not (replaced_loader.dataset.dataset is dataloader.dataset.dataset)
|
||||
assert isinstance(replaced_loader.sampler, RandomSampler)
|
||||
assert replaced_loader.batch_size == 8
|
||||
assert replaced_loader.shuffle == True
|
@ -10,7 +10,7 @@ from fastNLP.core.samplers import (
|
||||
UnrepeatedSequentialSampler,
|
||||
)
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
|
||||
@ -19,8 +19,8 @@ if _NEED_IMPORT_PADDLE:
|
||||
import paddle.distributed as dist
|
||||
from paddle.io import DataLoader, BatchSampler
|
||||
|
||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
|
||||
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
|
||||
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="only_error"):
|
||||
paddle_model = PaddleNormalModel_Classification_1(labels, features)
|
||||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
|
||||
driver = PaddleFleetDriver(
|
||||
model=paddle_model,
|
||||
@ -465,10 +465,14 @@ class TestSetDistReproDataloader:
|
||||
num_replicas = len(self.device)
|
||||
num_consumed_batches = 2
|
||||
already_seen_idx = set()
|
||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.set_epoch(10)
|
||||
else:
|
||||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(10)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_idx.update(batch)
|
||||
already_seen_idx.update(batch.tolist())
|
||||
dist.barrier()
|
||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.state_dict()
|
||||
@ -496,6 +500,7 @@ class TestSetDistReproDataloader:
|
||||
pad=True
|
||||
)
|
||||
new_loader.batch_sampler.load_state_dict(sampler_states)
|
||||
new_loader.batch_sampler.set_epoch(10)
|
||||
else:
|
||||
batch_size = replaced_loader.batch_sampler.batch_size
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
|
||||
@ -508,8 +513,9 @@ class TestSetDistReproDataloader:
|
||||
)
|
||||
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
|
||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
|
||||
new_loader.batch_sampler.sampler.set_epoch(10)
|
||||
for idx, batch in enumerate(new_loader):
|
||||
left_idxes.update(batch)
|
||||
left_idxes.update(batch.tolist())
|
||||
|
||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
|
||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas
|
||||
@ -533,7 +539,7 @@ class TestSaveLoad:
|
||||
cls.driver = generate_driver(10, 10, device=[0,1])
|
||||
|
||||
def setup_method(self):
|
||||
self.dataset = PaddleRandomMaxDataset(20, 10)
|
||||
self.dataset = PaddleNormalXYDataset(40)
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("only_state_dict", ([True, False]))
|
||||
@ -545,12 +551,12 @@ class TestSaveLoad:
|
||||
path = "model"
|
||||
|
||||
dataloader = DataLoader(self.dataset, batch_size=2)
|
||||
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)
|
||||
self.driver1, self.driver2 = generate_driver(40, 1), generate_driver(40, 1)
|
||||
|
||||
if only_state_dict:
|
||||
self.driver1.save_model(path, only_state_dict)
|
||||
else:
|
||||
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 10))])
|
||||
self.driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((4, 1))])
|
||||
|
||||
# 同步
|
||||
dist.barrier()
|
||||
@ -594,8 +600,8 @@ class TestSaveLoad:
|
||||
path = "model.ckp"
|
||||
num_replicas = len(device)
|
||||
|
||||
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
|
||||
generate_driver(10, 10, device=device, fp16=False)
|
||||
self.driver1, self.driver2 = generate_driver(40, 1, device=device, fp16=fp16), \
|
||||
generate_driver(40, 1, device=device, fp16=False)
|
||||
dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_sampler=BucketedBatchSampler(
|
||||
@ -613,11 +619,12 @@ class TestSaveLoad:
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
self.driver1.set_sampler_epoch(dataloader, 2)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"])
|
||||
already_seen_y_set.update(batch["y"])
|
||||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist())
|
||||
|
||||
# 同步
|
||||
dist.barrier()
|
||||
@ -669,10 +676,11 @@ class TestSaveLoad:
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
self.driver2.set_sampler_epoch(replaced_loader, 2)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
left_x_batches.update(batch["x"].reshape((-1, )).tolist())
|
||||
left_y_batches.update(batch["y"].reshape((-1, )).tolist())
|
||||
res1 = self.driver1.model(
|
||||
batch,
|
||||
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
|
||||
@ -709,8 +717,8 @@ class TestSaveLoad:
|
||||
|
||||
num_replicas = len(device)
|
||||
|
||||
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
|
||||
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
|
||||
self.driver1 = generate_driver(40, 1, device=device, fp16=fp16)
|
||||
self.driver2 = generate_driver(40, 1, device=device, fp16=False)
|
||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
|
||||
batch_sampler.sampler = RandomSampler(self.dataset, True)
|
||||
batch_sampler.sampler.set_distributed(
|
||||
@ -726,11 +734,12 @@ class TestSaveLoad:
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
self.driver1.set_sampler_epoch(dataloader, 2)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"])
|
||||
already_seen_y_set.update(batch["y"])
|
||||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist())
|
||||
|
||||
# 同步
|
||||
dist.barrier()
|
||||
@ -779,10 +788,11 @@ class TestSaveLoad:
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
self.driver2.set_sampler_epoch(replaced_loader, 2)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
left_x_batches.update(batch["x"].reshape((-1, )).tolist())
|
||||
left_y_batches.update(batch["y"].reshape((-1, )).tolist())
|
||||
res1 = self.driver1.model(
|
||||
batch,
|
||||
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
|
||||
|
@ -12,7 +12,7 @@ if _NEED_IMPORT_PADDLE:
|
||||
@pytest.mark.paddle
|
||||
def test_incorrect_driver():
|
||||
|
||||
model = PaddleNormalModel_Classification_1(2, 100)
|
||||
model = PaddleNormalModel_Classification_1(20, 10)
|
||||
with pytest.raises(ValueError):
|
||||
driver = initialize_paddle_driver("torch", 0, model)
|
||||
|
||||
@ -26,7 +26,7 @@ def test_get_single_device(device):
|
||||
测试正常情况下初始化 PaddleSingleDriver 的情况
|
||||
"""
|
||||
|
||||
model = PaddleNormalModel_Classification_1(2, 100)
|
||||
model = PaddleNormalModel_Classification_1(20, 10)
|
||||
driver = initialize_paddle_driver("paddle", device, model)
|
||||
assert isinstance(driver, PaddleSingleDriver)
|
||||
|
||||
@ -41,7 +41,7 @@ def test_get_fleet(device):
|
||||
测试 fleet 多卡的初始化情况
|
||||
"""
|
||||
|
||||
model = PaddleNormalModel_Classification_1(64, 10)
|
||||
model = PaddleNormalModel_Classification_1(20, 10)
|
||||
driver = initialize_paddle_driver("paddle", device, model)
|
||||
|
||||
assert isinstance(driver, PaddleFleetDriver)
|
||||
@ -56,6 +56,6 @@ def test_device_out_of_range(device):
|
||||
"""
|
||||
测试传入的device超过范围的情况
|
||||
"""
|
||||
model = PaddleNormalModel_Classification_1(2, 100)
|
||||
model = PaddleNormalModel_Classification_1(20, 10)
|
||||
with pytest.raises(ValueError):
|
||||
driver = initialize_paddle_driver("paddle", device, model)
|
||||
|
@ -4,14 +4,16 @@ from pathlib import Path
|
||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
from paddle.io import DataLoader, BatchSampler
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
|
||||
@ -31,102 +33,70 @@ class TestPaddleDriverFunctions:
|
||||
model = PaddleNormalModel_Classification_1(10, 32)
|
||||
self.driver = PaddleSingleDriver(model, device="cpu")
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_single_optimizer_legality(self):
|
||||
@pytest.mark.paddle
|
||||
def test_check_optimizers_legality(self):
|
||||
"""
|
||||
测试传入单个 optimizer 时的表现
|
||||
测试对合法的 optimizers 的检查
|
||||
"""
|
||||
# 单个 optimizer
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
parameters=self.driver.model.parameters(),
|
||||
learning_rate=0.01
|
||||
)
|
||||
|
||||
self.driver.set_optimizers(optimizer)
|
||||
|
||||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
|
||||
# 传入torch的optimizer时,应该报错ValueError
|
||||
with pytest.raises(ValueError):
|
||||
self.driver.set_optimizers(optimizer)
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_optimizers_legality(self):
|
||||
"""
|
||||
测试传入 optimizer list 的表现
|
||||
"""
|
||||
# optimizer 列表
|
||||
optimizers = [
|
||||
paddle.optimizer.Adam(
|
||||
parameters=self.driver.model.parameters(),
|
||||
learning_rate=0.01
|
||||
) for i in range(10)
|
||||
]
|
||||
|
||||
self.driver.set_optimizers(optimizers)
|
||||
|
||||
optimizers += [
|
||||
@pytest.mark.torchpaddle
|
||||
def test_invalid_optimizers(self):
|
||||
"""
|
||||
测试传入非法的 optimizers
|
||||
"""
|
||||
# 单个 optimizer
|
||||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.set_optimizers(optimizer)
|
||||
|
||||
optimizers = [
|
||||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.set_optimizers(optimizers)
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_dataloader_legality_in_train(self):
|
||||
@pytest.mark.paddle
|
||||
def test_check_dataloader_legality(self):
|
||||
"""
|
||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现
|
||||
测试 check_dataloader_legality 函数的表现
|
||||
"""
|
||||
dataloader = DataLoader(PaddleNormalDataset())
|
||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
self.driver.check_dataloader_legality(dataloader)
|
||||
|
||||
# batch_size 和 batch_sampler 均为 None 的情形
|
||||
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None)
|
||||
with pytest.raises(ValueError):
|
||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
self.driver.check_dataloader_legality(dataloader)
|
||||
|
||||
# 创建torch的dataloader
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_dataloader_legality_invalid(self):
|
||||
"""
|
||||
测试 check_dataloader_legality 函数传入其他类型的表现
|
||||
"""
|
||||
# 创建 torch 的 dataloader
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
TorchNormalDataset(),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.check_dataloader_legality(dataloader)
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_dataloader_legality_in_test(self):
|
||||
"""
|
||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现
|
||||
"""
|
||||
# 此时传入的应该是dict
|
||||
dataloader = {
|
||||
"train": DataLoader(PaddleNormalDataset()),
|
||||
"test":DataLoader(PaddleNormalDataset())
|
||||
}
|
||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
|
||||
# batch_size 和 batch_sampler 均为 None 的情形
|
||||
dataloader = {
|
||||
"train": DataLoader(PaddleNormalDataset()),
|
||||
"test":DataLoader(PaddleNormalDataset(), batch_size=None)
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
|
||||
# 传入的不是 dict ,应该报错
|
||||
dataloader = DataLoader(PaddleNormalDataset())
|
||||
with pytest.raises(ValueError):
|
||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
|
||||
# 创建 torch 的 dataloader
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
TorchNormalDataset(),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
TorchNormalDataset(),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
dataloader = {"train": train_loader, "test": test_loader}
|
||||
with pytest.raises(ValueError):
|
||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
|
||||
@pytest.mark.paddle
|
||||
def test_tensor_to_numeric(self):
|
||||
@ -505,10 +475,14 @@ class TestSetDistReproDataloader:
|
||||
# 迭代两个 batch
|
||||
num_consumed_batches = 2
|
||||
already_seen_idx = set()
|
||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.set_epoch(5)
|
||||
else:
|
||||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(5)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_idx.update(batch)
|
||||
already_seen_idx.update(batch.tolist())
|
||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.state_dict()
|
||||
else:
|
||||
@ -529,6 +503,7 @@ class TestSetDistReproDataloader:
|
||||
)
|
||||
)
|
||||
new_loader.batch_sampler.load_state_dict(sampler_states)
|
||||
new_loader.batch_sampler.set_epoch(5)
|
||||
else:
|
||||
batch_size = replaced_loader.batch_sampler.batch_size
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
|
||||
@ -537,8 +512,9 @@ class TestSetDistReproDataloader:
|
||||
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
|
||||
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
|
||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
|
||||
new_loader.batch_sampler.sampler.set_epoch(5)
|
||||
for idx, batch in enumerate(new_loader):
|
||||
left_idxes.update(batch)
|
||||
left_idxes.update(batch.tolist())
|
||||
|
||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
|
||||
assert len(left_idxes | already_seen_idx) == len(self.dataset)
|
||||
@ -549,7 +525,7 @@ class TestSetDistReproDataloader:
|
||||
#
|
||||
############################################################################
|
||||
|
||||
def generate_random_driver(features, labels, fp16=False, device="cpu"):
|
||||
def generate_random_driver(labels, features, fp16=False, device="cpu"):
|
||||
"""
|
||||
生成driver
|
||||
"""
|
||||
@ -569,9 +545,9 @@ def test_save_and_load_model(only_state_dict):
|
||||
"""
|
||||
try:
|
||||
path = "model"
|
||||
dataset = PaddleRandomMaxDataset(40, 10)
|
||||
dataset = PaddleNormalXYDataset(20)
|
||||
dataloader = DataLoader(dataset, batch_size=4)
|
||||
driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu")
|
||||
driver1, driver2 = generate_random_driver(20, 1, device="gpu"), generate_random_driver(20, 1, device="gpu")
|
||||
|
||||
if only_state_dict:
|
||||
driver1.save_model(path, only_state_dict)
|
||||
@ -580,6 +556,7 @@ def test_save_and_load_model(only_state_dict):
|
||||
driver2.load_model(path, only_state_dict)
|
||||
|
||||
for batch in dataloader:
|
||||
print("?")
|
||||
batch = driver1.move_data_to_device(batch)
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
@ -604,22 +581,23 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
|
||||
try:
|
||||
path = "model.ckp"
|
||||
dataset = PaddleRandomMaxDataset(40, 10)
|
||||
dataset = PaddleNormalXYDataset(40)
|
||||
dataloader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
|
||||
)
|
||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
|
||||
driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu")
|
||||
|
||||
num_consumed_batches = 2
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
driver1.set_sampler_epoch(dataloader, 3)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"])
|
||||
already_seen_y_set.update(batch["y"])
|
||||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist())
|
||||
|
||||
sampler_states = dataloader.batch_sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
@ -656,10 +634,11 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
driver2.set_sampler_epoch(replaced_loader, 3)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
left_x_batches.update(batch["x"].reshape((-1, )).tolist())
|
||||
left_y_batches.update(batch["y"].reshape((-1, )).tolist())
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
assert paddle.equal_all(res1["pred"], res2["pred"])
|
||||
@ -679,14 +658,14 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
@pytest.mark.parametrize("fp16", ([True, False]))
|
||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16):
|
||||
"""
|
||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
|
||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况
|
||||
"""
|
||||
|
||||
try:
|
||||
path = "model.ckp"
|
||||
|
||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")
|
||||
dataset = PaddleRandomMaxDataset(40, 10)
|
||||
driver1, driver2 = generate_random_driver(40, 1, fp16, "gpu"), generate_random_driver(40, 1, False, "gpu")
|
||||
dataset = PaddleNormalXYDataset(40)
|
||||
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
|
||||
batch_sampler.sampler = RandomSampler(dataset, True)
|
||||
dataloader = DataLoader(
|
||||
@ -697,11 +676,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
driver1.set_sampler_epoch(dataloader, 3)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"])
|
||||
already_seen_y_set.update(batch["y"])
|
||||
already_seen_x_set.update(batch["x"].reshape((-1, )).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape((-1, )).tolist())
|
||||
|
||||
sampler_states = dataloader.batch_sampler.sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
@ -743,10 +723,11 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
driver1.set_sampler_epoch(replaced_loader, 3)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
left_x_batches.update(batch["x"].reshape((-1, )).tolist())
|
||||
left_y_batches.update(batch["y"].reshape((-1, )).tolist())
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
assert paddle.equal_all(res1["pred"], res2["pred"])
|
||||
|
@ -10,7 +10,7 @@ from fastNLP.core.samplers import (
|
||||
UnrepeatedSequentialSampler,
|
||||
)
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset
|
||||
from tests.helpers.utils import magic_argv_env_context
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
@ -19,8 +19,8 @@ if _NEED_IMPORT_TORCH:
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader, BatchSampler
|
||||
|
||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"):
|
||||
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension)
|
||||
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"):
|
||||
torch_model = TorchNormalModel_Classification_1(labels, features)
|
||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
|
||||
device = [torch.device(i) for i in device]
|
||||
driver = TorchDDPDriver(
|
||||
@ -504,10 +504,14 @@ class TestSetDistReproDataloader:
|
||||
num_replicas = len(self.device)
|
||||
num_consumed_batches = 2
|
||||
already_seen_idx = set()
|
||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.set_epoch(4)
|
||||
else:
|
||||
sampler_states = replaced_loader.batch_sampler.sampler.set_epoch(4)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_idx.update(batch)
|
||||
already_seen_idx.update(batch.tolist())
|
||||
dist.barrier()
|
||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.state_dict()
|
||||
@ -533,6 +537,7 @@ class TestSetDistReproDataloader:
|
||||
pad=True
|
||||
)
|
||||
new_loader.batch_sampler.load_state_dict(sampler_states)
|
||||
new_loader.batch_sampler.set_epoch(4)
|
||||
else:
|
||||
batch_size = replaced_loader.batch_sampler.batch_size
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas
|
||||
@ -543,8 +548,9 @@ class TestSetDistReproDataloader:
|
||||
rank=driver.global_rank
|
||||
)
|
||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
|
||||
new_loader.batch_sampler.sampler.set_epoch(4)
|
||||
for idx, batch in enumerate(new_loader):
|
||||
left_idxes.update(batch)
|
||||
left_idxes.update(batch.tolist())
|
||||
|
||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas
|
||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas
|
||||
@ -562,7 +568,7 @@ class TestSaveLoad:
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
self.dataset = TorchArgMaxDataset(10, 20)
|
||||
self.dataset = TorchNormalXYDataset(20)
|
||||
|
||||
@magic_argv_env_context
|
||||
@pytest.mark.parametrize("only_state_dict", ([True, False]))
|
||||
@ -574,7 +580,7 @@ class TestSaveLoad:
|
||||
path = "model"
|
||||
|
||||
dataloader = DataLoader(self.dataset, batch_size=2)
|
||||
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10)
|
||||
driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1)
|
||||
|
||||
driver1.save_model(path, only_state_dict)
|
||||
|
||||
@ -618,8 +624,8 @@ class TestSaveLoad:
|
||||
path = "model.ckp"
|
||||
num_replicas = len(device)
|
||||
|
||||
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
|
||||
generate_driver(10, 10, device=device, fp16=False)
|
||||
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \
|
||||
generate_driver(20, 1, device=device, fp16=False)
|
||||
dataloader = dataloader_with_bucketedbatchsampler(
|
||||
self.dataset,
|
||||
length=[10 for i in range(len(self.dataset))],
|
||||
@ -636,11 +642,12 @@ class TestSaveLoad:
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
driver1.set_sampler_epoch(dataloader, 4)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"])
|
||||
already_seen_y_set.update(batch["y"])
|
||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
|
||||
|
||||
# 同步
|
||||
dist.barrier()
|
||||
@ -665,7 +672,6 @@ class TestSaveLoad:
|
||||
pad=True
|
||||
)
|
||||
dist.barrier()
|
||||
print("========load=======", driver1.global_rank, driver2.global_rank)
|
||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True)
|
||||
dist.barrier()
|
||||
replaced_loader = load_states.pop("dataloader")
|
||||
@ -690,10 +696,11 @@ class TestSaveLoad:
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
driver2.set_sampler_epoch(replaced_loader, 4)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
|
||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
|
||||
res1 = driver1.model(
|
||||
batch,
|
||||
fastnlp_fn=driver1.model.module.model.evaluate_step,
|
||||
@ -716,7 +723,6 @@ class TestSaveLoad:
|
||||
dist.barrier()
|
||||
finally:
|
||||
rank_zero_rm(path)
|
||||
print("=======delete======")
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
@ -735,8 +741,8 @@ class TestSaveLoad:
|
||||
|
||||
num_replicas = len(device)
|
||||
|
||||
driver1 = generate_driver(10, 10, device=device, fp16=fp16)
|
||||
driver2 = generate_driver(10, 10, device=device, fp16=False)
|
||||
driver1 = generate_driver(20, 1, device=device, fp16=fp16)
|
||||
driver2 = generate_driver(20, 1, device=device, fp16=False)
|
||||
|
||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
|
||||
dataloader.batch_sampler.sampler.set_distributed(
|
||||
@ -748,11 +754,12 @@ class TestSaveLoad:
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
driver1.set_sampler_epoch(dataloader, 4)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"])
|
||||
already_seen_y_set.update(batch["y"])
|
||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
|
||||
|
||||
# 同步
|
||||
dist.barrier()
|
||||
@ -797,10 +804,11 @@ class TestSaveLoad:
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
driver2.set_sampler_epoch(replaced_loader, 4)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
|
||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
|
||||
res1 = driver1.model(
|
||||
batch,
|
||||
fastnlp_fn=driver1.model.module.model.evaluate_step,
|
||||
|
@ -14,7 +14,7 @@ else:
|
||||
@pytest.mark.torch
|
||||
def test_incorrect_driver():
|
||||
|
||||
model = TorchNormalModel_Classification_1(2, 100)
|
||||
model = TorchNormalModel_Classification_1(20, 10)
|
||||
with pytest.raises(ValueError):
|
||||
driver = initialize_torch_driver("paddle", 0, model)
|
||||
|
||||
@ -33,7 +33,7 @@ def test_get_single_device(driver, device):
|
||||
测试正常情况下初始化TorchSingleDriver的情况
|
||||
"""
|
||||
|
||||
model = TorchNormalModel_Classification_1(2, 100)
|
||||
model = TorchNormalModel_Classification_1(20, 10)
|
||||
driver = initialize_torch_driver(driver, device, model)
|
||||
assert isinstance(driver, TorchSingleDriver)
|
||||
|
||||
@ -52,7 +52,7 @@ def test_get_ddp(driver, device):
|
||||
测试 ddp 多卡的初始化情况
|
||||
"""
|
||||
|
||||
model = TorchNormalModel_Classification_1(64, 10)
|
||||
model = TorchNormalModel_Classification_1(20, 10)
|
||||
driver = initialize_torch_driver(driver, device, model)
|
||||
|
||||
assert isinstance(driver, TorchDDPDriver)
|
||||
@ -70,6 +70,6 @@ def test_device_out_of_range(driver, device):
|
||||
"""
|
||||
测试传入的device超过范围的情况
|
||||
"""
|
||||
model = TorchNormalModel_Classification_1(2, 100)
|
||||
model = TorchNormalModel_Classification_1(20, 10)
|
||||
with pytest.raises(ValueError):
|
||||
driver = initialize_torch_driver(driver, device, model)
|
@ -6,7 +6,7 @@ from pkg_resources import parse_version
|
||||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
|
||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
|
||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
|
||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset
|
||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
|
||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
|
||||
from fastNLP.envs.distributed import rank_zero_rm
|
||||
@ -15,6 +15,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
|
||||
if _NEED_IMPORT_TORCH:
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, BatchSampler
|
||||
|
||||
if _NEED_IMPORT_PADDLE:
|
||||
import paddle
|
||||
|
||||
@ -67,95 +68,67 @@ class TestTorchDriverFunctions:
|
||||
model = TorchNormalModel_Classification_1(10, 32)
|
||||
self.driver = TorchSingleDriver(model, device="cpu")
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_single_optimizer_legality(self):
|
||||
@pytest.mark.torch
|
||||
def test_check_optimizers_legality(self):
|
||||
"""
|
||||
测试传入单个 optimizer 时的表现
|
||||
测试对合法 optimizers 的检查
|
||||
"""
|
||||
# 单个 optimizer
|
||||
optimizer = torch.optim.Adam(
|
||||
params=self.driver.model.parameters(),
|
||||
lr=0.01
|
||||
)
|
||||
|
||||
self.driver.set_optimizers(optimizer)
|
||||
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(),
|
||||
learning_rate=0.01,
|
||||
)
|
||||
# 传入 torch 的 optimize r时,应该报错 ValueError
|
||||
with pytest.raises(ValueError):
|
||||
self.driver.set_optimizers(optimizer)
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_optimizers_legality(self):
|
||||
"""
|
||||
测试传入 optimizer list 的表现
|
||||
"""
|
||||
# 列表
|
||||
optimizers = [
|
||||
torch.optim.Adam(
|
||||
params=self.driver.model.parameters(),
|
||||
lr=0.01
|
||||
) for i in range(10)
|
||||
]
|
||||
|
||||
self.driver.set_optimizers(optimizers)
|
||||
|
||||
optimizers += [
|
||||
@pytest.mark.torchpaddle
|
||||
def test_invalid_optimizers(self):
|
||||
"""
|
||||
测试传入非法的 optimizers
|
||||
"""
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(),
|
||||
learning_rate=0.01,
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.set_optimizers(optimizer)
|
||||
|
||||
optimizers = [
|
||||
paddle.optimizer.Adam(
|
||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(),
|
||||
learning_rate=0.01,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.set_optimizers(optimizers)
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_dataloader_legality_in_train(self):
|
||||
@pytest.mark.torch
|
||||
def test_check_dataloader_legality(self):
|
||||
"""
|
||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现
|
||||
测试 check_dataloader_legality 函数的表现
|
||||
"""
|
||||
dataloader = DataLoader(TorchNormalDataset())
|
||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
self.driver.check_dataloader_legality(dataloader)
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_dataloader_legality_invalid(self):
|
||||
"""
|
||||
测试 check_dataloader_legality 函数传入其他类型的表现
|
||||
"""
|
||||
# 创建 paddle 的 dataloader
|
||||
dataloader = paddle.io.DataLoader(
|
||||
PaddleNormalDataset(),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
|
||||
@pytest.mark.torchpaddle
|
||||
def test_check_dataloader_legality_in_test(self):
|
||||
"""
|
||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现
|
||||
"""
|
||||
# 此时传入的应该是dict
|
||||
dataloader = {
|
||||
"train": DataLoader(TorchNormalDataset()),
|
||||
"test": DataLoader(TorchNormalDataset())
|
||||
}
|
||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
|
||||
# 传入的不是 dict,应该报错
|
||||
dataloader = DataLoader(TorchNormalDataset())
|
||||
with pytest.raises(ValueError):
|
||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
|
||||
# 创建 paddle 的 dataloader
|
||||
train_loader = paddle.io.DataLoader(
|
||||
PaddleNormalDataset(),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
test_loader = paddle.io.DataLoader(
|
||||
PaddleNormalDataset(),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
dataloader = {"train": train_loader, "test": test_loader}
|
||||
with pytest.raises(ValueError):
|
||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")
|
||||
with pytest.raises(TypeError):
|
||||
self.driver.check_dataloader_legality(dataloader)
|
||||
|
||||
@pytest.mark.torch
|
||||
def test_tensor_to_numeric(self):
|
||||
@ -515,10 +488,14 @@ class TestSetDistReproDataloader:
|
||||
# 迭代两个 batch
|
||||
num_consumed_batches = 2
|
||||
already_seen_idx = set()
|
||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
|
||||
replaced_loader.batch_sampler.set_epoch(3)
|
||||
else:
|
||||
replaced_loader.batch_sampler.sampler.set_epoch(3)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_idx.update(batch)
|
||||
already_seen_idx.update(batch.tolist())
|
||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
|
||||
sampler_states = replaced_loader.batch_sampler.state_dict()
|
||||
else:
|
||||
@ -532,14 +509,16 @@ class TestSetDistReproDataloader:
|
||||
# 重新改造 dataloader
|
||||
new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False)
|
||||
new_loader.batch_sampler.load_state_dict(sampler_states)
|
||||
new_loader.batch_sampler.set_epoch(3)
|
||||
else:
|
||||
batch_size = replaced_loader.batch_sampler.batch_size
|
||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
|
||||
# 重新构造 dataloader
|
||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False)
|
||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
|
||||
new_loader.batch_sampler.sampler.set_epoch(3)
|
||||
for idx, batch in enumerate(new_loader):
|
||||
left_idxes.update(batch)
|
||||
left_idxes.update(batch.tolist())
|
||||
|
||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
|
||||
assert len(left_idxes | already_seen_idx) == len(self.dataset)
|
||||
@ -550,7 +529,7 @@ class TestSetDistReproDataloader:
|
||||
#
|
||||
############################################################################
|
||||
|
||||
def generate_random_driver(features, labels, fp16=False, device="cpu"):
|
||||
def generate_random_driver(labels, features, fp16=False, device="cpu"):
|
||||
"""
|
||||
生成driver
|
||||
"""
|
||||
@ -570,9 +549,9 @@ def test_save_and_load_model(only_state_dict):
|
||||
"""
|
||||
try:
|
||||
path = "model"
|
||||
dataset = TorchArgMaxDataset(10, 40)
|
||||
dataset = TorchNormalXYDataset(20)
|
||||
dataloader = DataLoader(dataset, batch_size=4)
|
||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
|
||||
driver1, driver2 = generate_random_driver(20, 1), generate_random_driver(20, 1)
|
||||
|
||||
driver1.save_model(path, only_state_dict)
|
||||
driver2.load_model(path, only_state_dict)
|
||||
@ -596,19 +575,20 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
|
||||
try:
|
||||
path = "model.ckp"
|
||||
dataset = TorchArgMaxDataset(10, 40)
|
||||
dataset = TorchNormalXYDataset(20)
|
||||
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False)
|
||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda")
|
||||
driver1, driver2 = generate_random_driver(20, 1, fp16, "cuda"), generate_random_driver(20, 1, False, "cuda")
|
||||
|
||||
num_consumed_batches = 2
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
driver1.set_sampler_epoch(dataloader, 3)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"])
|
||||
already_seen_y_set.update(batch["y"])
|
||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
|
||||
|
||||
sampler_states = dataloader.batch_sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
@ -639,11 +619,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
driver1.set_sampler_epoch(replaced_loader, 3)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
batch = driver2.move_data_to_device(batch)
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
|
||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
assert torch.equal(res1["preds"], res2["preds"])
|
||||
@ -660,24 +641,25 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
|
||||
@pytest.mark.parametrize("fp16", ([True, False]))
|
||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16):
|
||||
"""
|
||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
|
||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 的情况
|
||||
"""
|
||||
|
||||
try:
|
||||
path = "model.ckp"
|
||||
|
||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda")
|
||||
dataset = TorchArgMaxDataset(10, 40)
|
||||
driver1, driver2 = generate_random_driver(40, 1, fp16, "cuda"), generate_random_driver(40, 1, False, "cuda")
|
||||
dataset = TorchNormalXYDataset(40)
|
||||
dataloader = dataloader_with_randomsampler(dataset, 4, True, False)
|
||||
num_consumed_batches = 2
|
||||
|
||||
already_seen_x_set = set()
|
||||
already_seen_y_set = set()
|
||||
driver1.set_sampler_epoch(dataloader, 3)
|
||||
for idx, batch in enumerate(dataloader):
|
||||
if idx >= num_consumed_batches:
|
||||
break
|
||||
already_seen_x_set.update(batch["x"])
|
||||
already_seen_y_set.update(batch["y"])
|
||||
already_seen_x_set.update(batch["x"].reshape(-1, ).tolist())
|
||||
already_seen_y_set.update(batch["y"].reshape(-1, ).tolist())
|
||||
|
||||
sampler_states = dataloader.batch_sampler.sampler.state_dict()
|
||||
save_states = {"num_consumed_batches": num_consumed_batches}
|
||||
@ -711,11 +693,13 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
|
||||
assert start_batch == 2 * num_consumed_batches
|
||||
left_x_batches = set()
|
||||
left_y_batches = set()
|
||||
# set epoch
|
||||
driver2.set_sampler_epoch(replaced_loader, 3)
|
||||
for idx, batch in enumerate(replaced_loader):
|
||||
|
||||
batch = driver2.move_data_to_device(batch)
|
||||
left_x_batches.update(batch["x"])
|
||||
left_y_batches.update(batch["y"])
|
||||
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
|
||||
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
|
||||
res1 = driver1.model.evaluate_step(**batch)
|
||||
res2 = driver2.model.evaluate_step(**batch)
|
||||
assert torch.equal(res1["preds"], res2["preds"])
|
||||
|
46
tests/helpers/datasets/jittor_data.py
Normal file
46
tests/helpers/datasets/jittor_data.py
Normal file
@ -0,0 +1,46 @@
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
import jittor as jt
|
||||
from jittor.dataset import Dataset
|
||||
else:
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
|
||||
|
||||
class JittorNormalDataset(Dataset):
|
||||
def __init__(self, num_of_data=100, **kwargs):
|
||||
super(JittorNormalDataset, self).__init__(**kwargs)
|
||||
self._data = list(range(num_of_data))
|
||||
self.set_attrs(total_len=num_of_data)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._data[item]
|
||||
|
||||
class JittorNormalXYDataset(Dataset):
|
||||
"""
|
||||
可以被输入到分类模型中的普通数据集
|
||||
"""
|
||||
def __init__(self, num_of_data=1000, **kwargs):
|
||||
super(JittorNormalXYDataset, self).__init__(**kwargs)
|
||||
self.num_of_data = num_of_data
|
||||
self._data = list(range(num_of_data))
|
||||
self.set_attrs(total_len=num_of_data)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return {
|
||||
"x": jt.Var([self._data[item]]),
|
||||
"y": jt.Var([self._data[item]])
|
||||
}
|
||||
|
||||
class JittorArgMaxDataset(Dataset):
|
||||
def __init__(self, num_samples, num_features, **kwargs):
|
||||
super(JittorArgMaxDataset, self).__init__(**kwargs)
|
||||
self.x = jt.randn(num_samples, num_features)
|
||||
self.y = self.x.argmax(dim=-1)
|
||||
self.set_attrs(total_len=num_samples)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return {"x": self.x[item], "y": self.y[item]}
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = JittorNormalDataset()
|
||||
print(len(dataset))
|
@ -19,8 +19,24 @@ class PaddleNormalDataset(Dataset):
|
||||
def __getitem__(self, item):
|
||||
return self._data[item]
|
||||
|
||||
class PaddleNormalXYDataset(Dataset):
|
||||
"""
|
||||
可以被输入到分类模型中的普通数据集
|
||||
"""
|
||||
def __init__(self, num_of_data=1000):
|
||||
self.num_of_data = num_of_data
|
||||
self._data = list(range(num_of_data))
|
||||
|
||||
class PaddleRandomMaxDataset(Dataset):
|
||||
def __len__(self):
|
||||
return self.num_of_data
|
||||
|
||||
def __getitem__(self, item):
|
||||
return {
|
||||
"x": paddle.to_tensor([self._data[item]], dtype="float32"),
|
||||
"y": paddle.to_tensor([self._data[item]], dtype="float32")
|
||||
}
|
||||
|
||||
class PaddleArgMaxDataset(Dataset):
|
||||
def __init__(self, num_samples, num_features):
|
||||
self.x = paddle.randn((num_samples, num_features))
|
||||
self.y = self.x.argmax(axis=-1)
|
||||
|
@ -1,4 +1,6 @@
|
||||
from functools import reduce
|
||||
|
||||
from numpy import dtype
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
|
||||
|
||||
if _NEED_IMPORT_TORCH:
|
||||
@ -19,6 +21,23 @@ class TorchNormalDataset(Dataset):
|
||||
def __getitem__(self, item):
|
||||
return self._data[item]
|
||||
|
||||
class TorchNormalXYDataset(Dataset):
|
||||
"""
|
||||
可以被输入到分类模型中的普通数据集
|
||||
"""
|
||||
def __init__(self, num_of_data=1000):
|
||||
self.num_of_data = num_of_data
|
||||
self._data = list(range(num_of_data))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_of_data
|
||||
|
||||
def __getitem__(self, item):
|
||||
return {
|
||||
"x": torch.tensor([self._data[item]], dtype=torch.float),
|
||||
"y": torch.tensor([self._data[item]], dtype=torch.float)
|
||||
}
|
||||
|
||||
|
||||
# 该类专门用于为 tests.helpers.models.torch_model.py/ TorchNormalModel_Classification_1 创建数据;
|
||||
class TorchNormalDataset_Classification(Dataset):
|
||||
|
57
tests/helpers/models/jittor_model.py
Normal file
57
tests/helpers/models/jittor_model.py
Normal file
@ -0,0 +1,57 @@
|
||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
|
||||
if _NEED_IMPORT_JITTOR:
|
||||
from jittor import Module, nn
|
||||
else:
|
||||
from fastNLP.core.utils.dummy_class import DummyClass as Module
|
||||
|
||||
class JittorNormalModel_Classification_1(Module):
|
||||
"""
|
||||
基础的 jittor 分类模型
|
||||
"""
|
||||
def __init__(self, num_labels, feature_dimension):
|
||||
super(JittorNormalModel_Classification_1, self).__init__()
|
||||
self.num_labels = num_labels
|
||||
|
||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
|
||||
self.ac1 = nn.ReLU()
|
||||
self.linear2 = nn.Linear(in_features=64, out_features=32)
|
||||
self.ac2 = nn.ReLU()
|
||||
self.output = nn.Linear(in_features=32, out_features=num_labels)
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def execute(self, x):
|
||||
x = self.ac1(self.linear1(x))
|
||||
x = self.ac2(self.linear2(x))
|
||||
x = self.output(x)
|
||||
return x
|
||||
|
||||
def train_step(self, x, y):
|
||||
x = self(x)
|
||||
return {"loss": self.loss_fn(x, y)}
|
||||
|
||||
def evaluate_step(self, x, y):
|
||||
|
||||
x = self(x)
|
||||
return {"pred": x, "target": y.reshape((-1,))}
|
||||
|
||||
|
||||
class JittorNormalModel_Classification_2(Module):
|
||||
"""
|
||||
基础的 jittor 分类模型,只实现 execute 函数测试用户自己初始化了分布式的场景
|
||||
"""
|
||||
def __init__(self, num_labels, feature_dimension):
|
||||
super(JittorNormalModel_Classification_2, self).__init__()
|
||||
self.num_labels = num_labels
|
||||
|
||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
|
||||
self.ac1 = nn.ReLU()
|
||||
self.linear2 = nn.Linear(in_features=64, out_features=32)
|
||||
self.ac2 = nn.ReLU()
|
||||
self.output = nn.Linear(in_features=32, out_features=num_labels)
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def execute(self, x, y):
|
||||
x = self.ac1(self.linear1(x))
|
||||
x = self.ac2(self.linear2(x))
|
||||
x = self.output(x)
|
||||
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))}
|
@ -8,7 +8,7 @@ else:
|
||||
|
||||
class PaddleNormalModel_Classification_1(Layer):
|
||||
"""
|
||||
基础的paddle分类模型
|
||||
基础的 paddle 分类模型
|
||||
"""
|
||||
def __init__(self, num_labels, feature_dimension):
|
||||
super(PaddleNormalModel_Classification_1, self).__init__()
|
||||
@ -39,7 +39,7 @@ class PaddleNormalModel_Classification_1(Layer):
|
||||
|
||||
class PaddleNormalModel_Classification_2(Layer):
|
||||
"""
|
||||
基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景
|
||||
基础的 paddle 分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景
|
||||
"""
|
||||
def __init__(self, num_labels, feature_dimension):
|
||||
super(PaddleNormalModel_Classification_2, self).__init__()
|
||||
@ -56,5 +56,4 @@ class PaddleNormalModel_Classification_2(Layer):
|
||||
x = self.ac1(self.linear1(x))
|
||||
x = self.ac2(self.linear2(x))
|
||||
x = self.output(x)
|
||||
loss = self.loss_fn(x, y)
|
||||
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))}
|
||||
|
@ -33,7 +33,11 @@ class TestPaddle2Torch:
|
||||
"""
|
||||
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.device == torch.device(device)
|
||||
if device == "cpu":
|
||||
assert not tensor.is_cuda
|
||||
else:
|
||||
assert tensor.is_cuda
|
||||
assert tensor.device.index == torch.device(device).index
|
||||
assert tensor.requires_grad == requires_grad
|
||||
|
||||
def test_gradient(self):
|
||||
@ -261,7 +265,8 @@ class TestJittor2Torch:
|
||||
if device == "cpu":
|
||||
assert not tensor.is_cuda
|
||||
else:
|
||||
assert tensor.device == torch.device(device)
|
||||
assert tensor.is_cuda
|
||||
assert tensor.device.index == torch.device(device).index
|
||||
assert tensor.requires_grad == requires_grad
|
||||
|
||||
def test_var_transfer(self):
|
||||
@ -271,7 +276,10 @@ class TestJittor2Torch:
|
||||
|
||||
jittor_var = jittor.rand((3, 4, 5))
|
||||
res = jittor2torch(jittor_var)
|
||||
self.check_torch_tensor(res, "cpu", True)
|
||||
if jittor.flags.use_cuda:
|
||||
self.check_torch_tensor(res, "cuda:0", True)
|
||||
else:
|
||||
self.check_torch_tensor(res, "cpu", True)
|
||||
|
||||
res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None)
|
||||
self.check_torch_tensor(res, "cuda:2", True)
|
||||
@ -291,7 +299,10 @@ class TestJittor2Torch:
|
||||
res = jittor2torch(jittor_list)
|
||||
assert isinstance(res, list)
|
||||
for t in res:
|
||||
self.check_torch_tensor(t, "cpu", True)
|
||||
if jittor.flags.use_cuda:
|
||||
self.check_torch_tensor(t, "cuda:0", True)
|
||||
else:
|
||||
self.check_torch_tensor(t, "cpu", True)
|
||||
|
||||
res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False)
|
||||
assert isinstance(res, list)
|
||||
@ -327,17 +338,29 @@ class TestJittor2Torch:
|
||||
}
|
||||
res = jittor2torch(jittor_dict)
|
||||
assert isinstance(res, dict)
|
||||
self.check_torch_tensor(res["tensor"], "cpu", True)
|
||||
if jittor.flags.use_cuda:
|
||||
self.check_torch_tensor(res["tensor"], "cuda:0", True)
|
||||
else:
|
||||
self.check_torch_tensor(res["tensor"], "cpu", True)
|
||||
assert isinstance(res["list"], list)
|
||||
for t in res["list"]:
|
||||
self.check_torch_tensor(t, "cpu", True)
|
||||
if jittor.flags.use_cuda:
|
||||
self.check_torch_tensor(t, "cuda:0", True)
|
||||
else:
|
||||
self.check_torch_tensor(t, "cpu", True)
|
||||
assert isinstance(res["int"], int)
|
||||
assert isinstance(res["string"], str)
|
||||
assert isinstance(res["dict"], dict)
|
||||
assert isinstance(res["dict"]["list"], list)
|
||||
for t in res["dict"]["list"]:
|
||||
self.check_torch_tensor(t, "cpu", True)
|
||||
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True)
|
||||
if jittor.flags.use_cuda:
|
||||
self.check_torch_tensor(t, "cuda:0", True)
|
||||
else:
|
||||
self.check_torch_tensor(t, "cpu", True)
|
||||
if jittor.flags.use_cuda:
|
||||
self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", True)
|
||||
else:
|
||||
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True)
|
||||
|
||||
|
||||
############################################################################
|
||||
|
Loading…
Reference in New Issue
Block a user