mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
1. 修复Trainer check_code中检查evaluate时使用train_data的bug
This commit is contained in:
parent
daa1671230
commit
56e7641eb8
@ -138,20 +138,30 @@ class Trainer(object):
|
||||
|
||||
开始训练过程。主要有以下几个步骤::
|
||||
|
||||
对于每次循环
|
||||
1. 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为float, int的fields进行padding。并转换为Tensor。
|
||||
for epoch in range(num_epochs):
|
||||
# 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。
|
||||
非float,int类型的参数将不会被转换为Tensor,且不进行padding。
|
||||
for batch_x, batch_y in Batch(DataSet)
|
||||
# batch_x中为设置为input的field
|
||||
# batch_y中为设置为target的field
|
||||
2. 将batch_x的数据送入到model.forward函数中,并获取结果
|
||||
3. 将batch_y与model.forward的结果一并送入loss中计算loss
|
||||
# batch_x是一个dict, 被设为input的field会出现在这个dict中,
|
||||
key为DataSet中的field_name, value为该field的value
|
||||
# batch_y也是一个dict,被设为target的field会出现在这个dict中,
|
||||
key为DataSet中的field_name, value为该field的value
|
||||
2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形
|
||||
参完成参数传递。例如,
|
||||
forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的
|
||||
value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数
|
||||
而且默认参数这个key没有在batch_x中,则使用默认参数。
|
||||
3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况
|
||||
中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或
|
||||
target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计
|
||||
算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target
|
||||
, 并完成loss的计算。
|
||||
4. 获取到loss之后,进行反向求导并更新梯度
|
||||
如果测试集不为空
|
||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型
|
||||
根据需要适时进行验证机测试
|
||||
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型
|
||||
|
||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现最好的
|
||||
模型参数。
|
||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
|
||||
最好的模型参数。
|
||||
:return results: 返回一个字典类型的数据, 内含以下内容::
|
||||
|
||||
seconds: float, 表示训练时长
|
||||
@ -196,8 +206,11 @@ class Trainer(object):
|
||||
results['best_step'] = self.best_dev_step
|
||||
if load_best_model:
|
||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
||||
# self._load_model(self.model, model_name)
|
||||
print("Reloaded the best model.")
|
||||
load_succeed = self._load_model(self.model, model_name)
|
||||
if load_succeed:
|
||||
print("Reloaded the best model.")
|
||||
else:
|
||||
print("Fail to reload best model.")
|
||||
finally:
|
||||
self._summary_writer.close()
|
||||
del self._summary_writer
|
||||
@ -208,7 +221,7 @@ class Trainer(object):
|
||||
def _tqdm_train(self):
|
||||
self.step = 0
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
|
||||
as_numpy=False)
|
||||
as_numpy=False)
|
||||
total_steps = data_iterator.num_batches*self.n_epochs
|
||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
||||
avg_loss = 0
|
||||
@ -297,7 +310,8 @@ class Trainer(object):
|
||||
if self.save_path is not None:
|
||||
self._save_model(self.model,
|
||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
|
||||
|
||||
else:
|
||||
self._best_model_states = {name:param.cpu().clone() for name, param in self.model.named_parameters()}
|
||||
self.best_dev_perf = res
|
||||
self.best_dev_epoch = epoch
|
||||
self.best_dev_step = step
|
||||
@ -356,7 +370,7 @@ class Trainer(object):
|
||||
torch.save(model, model_name)
|
||||
|
||||
def _load_model(self, model, model_name, only_param=False):
|
||||
# TODO: 这个是不是有问题?
|
||||
# 返回bool值指示是否成功reload模型
|
||||
if self.save_path is not None:
|
||||
model_path = os.path.join(self.save_path, model_name)
|
||||
if only_param:
|
||||
@ -364,6 +378,11 @@ class Trainer(object):
|
||||
else:
|
||||
states = torch.load(model_path).state_dict()
|
||||
model.load_state_dict(states)
|
||||
elif hasattr(self, "_best_model_states"):
|
||||
model.load_state_dict(self._best_model_states)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _better_eval_result(self, metrics):
|
||||
"""Check if the current epoch yields better validation results.
|
||||
@ -469,7 +488,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
|
||||
break
|
||||
|
||||
if dev_data is not None:
|
||||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
|
||||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
|
||||
batch_size=batch_size, verbose=-1)
|
||||
evaluate_results = tester.test()
|
||||
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)
|
||||
|
@ -448,4 +448,33 @@ class BMES2OutputProcessor(Processor):
|
||||
words.append(''.join(chars[start_idx:idx+1]))
|
||||
start_idx = idx + 1
|
||||
return ' '.join(words)
|
||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name)
|
||||
dataset.apply(func=inner_proc, new_field_name=self.new_added_field_name)
|
||||
|
||||
|
||||
class InputTargetProcessor(Processor):
|
||||
def __init__(self, input_fields, target_fields):
|
||||
"""
|
||||
对DataSet操作,将input_fields中的field设置为input,target_fields的中field设置为target
|
||||
|
||||
:param input_fields: List[str], 设置为input_field的field_name。如果为None,则不将任何field设置为target。
|
||||
:param target_fields: List[str], 设置为target_field的field_name。 如果为None,则不将任何field设置为target。
|
||||
"""
|
||||
super(InputTargetProcessor, self).__init__(None, None)
|
||||
|
||||
if input_fields is not None and not isinstance(input_fields, list):
|
||||
raise TypeError("input_fields should be List[str], not {}.".format(type(input_fields)))
|
||||
else:
|
||||
self.input_fields = input_fields
|
||||
if target_fields is not None and not isinstance(target_fields, list):
|
||||
raise TypeError("target_fiels should be List[str], not{}.".format(type(target_fields)))
|
||||
else:
|
||||
self.target_fields = target_fields
|
||||
|
||||
def process(self, dataset):
|
||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
|
||||
if self.input_fields is not None:
|
||||
for field in self.input_fields:
|
||||
dataset.set_input(field)
|
||||
if self.target_fields is not None:
|
||||
for field in self.target_fields:
|
||||
dataset.set_target(field)
|
@ -6,7 +6,7 @@ from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegPr
|
||||
from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor
|
||||
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor
|
||||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor
|
||||
|
||||
from reproduction.chinese_word_segment.process.cws_processor import InputTargetProcessor
|
||||
|
||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
|
||||
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF
|
||||
@ -39,6 +39,8 @@ bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='b
|
||||
|
||||
seq_len_proc = SeqLenProcessor('chars')
|
||||
|
||||
input_target_proc = InputTargetProcessor(input_fields=['chars', 'bigrams', 'seq_lens', "target"],
|
||||
target_fields=['target', 'seq_lens'])
|
||||
# 2. 使用processor
|
||||
fs2hs_proc(tr_dataset)
|
||||
|
||||
@ -61,14 +63,11 @@ char_vocab_proc(dev_dataset)
|
||||
bigram_vocab_proc(dev_dataset)
|
||||
seq_len_proc(dev_dataset)
|
||||
|
||||
dev_dataset.set_input('chars', 'bigrams', 'target')
|
||||
tr_dataset.set_input('chars', 'bigrams', 'target')
|
||||
dev_dataset.set_target('seq_lens')
|
||||
tr_dataset.set_target('seq_lens')
|
||||
input_target_proc(tr_dataset)
|
||||
input_target_proc(dev_dataset)
|
||||
|
||||
print("Finish preparing data.")
|
||||
|
||||
|
||||
# 3. 得到数据集可以用于训练了
|
||||
# TODO pretrain的embedding是怎么解决的?
|
||||
|
||||
@ -86,80 +85,18 @@ cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100,
|
||||
cws_model.cuda()
|
||||
|
||||
num_epochs = 5
|
||||
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02)
|
||||
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.005)
|
||||
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.core.sampler import BucketSampler
|
||||
from fastNLP.core.metrics import BMESF1PreRecMetric
|
||||
|
||||
metric = BMESF1PreRecMetric(target='tags')
|
||||
trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=3,
|
||||
trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=num_epochs,
|
||||
batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None,
|
||||
optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True)
|
||||
|
||||
trainer.train()
|
||||
exit(0)
|
||||
|
||||
#
|
||||
# print_every = 50
|
||||
# batch_size = 32
|
||||
# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
|
||||
# dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False)
|
||||
# num_batch_per_epoch = len(tr_dataset) // batch_size
|
||||
# best_f1 = 0
|
||||
# best_epoch = 0
|
||||
# for num_epoch in range(num_epochs):
|
||||
# print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10)
|
||||
# sys.stdout.flush()
|
||||
# avg_loss = 0
|
||||
# with tqdm(total=num_batch_per_epoch, leave=True) as pbar:
|
||||
# pbar.set_description_str('Epoch:%d' % (num_epoch + 1))
|
||||
# cws_model.train()
|
||||
# for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
|
||||
# optimizer.zero_grad()
|
||||
#
|
||||
# tags = batch_y['tags'].long()
|
||||
# pred_dict = cws_model(**batch_x, tags=tags) # B x L x tag_size
|
||||
#
|
||||
# seq_lens = pred_dict['seq_lens']
|
||||
# masks = seq_lens_to_mask(seq_lens).float()
|
||||
# tags = tags.to(seq_lens.device)
|
||||
#
|
||||
# loss = pred_dict['loss']
|
||||
#
|
||||
# # loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size),
|
||||
# # tags.view(-1)) * masks.view(-1)) / torch.sum(masks)
|
||||
# # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float())
|
||||
#
|
||||
# avg_loss += loss.item()
|
||||
#
|
||||
# loss.backward()
|
||||
# for group in optimizer.param_groups:
|
||||
# for param in group['params']:
|
||||
# param.grad.clamp_(-5, 5)
|
||||
#
|
||||
# optimizer.step()
|
||||
#
|
||||
# if batch_idx % print_every == 0:
|
||||
# pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every))
|
||||
# avg_loss = 0
|
||||
# pbar.update(print_every)
|
||||
# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
|
||||
# # 验证集
|
||||
# pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher, type='bmes')
|
||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
|
||||
# pre*100,
|
||||
# rec*100))
|
||||
# if best_f1<f1:
|
||||
# best_f1 = f1
|
||||
# # 缓存最佳的parameter,可能之后会用于保存
|
||||
# best_state_dict = {
|
||||
# key:value.clone() for key, value in
|
||||
# cws_model.state_dict().items()
|
||||
# }
|
||||
# best_epoch = num_epoch
|
||||
#
|
||||
# cws_model.load_state_dict(best_state_dict)
|
||||
|
||||
# 4. 组装需要存下的内容
|
||||
pp = Pipeline()
|
||||
@ -171,6 +108,7 @@ pp.add_processor(bigram_proc)
|
||||
pp.add_processor(char_vocab_proc)
|
||||
pp.add_processor(bigram_vocab_proc)
|
||||
pp.add_processor(seq_len_proc)
|
||||
pp.add_processor(input_target_proc)
|
||||
|
||||
# te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
|
||||
te_filename = '/home/hyan/ctb3/test.conllx'
|
||||
@ -181,6 +119,7 @@ from fastNLP.core.tester import Tester
|
||||
|
||||
tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False,
|
||||
verbose=1)
|
||||
tester.test()
|
||||
#
|
||||
# batch_size = 64
|
||||
# te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
|
||||
@ -193,7 +132,7 @@ tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64,
|
||||
|
||||
test_context_dict = {'pipeline': pp,
|
||||
'model': cws_model}
|
||||
torch.save(test_context_dict, 'models/test_context_crf.pkl')
|
||||
# torch.save(test_context_dict, 'models/test_context_crf.pkl')
|
||||
|
||||
|
||||
# 5. dev的pp
|
||||
|
Loading…
Reference in New Issue
Block a user