fastNLP/reproduction/matching/matching_mwan.py

116 lines
3.9 KiB
Python

import random
import numpy as np
import torch
from torch.optim import Adadelta
from torch.optim.lr_scheduler import StepLR
from fastNLP import CrossEntropyLoss
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
from fastNLP.core.callback import LRScheduler, EvaluateCallback
from fastNLP.embeddings import StaticEmbedding
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe
from reproduction.matching.model.mwan import MwanModel
import fitlog
fitlog.debug()
import argparse
argument = argparse.ArgumentParser()
argument.add_argument('--task' , choices = ['snli', 'rte', 'qnli', 'mnli'],default = 'snli')
argument.add_argument('--batch-size' , type = int , default = 128)
argument.add_argument('--n-epochs' , type = int , default = 50)
argument.add_argument('--lr' , type = float , default = 1)
argument.add_argument('--testset-name' , type = str , default = 'test')
argument.add_argument('--devset-name' , type = str , default = 'dev')
argument.add_argument('--seed' , type = int , default = 42)
argument.add_argument('--hidden-size' , type = int , default = 150)
argument.add_argument('--dropout' , type = float , default = 0.3)
arg = argument.parse_args()
random.seed(arg.seed)
np.random.seed(arg.seed)
torch.manual_seed(arg.seed)
n_gpu = torch.cuda.device_count()
if n_gpu > 0:
torch.cuda.manual_seed_all(arg.seed)
print (n_gpu)
for k in arg.__dict__:
print(k, arg.__dict__[k], type(arg.__dict__[k]))
# load data set
if arg.task == 'snli':
data_bundle = SNLIPipe(lower=True, tokenizer='spacy').process_from_file()
elif arg.task == 'rte':
data_bundle = RTEPipe(lower=True, tokenizer='spacy').process_from_file()
elif arg.task == 'qnli':
data_bundle = QNLIPipe(lower=True, tokenizer='spacy').process_from_file()
elif arg.task == 'mnli':
data_bundle = MNLIPipe(lower=True, tokenizer='spacy').process_from_file()
elif arg.task == 'quora':
data_bundle = QuoraPipe(lower=True, tokenizer='spacy').process_from_file()
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')
print(data_bundle)
print(len(data_bundle.vocabs[Const.INPUTS(0)]))
model = MwanModel(
num_class = len(data_bundle.vocabs[Const.TARGET]),
EmbLayer = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], requires_grad=False, normalize=False),
ElmoLayer = None,
args_of_imm = {
"input_size" : 300 ,
"hidden_size" : arg.hidden_size ,
"dropout" : arg.dropout ,
"use_allennlp" : False ,
} ,
)
optimizer = Adadelta(lr=arg.lr, params=model.parameters())
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
callbacks = [
LRScheduler(scheduler),
]
if arg.task in ['snli']:
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.testset_name]))
elif arg.task == 'mnli':
callbacks.append(EvaluateCallback(data={'dev_matched': data_bundle.datasets['dev_matched'],
'dev_mismatched': data_bundle.datasets['dev_mismatched']},))
trainer = Trainer(
train_data = data_bundle.datasets['train'],
model = model,
optimizer = optimizer,
num_workers = 0,
batch_size = arg.batch_size,
n_epochs = arg.n_epochs,
print_every = -1,
dev_data = data_bundle.datasets[arg.devset_name],
metrics = AccuracyMetric(pred = "pred" , target = "target"),
metric_key = 'acc',
device = [i for i in range(torch.cuda.device_count())],
check_code_level = -1,
callbacks = callbacks,
loss = CrossEntropyLoss(pred = "pred" , target = "target")
)
trainer.train(load_best_model=True)
tester = Tester(
data=data_bundle.datasets[arg.testset_name],
model=model,
metrics=AccuracyMetric(),
batch_size=arg.batch_size,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()