mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
99 lines
3.4 KiB
Python
99 lines
3.4 KiB
Python
import argparse
|
|
import torch
|
|
|
|
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const, CrossEntropyLoss
|
|
from fastNLP.embeddings import StaticEmbedding
|
|
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe
|
|
|
|
from reproduction.matching.model.cntn import CNTNModel
|
|
|
|
# define hyper-parameters
|
|
argument = argparse.ArgumentParser()
|
|
argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glove')
|
|
argument.add_argument('--batch-size-per-gpu', type=int, default=256)
|
|
argument.add_argument('--n-epochs', type=int, default=200)
|
|
argument.add_argument('--lr', type=float, default=1e-5)
|
|
argument.add_argument('--save-dir', type=str, default=None)
|
|
argument.add_argument('--cntn-depth', type=int, default=1)
|
|
argument.add_argument('--cntn-ns', type=int, default=200)
|
|
argument.add_argument('--cntn-k-top', type=int, default=10)
|
|
argument.add_argument('--cntn-r', type=int, default=5)
|
|
argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli')
|
|
arg = argument.parse_args()
|
|
|
|
# dataset dict
|
|
dev_dict = {
|
|
'qnli': 'dev',
|
|
'rte': 'dev',
|
|
'snli': 'dev',
|
|
'mnli': 'dev_matched',
|
|
}
|
|
|
|
test_dict = {
|
|
'qnli': 'dev',
|
|
'rte': 'dev',
|
|
'snli': 'test',
|
|
'mnli': 'dev_matched',
|
|
}
|
|
|
|
# set num_labels
|
|
if arg.dataset == 'qnli' or arg.dataset == 'rte':
|
|
num_labels = 2
|
|
else:
|
|
num_labels = 3
|
|
|
|
# load data set
|
|
if arg.dataset == 'snli':
|
|
data_bundle = SNLIPipe(lower=True, tokenizer='raw').process_from_file()
|
|
elif arg.dataset == 'rte':
|
|
data_bundle = RTEPipe(lower=True, tokenizer='raw').process_from_file()
|
|
elif arg.dataset == 'qnli':
|
|
data_bundle = QNLIPipe(lower=True, tokenizer='raw').process_from_file()
|
|
elif arg.dataset == 'mnli':
|
|
data_bundle = MNLIPipe(lower=True, tokenizer='raw').process_from_file()
|
|
else:
|
|
raise RuntimeError(f'NOT support {arg.task} task yet!')
|
|
|
|
print(data_bundle) # print details in data_bundle
|
|
|
|
# load embedding
|
|
if arg.embedding == 'word2vec':
|
|
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-word2vec-300',
|
|
requires_grad=True)
|
|
elif arg.embedding == 'glove':
|
|
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d',
|
|
requires_grad=True)
|
|
else:
|
|
raise ValueError(f'now we only support word2vec or glove embedding for cntn model!')
|
|
|
|
# define model
|
|
model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=num_labels, depth=arg.cntn_depth,
|
|
r=arg.cntn_r)
|
|
print(model)
|
|
|
|
# define trainer
|
|
trainer = Trainer(train_data=data_bundle.datasets['train'], model=model,
|
|
optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
|
|
loss=CrossEntropyLoss(),
|
|
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
|
|
n_epochs=arg.n_epochs, print_every=-1,
|
|
dev_data=data_bundle.datasets[dev_dict[arg.dataset]],
|
|
metrics=AccuracyMetric(), metric_key='acc',
|
|
device=[i for i in range(torch.cuda.device_count())],
|
|
check_code_level=-1)
|
|
|
|
# train model
|
|
trainer.train(load_best_model=True)
|
|
|
|
# define tester
|
|
tester = Tester(
|
|
data=data_bundle.datasets[test_dict[arg.dataset]],
|
|
model=model,
|
|
metrics=AccuracyMetric(),
|
|
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
|
|
device=[i for i in range(torch.cuda.device_count())]
|
|
)
|
|
|
|
# test model
|
|
tester.test()
|