2019-06-24 21:44:43 +08:00
|
|
|
|
|
2019-07-02 13:38:05 +08:00
|
|
|
|
import random
|
|
|
|
|
import numpy as np
|
2019-06-24 21:44:43 +08:00
|
|
|
|
import torch
|
2019-07-02 13:38:05 +08:00
|
|
|
|
from torch.optim import Adamax
|
|
|
|
|
from torch.optim.lr_scheduler import StepLR
|
2019-06-24 21:44:43 +08:00
|
|
|
|
|
2019-07-02 13:38:05 +08:00
|
|
|
|
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
|
2019-08-19 20:48:08 +08:00
|
|
|
|
from fastNLP.core.callback import GradientClipCallback, LRScheduler, EvaluateCallback
|
|
|
|
|
from fastNLP.core.losses import CrossEntropyLoss
|
|
|
|
|
from fastNLP.embeddings import StaticEmbedding
|
|
|
|
|
from fastNLP.embeddings import ElmoEmbedding
|
|
|
|
|
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe
|
2019-07-12 04:07:47 +08:00
|
|
|
|
from fastNLP.models.snli import ESIM
|
2019-06-24 21:44:43 +08:00
|
|
|
|
|
|
|
|
|
|
2019-07-02 13:38:05 +08:00
|
|
|
|
# define hyper-parameters
|
|
|
|
|
class ESIMConfig:
|
|
|
|
|
|
|
|
|
|
task = 'snli'
|
2019-08-19 20:48:08 +08:00
|
|
|
|
|
2019-07-02 13:38:05 +08:00
|
|
|
|
embedding = 'glove'
|
2019-08-19 20:48:08 +08:00
|
|
|
|
|
2019-07-02 13:38:05 +08:00
|
|
|
|
batch_size_per_gpu = 196
|
|
|
|
|
n_epochs = 30
|
|
|
|
|
lr = 2e-3
|
|
|
|
|
seed = 42
|
2019-08-19 20:48:08 +08:00
|
|
|
|
save_path = None # 模型存储的位置,None表示不存储模型。
|
|
|
|
|
|
2019-07-02 13:38:05 +08:00
|
|
|
|
train_dataset_name = 'train'
|
|
|
|
|
dev_dataset_name = 'dev'
|
|
|
|
|
test_dataset_name = 'test'
|
2019-08-19 20:48:08 +08:00
|
|
|
|
|
|
|
|
|
to_lower = True # 忽略大小写
|
|
|
|
|
tokenizer = 'spacy' # 使用spacy进行分词
|
2019-07-02 13:38:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arg = ESIMConfig()
|
|
|
|
|
|
|
|
|
|
# set random seed
|
|
|
|
|
random.seed(arg.seed)
|
|
|
|
|
np.random.seed(arg.seed)
|
|
|
|
|
torch.manual_seed(arg.seed)
|
|
|
|
|
|
|
|
|
|
n_gpu = torch.cuda.device_count()
|
|
|
|
|
if n_gpu > 0:
|
|
|
|
|
torch.cuda.manual_seed_all(arg.seed)
|
2019-06-24 21:44:43 +08:00
|
|
|
|
|
|
|
|
|
# load data set
|
2019-07-02 13:38:05 +08:00
|
|
|
|
if arg.task == 'snli':
|
2019-08-19 20:48:08 +08:00
|
|
|
|
data_bundle = SNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
|
2019-07-02 13:38:05 +08:00
|
|
|
|
elif arg.task == 'rte':
|
2019-08-19 20:48:08 +08:00
|
|
|
|
data_bundle = RTEPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
|
2019-07-02 13:38:05 +08:00
|
|
|
|
elif arg.task == 'qnli':
|
2019-08-19 20:48:08 +08:00
|
|
|
|
data_bundle = QNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
|
2019-07-02 13:38:05 +08:00
|
|
|
|
elif arg.task == 'mnli':
|
2019-08-19 20:48:08 +08:00
|
|
|
|
data_bundle = MNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
|
2019-07-02 13:38:05 +08:00
|
|
|
|
elif arg.task == 'quora':
|
2019-08-19 20:48:08 +08:00
|
|
|
|
data_bundle = QuoraPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
|
2019-07-02 13:38:05 +08:00
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError(f'NOT support {arg.task} task yet!')
|
2019-06-24 21:44:43 +08:00
|
|
|
|
|
2019-08-19 20:48:08 +08:00
|
|
|
|
print(data_bundle) # print details in data_bundle
|
|
|
|
|
|
2019-06-24 21:44:43 +08:00
|
|
|
|
# load embedding
|
|
|
|
|
if arg.embedding == 'elmo':
|
2019-08-19 20:48:08 +08:00
|
|
|
|
embedding = ElmoEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-medium',
|
|
|
|
|
requires_grad=True)
|
2019-06-24 21:44:43 +08:00
|
|
|
|
elif arg.embedding == 'glove':
|
2019-08-19 20:48:08 +08:00
|
|
|
|
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d',
|
|
|
|
|
requires_grad=True, normalize=False)
|
2019-06-24 21:44:43 +08:00
|
|
|
|
else:
|
2019-07-02 13:38:05 +08:00
|
|
|
|
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!')
|
2019-06-24 21:44:43 +08:00
|
|
|
|
|
|
|
|
|
# define model
|
2019-08-19 20:48:08 +08:00
|
|
|
|
model = ESIM(embedding, num_labels=len(data_bundle.vocabs[Const.TARGET]))
|
2019-07-02 13:38:05 +08:00
|
|
|
|
|
|
|
|
|
# define optimizer and callback
|
|
|
|
|
optimizer = Adamax(lr=arg.lr, params=model.parameters())
|
|
|
|
|
scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍
|
|
|
|
|
|
|
|
|
|
callbacks = [
|
|
|
|
|
GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10)
|
|
|
|
|
LRScheduler(scheduler),
|
|
|
|
|
]
|
2019-06-24 21:44:43 +08:00
|
|
|
|
|
2019-08-19 20:48:08 +08:00
|
|
|
|
if arg.task in ['snli']:
|
|
|
|
|
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name]))
|
|
|
|
|
# evaluate test set in every epoch if task is snli.
|
|
|
|
|
|
2019-06-24 21:44:43 +08:00
|
|
|
|
# define trainer
|
2019-08-19 20:48:08 +08:00
|
|
|
|
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model,
|
2019-07-02 13:38:05 +08:00
|
|
|
|
optimizer=optimizer,
|
2019-08-19 20:48:08 +08:00
|
|
|
|
loss=CrossEntropyLoss(),
|
2019-06-24 21:44:43 +08:00
|
|
|
|
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
|
|
|
|
|
n_epochs=arg.n_epochs, print_every=-1,
|
2019-08-19 20:48:08 +08:00
|
|
|
|
dev_data=data_bundle.datasets[arg.dev_dataset_name],
|
2019-06-24 21:44:43 +08:00
|
|
|
|
metrics=AccuracyMetric(), metric_key='acc',
|
|
|
|
|
device=[i for i in range(torch.cuda.device_count())],
|
|
|
|
|
check_code_level=-1,
|
2019-08-19 20:48:08 +08:00
|
|
|
|
save_path=arg.save_path,
|
|
|
|
|
callbacks=callbacks)
|
2019-06-24 21:44:43 +08:00
|
|
|
|
|
|
|
|
|
# train model
|
|
|
|
|
trainer.train(load_best_model=True)
|
|
|
|
|
|
|
|
|
|
# define tester
|
|
|
|
|
tester = Tester(
|
2019-08-19 20:48:08 +08:00
|
|
|
|
data=data_bundle.datasets[arg.test_dataset_name],
|
2019-06-24 21:44:43 +08:00
|
|
|
|
model=model,
|
|
|
|
|
metrics=AccuracyMetric(),
|
|
|
|
|
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
|
|
|
|
|
device=[i for i in range(torch.cuda.device_count())],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# test model
|
|
|
|
|
tester.test()
|
|
|
|
|
|
|
|
|
|
|