1.在DistTrainer中增加对sampler的控制; 2.在trainer和disttrainer中支持其它optimizer

This commit is contained in:
yh_cc 2020-11-22 20:14:44 +08:00
parent 60423c2d0d
commit 2ba336945c
4 changed files with 26 additions and 9 deletions

View File

@ -206,7 +206,7 @@ class Callback(object):
def on_batch_begin(self, batch_x, batch_y, indices):
r"""
每次采集到一个batch的数据则调用一次这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的所以在这一步
可以进行一些负采样之类的操作
可以进行一些负采样之类的操作batch_x和batch_y中的tensor已经被放置到了模型所在的设备上
:param dict batch_x: DataSet中被设置为input的field的batch
:param dict batch_y: DataSet中被设置为target的field的batch
@ -1169,11 +1169,12 @@ class EchoCallback(Callback):
class _TesterCallback(Callback):
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None, sampler=None,
use_tqdm=True):
super(_TesterCallback, self).__init__()
self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0)
num_workers=num_workers, verbose=0, sampler=sampler, use_tqdm=use_tqdm)
if metric_key is not None:
self.metric_key, self.increase_better = self._parse_metric_key(metric_key)
else:

View File

@ -73,7 +73,7 @@ class DistTrainer():
dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1,
save_path=None, device='auto',
fp16='', use_tqdm=True):
fp16='', use_tqdm=True, **kwargs):
r"""
:param train_data: 训练集 :class:`~fastNLP.DataSet` 类型
@ -106,6 +106,9 @@ class DistTrainer():
:param str device: 指定 device可以是 gpucpu auto
:param str fp16: 指定半精度训练的优化等级可为 O1O2 O3若为空字符串则不使用半精度
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False则将loss打印在终端中
:param kwargs: 支持配置可选参数
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
Sampler test_sampler: 在evaluate的时候使用的sampler
"""
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto':
@ -163,16 +166,23 @@ class DistTrainer():
self.model = self.ddp_model.module
self.optimizer = optimizer
self.sampler = DistributedSampler(self.train_data)
if isinstance(self.train_data, DataSet):
self.sampler = DistributedSampler(self.train_data)
self.data_iterator = self._get_data_iter(self.train_data)
self.batch_size = self.world_size * self.batch_size_per_gpu
self.n_steps = self._get_n_steps()
if 'test_use_tqdm' in kwargs:
test_use_tqdm = kwargs.get('test_use_tqdm')
else:
test_use_tqdm = self.use_tqdm
# for evaluation, only run eval on master proc
if dev_data and metrics:
cb = _TesterCallback(
dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_workers)
batch_size=batch_size_per_gpu, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
use_tqdm=test_use_tqdm)
self.test_manager.add_callback([cb], master=True)
# Setup logging
@ -232,8 +242,10 @@ class DistTrainer():
elif optimizer is None:
return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3)
else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
if not (hasattr(optimizer, 'step') and callable(optimizer.step)):
raise TypeError("optimizer must have a callable step() function.")
else:
self.optimizer = optimizer
@property
def is_master(self):
r"""是否是主进程"""

View File

@ -545,7 +545,10 @@ class Trainer(object):
elif optimizer is None:
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3)
else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
if not (hasattr(optimizer, 'step') and callable(optimizer.step)):
raise TypeError("optimizer must have a callable step() function.")
else:
self.optimizer = optimizer
self.logger = logger

View File

@ -273,6 +273,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
# 得到(batch_size, num_beams), (batch_size, num_beams)
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
# TODO 这里需要考虑如果在第一个位置就结束的情况
# 根据index来做顺序的调转
indices = torch.arange(batch_size, dtype=torch.long).to(device)