fastNLP/reproduction/matching/matching_esim.py

120 lines
3.9 KiB
Python
Raw Normal View History

2019-06-24 21:44:43 +08:00
2019-07-02 13:38:05 +08:00
import random
import numpy as np
2019-06-24 21:44:43 +08:00
import torch
2019-07-02 13:38:05 +08:00
from torch.optim import Adamax
from torch.optim.lr_scheduler import StepLR
2019-06-24 21:44:43 +08:00
2019-07-02 13:38:05 +08:00
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
from fastNLP.core.callback import GradientClipCallback, LRScheduler, EvaluateCallback
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.embeddings import StaticEmbedding
from fastNLP.embeddings import ElmoEmbedding
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe
from fastNLP.models.snli import ESIM
2019-06-24 21:44:43 +08:00
2019-07-02 13:38:05 +08:00
# define hyper-parameters
class ESIMConfig:
task = 'snli'
2019-07-02 13:38:05 +08:00
embedding = 'glove'
2019-07-02 13:38:05 +08:00
batch_size_per_gpu = 196
n_epochs = 30
lr = 2e-3
seed = 42
save_path = None # 模型存储的位置None表示不存储模型。
2019-07-02 13:38:05 +08:00
train_dataset_name = 'train'
dev_dataset_name = 'dev'
test_dataset_name = 'test'
to_lower = True # 忽略大小写
tokenizer = 'spacy' # 使用spacy进行分词
2019-07-02 13:38:05 +08:00
arg = ESIMConfig()
# set random seed
random.seed(arg.seed)
np.random.seed(arg.seed)
torch.manual_seed(arg.seed)
n_gpu = torch.cuda.device_count()
if n_gpu > 0:
torch.cuda.manual_seed_all(arg.seed)
2019-06-24 21:44:43 +08:00
# load data set
2019-07-02 13:38:05 +08:00
if arg.task == 'snli':
data_bundle = SNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
2019-07-02 13:38:05 +08:00
elif arg.task == 'rte':
data_bundle = RTEPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
2019-07-02 13:38:05 +08:00
elif arg.task == 'qnli':
data_bundle = QNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
2019-07-02 13:38:05 +08:00
elif arg.task == 'mnli':
data_bundle = MNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
2019-07-02 13:38:05 +08:00
elif arg.task == 'quora':
data_bundle = QuoraPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
2019-07-02 13:38:05 +08:00
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')
2019-06-24 21:44:43 +08:00
print(data_bundle) # print details in data_bundle
2019-06-24 21:44:43 +08:00
# load embedding
if arg.embedding == 'elmo':
embedding = ElmoEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-medium',
requires_grad=True)
2019-06-24 21:44:43 +08:00
elif arg.embedding == 'glove':
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d',
requires_grad=True, normalize=False)
2019-06-24 21:44:43 +08:00
else:
2019-07-02 13:38:05 +08:00
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!')
2019-06-24 21:44:43 +08:00
# define model
model = ESIM(embedding, num_labels=len(data_bundle.vocabs[Const.TARGET]))
2019-07-02 13:38:05 +08:00
# define optimizer and callback
optimizer = Adamax(lr=arg.lr, params=model.parameters())
scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍
callbacks = [
GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10)
LRScheduler(scheduler),
]
2019-06-24 21:44:43 +08:00
if arg.task in ['snli']:
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name]))
# evaluate test set in every epoch if task is snli.
2019-06-24 21:44:43 +08:00
# define trainer
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model,
2019-07-02 13:38:05 +08:00
optimizer=optimizer,
loss=CrossEntropyLoss(),
2019-06-24 21:44:43 +08:00
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
dev_data=data_bundle.datasets[arg.dev_dataset_name],
2019-06-24 21:44:43 +08:00
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1,
save_path=arg.save_path,
callbacks=callbacks)
2019-06-24 21:44:43 +08:00
# train model
trainer.train(load_best_model=True)
# define tester
tester = Tester(
data=data_bundle.datasets[arg.test_dataset_name],
2019-06-24 21:44:43 +08:00
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
device=[i for i in range(torch.cuda.device_count())],
)
# test model
tester.test()