mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
1. callback中增加GradientClip; 2.Trainer中取消_print_train()和_tqdm_train(),全部并入了_train()
This commit is contained in:
parent
59c9b894b1
commit
400552971c
@ -23,7 +23,7 @@ from fastNLP.api.processor import IndexerProcessor
|
||||
|
||||
# TODO add pretrain urls
|
||||
model_urls = {
|
||||
|
||||
'cws': "http://123.206.98.91:8888/download/cws_crf-69e357c9.pkl"
|
||||
}
|
||||
|
||||
|
||||
@ -139,6 +139,12 @@ class POS(API):
|
||||
|
||||
class CWS(API):
|
||||
def __init__(self, model_path=None, device='cpu'):
|
||||
"""
|
||||
中文分词高级接口。
|
||||
|
||||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型
|
||||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。
|
||||
"""
|
||||
super(CWS, self).__init__()
|
||||
if model_path is None:
|
||||
model_path = model_urls['cws']
|
||||
@ -146,7 +152,13 @@ class CWS(API):
|
||||
self.load(model_path, device)
|
||||
|
||||
def predict(self, content):
|
||||
"""
|
||||
分词接口。
|
||||
|
||||
:param content: str或List[str], 例如: "中文分词很重要!", 返回的结果是"中文 分词 很 重要 !"。 如果传入的为List[str],比如
|
||||
[ "中文分词很重要!", ...], 返回的结果["中文 分词 很 重要 !", ...]。
|
||||
:return: str或List[str], 根据输入的的类型决定。
|
||||
"""
|
||||
if not hasattr(self, 'pipeline'):
|
||||
raise ValueError("You have to load model first.")
|
||||
|
||||
@ -162,7 +174,10 @@ class CWS(API):
|
||||
dataset.add_field('raw_sentence', sentence_list)
|
||||
|
||||
# 3. 使用pipeline
|
||||
self.pipeline(dataset)
|
||||
pipeline = self.pipeline.pipeline[:-3] + self.pipeline.pipeline[-2:]
|
||||
pp = Pipeline(pipeline)
|
||||
pp(dataset)
|
||||
# self.pipeline(dataset)
|
||||
|
||||
output = dataset['output'].content
|
||||
if isinstance(content, str):
|
||||
@ -171,10 +186,28 @@ class CWS(API):
|
||||
return output
|
||||
|
||||
def test(self, filepath):
|
||||
"""
|
||||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。
|
||||
分词文件应该为:
|
||||
1 编者按 编者按 NN O 11 nmod:topic
|
||||
2 : : PU O 11 punct
|
||||
3 7月 7月 NT DATE 4 compound:nn
|
||||
4 12日 12日 NT DATE 11 nmod:tmod
|
||||
5 , , PU O 11 punct
|
||||
|
||||
tag_proc = self._dict['tag_indexer']
|
||||
1 这 这 DT O 3 det
|
||||
2 款 款 M O 1 mark:clf
|
||||
3 飞行 飞行 NN O 8 nsubj
|
||||
4 从 从 P O 5 case
|
||||
5 外型 外型 NN O 8 nmod:prep
|
||||
以空行分割两个句子,有内容的每行有7列。
|
||||
|
||||
:param filepath: str, 文件路径路径。
|
||||
:return: float, float, float. 分别f1, precision, recall.
|
||||
"""
|
||||
tag_proc = self._dict['tag_proc']
|
||||
cws_model = self.pipeline.pipeline[-2].model
|
||||
pipeline = self.pipeline.pipeline[:5]
|
||||
pipeline = self.pipeline.pipeline[:-2]
|
||||
|
||||
pipeline.insert(1, tag_proc)
|
||||
pp = Pipeline(pipeline)
|
||||
@ -185,12 +218,16 @@ class CWS(API):
|
||||
te_dataset = reader.load(filepath)
|
||||
pp(te_dataset)
|
||||
|
||||
batch_size = 64
|
||||
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
|
||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes')
|
||||
f1 = round(f1 * 100, 2)
|
||||
pre = round(pre * 100, 2)
|
||||
rec = round(rec * 100, 2)
|
||||
from fastNLP.core.tester import Tester
|
||||
from fastNLP.core.metrics import BMESF1PreRecMetric
|
||||
|
||||
tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64,
|
||||
verbose=0)
|
||||
eval_res = tester.test()
|
||||
|
||||
f1 = eval_res['BMESF1PreRecMetric']['f']
|
||||
pre = eval_res['BMESF1PreRecMetric']['pre']
|
||||
rec = eval_res['BMESF1PreRecMetric']['rec']
|
||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
|
||||
|
||||
return f1, pre, rec
|
||||
@ -301,25 +338,25 @@ class Analyzer:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl'
|
||||
pos = POS(pos_model_path, device='cpu')
|
||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
'那么这款无人机到底有多厉害?']
|
||||
print(pos.test("/home/zyfeng/data/sample.conllx"))
|
||||
# pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl'
|
||||
# pos = POS(pos_model_path, device='cpu')
|
||||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
|
||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
# '那么这款无人机到底有多厉害?']
|
||||
# print(pos.test("/home/zyfeng/data/sample.conllx"))
|
||||
# print(pos.predict(s))
|
||||
|
||||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
|
||||
# cws = CWS(device='cpu')
|
||||
# s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' ,
|
||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
# '那么这款无人机到底有多厉害?']
|
||||
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll'))
|
||||
# print(cws.predict(s))
|
||||
cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
|
||||
cws = CWS(model_path=cws_model_path, device='cuda:0')
|
||||
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' ,
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
'那么这款无人机到底有多厉害?']
|
||||
# print(cws.test('/home/hyan/ctb3/test.conllx'))
|
||||
print(cws.predict(s))
|
||||
|
||||
# parser = Parser(device='cpu')
|
||||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll'))
|
||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
'那么这款无人机到底有多厉害?']
|
||||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
|
||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
# '那么这款无人机到底有多厉害?']
|
||||
# print(parser.predict(s))
|
||||
|
@ -8,38 +8,76 @@ class Callback(object):
|
||||
def __init__(self):
|
||||
super(Callback, self).__init__()
|
||||
|
||||
def before_train(self, *args):
|
||||
def before_train(self):
|
||||
# before the main training loop
|
||||
pass
|
||||
|
||||
def before_epoch(self, *args):
|
||||
def before_epoch(self, cur_epoch, total_epoch):
|
||||
# at the beginning of each epoch
|
||||
pass
|
||||
|
||||
def before_batch(self, *args):
|
||||
def before_batch(self, batch_x, batch_y, indices):
|
||||
# at the beginning of each step/mini-batch
|
||||
pass
|
||||
|
||||
def before_loss(self, *args):
|
||||
def before_loss(self, batch_y, predict_y):
|
||||
# after data_forward, and before loss computation
|
||||
pass
|
||||
|
||||
def before_backward(self, *args):
|
||||
def before_backward(self, loss, model):
|
||||
# after loss computation, and before gradient backward
|
||||
pass
|
||||
|
||||
def after_backward(self, model):
|
||||
pass
|
||||
|
||||
def after_step(self, optimizer):
|
||||
pass
|
||||
|
||||
def after_batch(self, *args):
|
||||
# at the end of each step/mini-batch
|
||||
pass
|
||||
|
||||
def after_epoch(self, *args):
|
||||
# at the end of each epoch
|
||||
def after_valid(self, eval_result, metric_key, optimizer):
|
||||
"""
|
||||
每次执行验证机的evaluation后会调用。传入eval_result
|
||||
|
||||
:param eval_result: Dict[str: Dict[str: float]], evaluation的结果
|
||||
:param metric_key: str
|
||||
:param optimizer:
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train(self, *args):
|
||||
# after training loop
|
||||
def after_epoch(self, cur_epoch, n_epoch, optimizer):
|
||||
"""
|
||||
每个epoch结束将会调用该方法
|
||||
|
||||
:param cur_epoch: int, 当前的batch。从1开始。
|
||||
:param n_epoch: int, 总的batch数
|
||||
:param optimizer: 传入Trainer的optimizer。
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train(self, model):
|
||||
"""
|
||||
训练结束,调用该方法
|
||||
|
||||
:param model: nn.Module, 传入Trainer的模型
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_exception(self, exception, model, indices):
|
||||
"""
|
||||
当训练过程出现异常,会触发该方法
|
||||
:param exception: 某种类型的Exception,比如KeyboardInterrupt等
|
||||
:param model: 传入Trainer的模型
|
||||
:param indices: 当前batch的index
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def transfer(func):
|
||||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类.
|
||||
@ -111,7 +149,7 @@ class CallbackManager(Callback):
|
||||
pass
|
||||
|
||||
@transfer
|
||||
def after_step(self):
|
||||
def after_step(self, optimizer):
|
||||
pass
|
||||
|
||||
@transfer
|
||||
@ -169,6 +207,35 @@ class EchoCallback(Callback):
|
||||
def after_train(self):
|
||||
print("after_train")
|
||||
|
||||
class GradientClipCallback(Callback):
|
||||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'):
|
||||
"""
|
||||
每次backward前,将parameter的gradient clip到某个范围。
|
||||
|
||||
:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer
|
||||
的model中所有参数进行clip
|
||||
:param clip_value: float, 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
|
||||
:param clip_type: str, 支持'norm', 'value'两种。
|
||||
(1) 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
|
||||
(2) 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于
|
||||
clip_value的gradient被赋值为clip_value.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
from torch import nn
|
||||
if clip_type == 'norm':
|
||||
self.clip_fun = nn.utils.clip_grad_norm_
|
||||
elif clip_type == 'value':
|
||||
self.clip_fun = nn.utils.clip_grad_value_
|
||||
else:
|
||||
raise ValueError("Only supports `norm` or `value` right now.")
|
||||
self.parameters = parameters
|
||||
self.clip_value = clip_value
|
||||
|
||||
def after_backward(self, model):
|
||||
self.clip_fun(model.parameters(), self.clip_value)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
|
||||
|
@ -7,7 +7,11 @@ import numpy as np
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch import nn
|
||||
from tqdm.autonotebook import tqdm
|
||||
|
||||
try:
|
||||
from tqdm.autonotebook import tqdm
|
||||
except:
|
||||
from fastNLP.core.utils import pseudo_tqdm as tqdm
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.callback import CallbackManager
|
||||
@ -108,7 +112,7 @@ class Trainer(object):
|
||||
self.use_cuda = bool(use_cuda)
|
||||
self.save_path = save_path
|
||||
self.print_every = int(print_every)
|
||||
self.validate_every = int(validate_every)
|
||||
self.validate_every = int(validate_every) if validate_every!=0 else -1
|
||||
self.best_metric_indicator = None
|
||||
self.sampler = sampler
|
||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
|
||||
@ -119,11 +123,7 @@ class Trainer(object):
|
||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
|
||||
|
||||
self.use_tqdm = use_tqdm
|
||||
if self.use_tqdm:
|
||||
tester_verbose = 0
|
||||
self.print_every = abs(self.print_every)
|
||||
else:
|
||||
tester_verbose = 1
|
||||
self.print_every = abs(self.print_every)
|
||||
|
||||
if self.dev_data is not None:
|
||||
self.tester = Tester(model=self.model,
|
||||
@ -131,7 +131,7 @@ class Trainer(object):
|
||||
metrics=self.metrics,
|
||||
batch_size=self.batch_size,
|
||||
use_cuda=self.use_cuda,
|
||||
verbose=tester_verbose)
|
||||
verbose=0)
|
||||
|
||||
self.step = 0
|
||||
self.start_time = None # start timestamp
|
||||
@ -199,10 +199,7 @@ class Trainer(object):
|
||||
self._summary_writer = SummaryWriter(path)
|
||||
|
||||
self.callback_manager.before_train()
|
||||
if self.use_tqdm:
|
||||
self._tqdm_train()
|
||||
else:
|
||||
self._print_train()
|
||||
self._train()
|
||||
self.callback_manager.after_train(self.model)
|
||||
|
||||
if self.dev_data is not None:
|
||||
@ -225,12 +222,16 @@ class Trainer(object):
|
||||
|
||||
return results
|
||||
|
||||
def _tqdm_train(self):
|
||||
def _train(self):
|
||||
if not self.use_tqdm:
|
||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
|
||||
else:
|
||||
inner_tqdm = tqdm
|
||||
self.step = 0
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
|
||||
as_numpy=False)
|
||||
start = time.time()
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False)
|
||||
total_steps = data_iterator.num_batches * self.n_epochs
|
||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
||||
avg_loss = 0
|
||||
for epoch in range(1, self.n_epochs+1):
|
||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
|
||||
@ -265,18 +266,26 @@ class Trainer(object):
|
||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
|
||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
|
||||
if (self.step+1) % self.print_every == 0:
|
||||
pbar.set_postfix_str("loss:{0:<6.5f}".format(avg_loss / self.print_every))
|
||||
if self.use_tqdm:
|
||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
|
||||
pbar.update(self.print_every)
|
||||
else:
|
||||
end = time.time()
|
||||
diff = timedelta(seconds=round(end - start))
|
||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
|
||||
epoch, self.step, avg_loss, diff)
|
||||
pbar.set_postfix_str(print_output)
|
||||
avg_loss = 0
|
||||
pbar.update(self.print_every)
|
||||
self.step += 1
|
||||
# do nothing
|
||||
self.callback_manager.after_batch()
|
||||
|
||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
|
||||
(self.validate_every < 0 and self.step % self.batch_size == len(data_iterator))) \
|
||||
(self.validate_every < 0 and self.step % len(data_iterator)) == 0) \
|
||||
and self.dev_data is not None:
|
||||
eval_res = self._do_validation(epoch=epoch, step=self.step)
|
||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
|
||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
|
||||
total_steps) + \
|
||||
self.tester._format_eval_results(eval_res)
|
||||
pbar.write(eval_str)
|
||||
|
||||
@ -292,54 +301,6 @@ class Trainer(object):
|
||||
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer)
|
||||
pbar.close()
|
||||
|
||||
def _print_train(self):
|
||||
epoch = 1
|
||||
start = time.time()
|
||||
while epoch <= self.n_epochs:
|
||||
self.callback_manager.before_epoch()
|
||||
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
|
||||
as_numpy=False)
|
||||
|
||||
for batch_x, batch_y in data_iterator:
|
||||
self.callback_manager.before_batch()
|
||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
|
||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
|
||||
prediction = self._data_forward(self.model, batch_x)
|
||||
|
||||
self.callback_manager.before_loss()
|
||||
loss = self._compute_loss(prediction, batch_y)
|
||||
|
||||
self.callback_manager.before_backward()
|
||||
self._grad_backward(loss)
|
||||
self._update()
|
||||
|
||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
|
||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
|
||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
|
||||
if self.print_every > 0 and self.step % self.print_every == 0:
|
||||
end = time.time()
|
||||
diff = timedelta(seconds=round(end - start))
|
||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
|
||||
epoch, self.step, loss.data, diff)
|
||||
print(print_output)
|
||||
|
||||
if (self.validate_every > 0 and self.step % self.validate_every == 0 and
|
||||
self.dev_data is not None):
|
||||
self._do_validation(epoch=epoch, step=self.step)
|
||||
|
||||
self.step += 1
|
||||
self.callback_manager.after_batch()
|
||||
|
||||
# validate_every override validation at end of epochs
|
||||
if self.dev_data and self.validate_every <= 0:
|
||||
self._do_validation(epoch=epoch, step=self.step)
|
||||
epoch += 1
|
||||
self.callback_manager.after_epoch()
|
||||
|
||||
def _do_validation(self, epoch, step):
|
||||
res = self.tester.test()
|
||||
for name, metric in res.items():
|
||||
|
@ -430,3 +430,30 @@ def seq_mask(seq_len, max_len):
|
||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
|
||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
|
||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len]
|
||||
|
||||
|
||||
class pseudo_tqdm:
|
||||
"""
|
||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def write(self, info):
|
||||
print(info)
|
||||
|
||||
def set_postfix_str(self, info):
|
||||
print(info)
|
||||
|
||||
def __getattr__(self, item):
|
||||
def pass_func(*args, **kwargs):
|
||||
pass
|
||||
|
||||
return pass_func
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
del self
|
||||
|
@ -65,7 +65,7 @@ class CWSBiLSTMEncoder(BaseModel):
|
||||
|
||||
x_tensor = self.char_embedding(chars)
|
||||
|
||||
if not bigrams is None:
|
||||
if hasattr(self, 'bigram_embedding'):
|
||||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
|
||||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)
|
||||
x_tensor = self.embedding_drop(x_tensor)
|
||||
@ -185,5 +185,5 @@ class CWSBiLSTMCRF(BaseModel):
|
||||
feats = self.decoder_model(feats)
|
||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False)
|
||||
|
||||
return {'pred': probs}
|
||||
return {'pred': probs, 'seq_lens':seq_lens}
|
||||
|
||||
|
@ -378,7 +378,7 @@ class BMES2OutputProcessor(Processor):
|
||||
prediction为BSEMS,会被认为是SSSSS.
|
||||
|
||||
"""
|
||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output',
|
||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred', new_added_field_name='output',
|
||||
b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3):
|
||||
"""
|
||||
|
||||
|
@ -11,7 +11,6 @@ from reproduction.chinese_word_segment.process.cws_processor import InputTargetP
|
||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
|
||||
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF
|
||||
|
||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
|
||||
|
||||
ds_name = 'msr'
|
||||
|
||||
@ -39,8 +38,6 @@ bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='b
|
||||
|
||||
seq_len_proc = SeqLenProcessor('chars')
|
||||
|
||||
input_target_proc = InputTargetProcessor(input_fields=['chars', 'bigrams', 'seq_lens', "target"],
|
||||
target_fields=['target', 'seq_lens'])
|
||||
# 2. 使用processor
|
||||
fs2hs_proc(tr_dataset)
|
||||
|
||||
@ -63,15 +60,15 @@ char_vocab_proc(dev_dataset)
|
||||
bigram_vocab_proc(dev_dataset)
|
||||
seq_len_proc(dev_dataset)
|
||||
|
||||
input_target_proc(tr_dataset)
|
||||
input_target_proc(dev_dataset)
|
||||
dev_dataset.set_input('target')
|
||||
tr_dataset.set_input('target')
|
||||
|
||||
|
||||
print("Finish preparing data.")
|
||||
|
||||
# 3. 得到数据集可以用于训练了
|
||||
# TODO pretrain的embedding是怎么解决的?
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
|
||||
|
||||
@ -79,8 +76,8 @@ tag_size = tag_proc.tag_size
|
||||
|
||||
cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100,
|
||||
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(),
|
||||
bigram_embed_dim=100, num_bigram_per_char=8,
|
||||
hidden_size=200, bidirectional=True, embed_drop_p=0.2,
|
||||
bigram_embed_dim=30, num_bigram_per_char=8,
|
||||
hidden_size=200, bidirectional=True, embed_drop_p=0.3,
|
||||
num_layers=1, tag_size=tag_size)
|
||||
cws_model.cuda()
|
||||
|
||||
@ -108,7 +105,7 @@ pp.add_processor(bigram_proc)
|
||||
pp.add_processor(char_vocab_proc)
|
||||
pp.add_processor(bigram_vocab_proc)
|
||||
pp.add_processor(seq_len_proc)
|
||||
pp.add_processor(input_target_proc)
|
||||
# pp.add_processor(input_target_proc)
|
||||
|
||||
# te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
|
||||
te_filename = '/home/hyan/ctb3/test.conllx'
|
||||
@ -142,7 +139,7 @@ from fastNLP.api.processor import ModelProcessor
|
||||
from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor
|
||||
|
||||
model_proc = ModelProcessor(cws_model)
|
||||
output_proc = BMES2OutputProcessor()
|
||||
output_proc = BMES2OutputProcessor(tag_field_name='pred')
|
||||
|
||||
pp = Pipeline()
|
||||
pp.add_processor(fs2hs_proc)
|
||||
@ -158,9 +155,11 @@ pp.add_processor(output_proc)
|
||||
|
||||
|
||||
# TODO 这里貌似需要区分test pipeline与infer pipeline
|
||||
|
||||
infer_context_dict = {'pipeline': pp}
|
||||
# torch.save(infer_context_dict, 'models/cws_crf.pkl')
|
||||
import torch
|
||||
import datetime
|
||||
now = datetime.datetime.now()
|
||||
infer_context_dict = {'pipeline': pp, 'tag_proc': tag_proc}
|
||||
torch.save(infer_context_dict, 'models/cws_crf_{}_{}.pkl'.format(now.month, now.day))
|
||||
|
||||
|
||||
# TODO 还需要考虑如何替换回原文的问题?
|
||||
|
Loading…
Reference in New Issue
Block a user