mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-11 10:05:30 +08:00
1.在DistTrainer中增加对sampler的控制; 2.在trainer和disttrainer中支持其它optimizer
This commit is contained in:
parent
60423c2d0d
commit
2ba336945c
@ -206,7 +206,7 @@ class Callback(object):
|
||||
def on_batch_begin(self, batch_x, batch_y, indices):
|
||||
r"""
|
||||
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步
|
||||
可以进行一些负采样之类的操作
|
||||
可以进行一些负采样之类的操作。batch_x和batch_y中的tensor已经被放置到了模型所在的设备上。
|
||||
|
||||
:param dict batch_x: DataSet中被设置为input的field的batch。
|
||||
:param dict batch_y: DataSet中被设置为target的field的batch。
|
||||
@ -1169,11 +1169,12 @@ class EchoCallback(Callback):
|
||||
|
||||
|
||||
class _TesterCallback(Callback):
|
||||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
|
||||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None, sampler=None,
|
||||
use_tqdm=True):
|
||||
super(_TesterCallback, self).__init__()
|
||||
self.tester = Tester(data, model,
|
||||
metrics=metrics, batch_size=batch_size,
|
||||
num_workers=num_workers, verbose=0)
|
||||
num_workers=num_workers, verbose=0, sampler=sampler, use_tqdm=use_tqdm)
|
||||
if metric_key is not None:
|
||||
self.metric_key, self.increase_better = self._parse_metric_key(metric_key)
|
||||
else:
|
||||
|
@ -73,7 +73,7 @@ class DistTrainer():
|
||||
dev_data=None, metrics=None, metric_key=None,
|
||||
update_every=1, print_every=10, validate_every=-1,
|
||||
save_path=None, device='auto',
|
||||
fp16='', use_tqdm=True):
|
||||
fp16='', use_tqdm=True, **kwargs):
|
||||
r"""
|
||||
|
||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
|
||||
@ -106,6 +106,9 @@ class DistTrainer():
|
||||
:param str device: 指定 device,可以是 gpu,cpu 或 auto
|
||||
:param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。
|
||||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
|
||||
:param kwargs: 支持配置可选参数
|
||||
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
|
||||
Sampler test_sampler: 在evaluate的时候使用的sampler
|
||||
"""
|
||||
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
|
||||
if device == 'auto':
|
||||
@ -163,16 +166,23 @@ class DistTrainer():
|
||||
self.model = self.ddp_model.module
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.sampler = DistributedSampler(self.train_data)
|
||||
if isinstance(self.train_data, DataSet):
|
||||
self.sampler = DistributedSampler(self.train_data)
|
||||
self.data_iterator = self._get_data_iter(self.train_data)
|
||||
self.batch_size = self.world_size * self.batch_size_per_gpu
|
||||
self.n_steps = self._get_n_steps()
|
||||
|
||||
if 'test_use_tqdm' in kwargs:
|
||||
test_use_tqdm = kwargs.get('test_use_tqdm')
|
||||
else:
|
||||
test_use_tqdm = self.use_tqdm
|
||||
|
||||
# for evaluation, only run eval on master proc
|
||||
if dev_data and metrics:
|
||||
cb = _TesterCallback(
|
||||
dev_data, model, metrics,
|
||||
batch_size=batch_size_per_gpu, num_workers=num_workers)
|
||||
batch_size=batch_size_per_gpu, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
|
||||
use_tqdm=test_use_tqdm)
|
||||
self.test_manager.add_callback([cb], master=True)
|
||||
|
||||
# Setup logging
|
||||
@ -232,8 +242,10 @@ class DistTrainer():
|
||||
elif optimizer is None:
|
||||
return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3)
|
||||
else:
|
||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
|
||||
|
||||
if not (hasattr(optimizer, 'step') and callable(optimizer.step)):
|
||||
raise TypeError("optimizer must have a callable step() function.")
|
||||
else:
|
||||
self.optimizer = optimizer
|
||||
@property
|
||||
def is_master(self):
|
||||
r"""是否是主进程"""
|
||||
|
@ -545,7 +545,10 @@ class Trainer(object):
|
||||
elif optimizer is None:
|
||||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3)
|
||||
else:
|
||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
|
||||
if not (hasattr(optimizer, 'step') and callable(optimizer.step)):
|
||||
raise TypeError("optimizer must have a callable step() function.")
|
||||
else:
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.logger = logger
|
||||
|
||||
|
@ -273,6 +273,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
|
||||
# 得到(batch_size, num_beams), (batch_size, num_beams)
|
||||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
|
||||
# TODO 这里需要考虑如果在第一个位置就结束的情况
|
||||
|
||||
# 根据index来做顺序的调转
|
||||
indices = torch.arange(batch_size, dtype=torch.long).to(device)
|
||||
|
Loading…
Reference in New Issue
Block a user