fastNLP/reproduction/matching/matching_cntn.py

99 lines
3.4 KiB
Python

import argparse
import torch
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const, CrossEntropyLoss
from fastNLP.embeddings import StaticEmbedding
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe
from reproduction.matching.model.cntn import CNTNModel
# define hyper-parameters
argument = argparse.ArgumentParser()
argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glove')
argument.add_argument('--batch-size-per-gpu', type=int, default=256)
argument.add_argument('--n-epochs', type=int, default=200)
argument.add_argument('--lr', type=float, default=1e-5)
argument.add_argument('--save-dir', type=str, default=None)
argument.add_argument('--cntn-depth', type=int, default=1)
argument.add_argument('--cntn-ns', type=int, default=200)
argument.add_argument('--cntn-k-top', type=int, default=10)
argument.add_argument('--cntn-r', type=int, default=5)
argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli')
arg = argument.parse_args()
# dataset dict
dev_dict = {
'qnli': 'dev',
'rte': 'dev',
'snli': 'dev',
'mnli': 'dev_matched',
}
test_dict = {
'qnli': 'dev',
'rte': 'dev',
'snli': 'test',
'mnli': 'dev_matched',
}
# set num_labels
if arg.dataset == 'qnli' or arg.dataset == 'rte':
num_labels = 2
else:
num_labels = 3
# load data set
if arg.dataset == 'snli':
data_bundle = SNLIPipe(lower=True, tokenizer='raw').process_from_file()
elif arg.dataset == 'rte':
data_bundle = RTEPipe(lower=True, tokenizer='raw').process_from_file()
elif arg.dataset == 'qnli':
data_bundle = QNLIPipe(lower=True, tokenizer='raw').process_from_file()
elif arg.dataset == 'mnli':
data_bundle = MNLIPipe(lower=True, tokenizer='raw').process_from_file()
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')
print(data_bundle) # print details in data_bundle
# load embedding
if arg.embedding == 'word2vec':
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-word2vec-300',
requires_grad=True)
elif arg.embedding == 'glove':
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d',
requires_grad=True)
else:
raise ValueError(f'now we only support word2vec or glove embedding for cntn model!')
# define model
model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=num_labels, depth=arg.cntn_depth,
r=arg.cntn_r)
print(model)
# define trainer
trainer = Trainer(train_data=data_bundle.datasets['train'], model=model,
optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
loss=CrossEntropyLoss(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
dev_data=data_bundle.datasets[dev_dict[arg.dataset]],
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
# train model
trainer.train(load_best_model=True)
# define tester
tester = Tester(
data=data_bundle.datasets[test_dict[arg.dataset]],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
device=[i for i in range(torch.cuda.device_count())]
)
# test model
tester.test()