mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 19:57:34 +08:00
追加paddle和oneflow关于替换dataloader的改动
This commit is contained in:
parent
28db704f70
commit
3d8214c783
@ -245,10 +245,15 @@ class OneflowDDPDriver(OneflowDriver):
|
||||
# evaluator
|
||||
elif dist == "unrepeatdist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if type(args.batch_sampler) != BatchSampler:
|
||||
# TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善
|
||||
logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \
|
||||
"train dataloader while testing ``overfit_batches``, which may cause that" \
|
||||
"the data for distributed evaluation is not unrepeated.")
|
||||
if isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
|
||||
elif not isinstance(args.sampler, UnrepeatedSampler):
|
||||
_check_dataloader_args_for_distributed(args, controller="Evaluator")
|
||||
_check_dataloader_args_for_distributed(args, controller='Evaluator')
|
||||
sampler = UnrepeatedSequentialSampler(
|
||||
dataset=args.dataset
|
||||
)
|
||||
@ -258,6 +263,7 @@ class OneflowDDPDriver(OneflowDriver):
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank
|
||||
)
|
||||
# TODO 这里暂时统一替换为 BatchSampler
|
||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
|
@ -43,7 +43,6 @@ def initialize_oneflow_driver(driver: str, device: Optional[Union[str, "oneflow.
|
||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
|
||||
device = [oneflow.device(f"cuda:{w}") for w in range(_could_use_device_num)]
|
||||
elif device >= _could_use_device_num:
|
||||
print(device, _could_use_device_num)
|
||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
|
||||
else:
|
||||
device = oneflow.device(f"cuda:{device}")
|
||||
|
@ -280,12 +280,23 @@ def optimizer_state_to_device(state, device):
|
||||
|
||||
|
||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'):
|
||||
if type(args.batch_sampler) is not oneflowBatchSampler or (type(args.sampler) not in {oneflowRandomSampler,
|
||||
oneflowSequentialSampler}):
|
||||
mode = 'training' if controller == 'Trainer' else 'evaluation'
|
||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler'
|
||||
"""
|
||||
检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止
|
||||
在分布式训练中出现错误会报错。
|
||||
"""
|
||||
error_flag = (type(args.sampler) not in {oneflowRandomSampler, oneflowSequentialSampler})
|
||||
if controller == 'Trainer':
|
||||
mode = 'training'
|
||||
substitution = 'fastNLP.RandomSampler'
|
||||
error_flag = (type(args.batch_sampler) != oneflowBatchSampler) or error_flag
|
||||
else: # Evaluator
|
||||
mode = 'evaluation'
|
||||
substitution = 'fastNLP.UnrepeatedSequentialSampler'
|
||||
if error_flag:
|
||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
|
||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
|
||||
f"``{substitution}``. The customized sampler should set for distributed running "
|
||||
f"before initializing ``{controller}`` , and then set the "
|
||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.")
|
||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``."
|
||||
f"\n Current batch_sampler: {type(args.batch_sampler)}"
|
||||
f"\n Current sampler: {type(args.sampler)}")
|
||||
|
@ -112,6 +112,7 @@ if _NEED_IMPORT_PADDLE:
|
||||
from paddle.optimizer import Optimizer
|
||||
from paddle.fluid.reader import _DatasetKind
|
||||
from paddle.fluid.dygraph import parallel_helper
|
||||
from paddle.io import BatchSampler
|
||||
|
||||
__all__ = [
|
||||
"PaddleFleetDriver",
|
||||
@ -471,9 +472,15 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
# evaluator
|
||||
elif dist == "unrepeatdist":
|
||||
args = self.get_dataloader_args(dataloader)
|
||||
if type(args.batch_sampler) != BatchSampler:
|
||||
# TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善
|
||||
logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \
|
||||
"train dataloader while testing ``overfit_batches``, which may cause that" \
|
||||
"the data for distributed evaluation is not unrepeated.")
|
||||
if isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
|
||||
elif not isinstance(args.sampler, UnrepeatedSampler):
|
||||
_check_dataloader_args_for_distributed(args, controller='Evaluator')
|
||||
sampler = UnrepeatedSequentialSampler(
|
||||
dataset=args.dataset
|
||||
)
|
||||
@ -483,7 +490,9 @@ class PaddleFleetDriver(PaddleDriver):
|
||||
num_replicas=self.world_size,
|
||||
rank=self.global_rank
|
||||
)
|
||||
return replace_sampler(dataloader, sampler)
|
||||
# TODO 这里暂时统一替换为 BatchSampler
|
||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
|
||||
return replace_batch_sampler(dataloader, batch_sampler)
|
||||
else:
|
||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")
|
||||
|
||||
|
@ -266,12 +266,23 @@ def optimizer_state_to_device(state, device):
|
||||
return new_state
|
||||
|
||||
def _check_dataloader_args_for_distributed(args, controller='Trainer'):
|
||||
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler,
|
||||
SequenceSampler}):
|
||||
mode = 'training' if controller == 'Trainer' else 'evaluation'
|
||||
substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler'
|
||||
"""
|
||||
检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止
|
||||
在分布式训练中出现错误会报错。
|
||||
"""
|
||||
error_flag = (type(args.sampler) not in {RandomSampler, SequenceSampler})
|
||||
if controller == 'Trainer':
|
||||
mode = 'training'
|
||||
substitution = 'fastNLP.RandomSampler'
|
||||
error_flag = (type(args.batch_sampler) != BatchSampler) or error_flag
|
||||
else: # Evaluator
|
||||
mode = 'evaluation'
|
||||
substitution = 'fastNLP.UnrepeatedSequentialSampler'
|
||||
if error_flag:
|
||||
raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause "
|
||||
f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into "
|
||||
f"``{substitution}``. The customized sampler should set for distributed running "
|
||||
f"before initializing ``{controller}`` , and then set the "
|
||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.")
|
||||
f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``."
|
||||
f"\n Current batch_sampler: {type(args.batch_sampler)}"
|
||||
f"\n Current sampler: {type(args.sampler)}")
|
||||
|
@ -617,7 +617,6 @@ class TorchDDPDriver(TorchDriver):
|
||||
if isinstance(args.sampler, ReproducibleSampler):
|
||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
|
||||
elif not isinstance(args.sampler, UnrepeatedSampler):
|
||||
# TODO 避开 batch_sampler 的情况
|
||||
_check_dataloader_args_for_distributed(args, controller='Evaluator')
|
||||
sampler = UnrepeatedSequentialSampler(
|
||||
dataset=args.dataset
|
||||
|
Loading…
Reference in New Issue
Block a user