mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.optim import Adadelta
|
|
from torch.optim.lr_scheduler import StepLR
|
|
|
|
from fastNLP import CrossEntropyLoss
|
|
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
|
|
from fastNLP.core.callback import LRScheduler, EvaluateCallback
|
|
from fastNLP.embeddings import StaticEmbedding
|
|
|
|
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe
|
|
from reproduction.matching.model.mwan import MwanModel
|
|
|
|
import fitlog
|
|
fitlog.debug()
|
|
|
|
import argparse
|
|
|
|
|
|
argument = argparse.ArgumentParser()
|
|
argument.add_argument('--task' , choices = ['snli', 'rte', 'qnli', 'mnli'],default = 'snli')
|
|
argument.add_argument('--batch-size' , type = int , default = 128)
|
|
argument.add_argument('--n-epochs' , type = int , default = 50)
|
|
argument.add_argument('--lr' , type = float , default = 1)
|
|
argument.add_argument('--testset-name' , type = str , default = 'test')
|
|
argument.add_argument('--devset-name' , type = str , default = 'dev')
|
|
argument.add_argument('--seed' , type = int , default = 42)
|
|
argument.add_argument('--hidden-size' , type = int , default = 150)
|
|
argument.add_argument('--dropout' , type = float , default = 0.3)
|
|
arg = argument.parse_args()
|
|
|
|
random.seed(arg.seed)
|
|
np.random.seed(arg.seed)
|
|
torch.manual_seed(arg.seed)
|
|
|
|
n_gpu = torch.cuda.device_count()
|
|
if n_gpu > 0:
|
|
torch.cuda.manual_seed_all(arg.seed)
|
|
print (n_gpu)
|
|
|
|
for k in arg.__dict__:
|
|
print(k, arg.__dict__[k], type(arg.__dict__[k]))
|
|
|
|
# load data set
|
|
if arg.task == 'snli':
|
|
data_bundle = SNLIPipe(lower=True, tokenizer='spacy').process_from_file()
|
|
elif arg.task == 'rte':
|
|
data_bundle = RTEPipe(lower=True, tokenizer='spacy').process_from_file()
|
|
elif arg.task == 'qnli':
|
|
data_bundle = QNLIPipe(lower=True, tokenizer='spacy').process_from_file()
|
|
elif arg.task == 'mnli':
|
|
data_bundle = MNLIPipe(lower=True, tokenizer='spacy').process_from_file()
|
|
elif arg.task == 'quora':
|
|
data_bundle = QuoraPipe(lower=True, tokenizer='spacy').process_from_file()
|
|
else:
|
|
raise RuntimeError(f'NOT support {arg.task} task yet!')
|
|
|
|
print(data_bundle)
|
|
print(len(data_bundle.vocabs[Const.INPUTS(0)]))
|
|
|
|
|
|
model = MwanModel(
|
|
num_class = len(data_bundle.vocabs[Const.TARGET]),
|
|
EmbLayer = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], requires_grad=False, normalize=False),
|
|
ElmoLayer = None,
|
|
args_of_imm = {
|
|
"input_size" : 300 ,
|
|
"hidden_size" : arg.hidden_size ,
|
|
"dropout" : arg.dropout ,
|
|
"use_allennlp" : False ,
|
|
} ,
|
|
)
|
|
|
|
|
|
optimizer = Adadelta(lr=arg.lr, params=model.parameters())
|
|
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
|
|
|
|
callbacks = [
|
|
LRScheduler(scheduler),
|
|
]
|
|
|
|
if arg.task in ['snli']:
|
|
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.testset_name]))
|
|
elif arg.task == 'mnli':
|
|
callbacks.append(EvaluateCallback(data={'dev_matched': data_bundle.datasets['dev_matched'],
|
|
'dev_mismatched': data_bundle.datasets['dev_mismatched']},))
|
|
|
|
trainer = Trainer(
|
|
train_data = data_bundle.datasets['train'],
|
|
model = model,
|
|
optimizer = optimizer,
|
|
num_workers = 0,
|
|
batch_size = arg.batch_size,
|
|
n_epochs = arg.n_epochs,
|
|
print_every = -1,
|
|
dev_data = data_bundle.datasets[arg.devset_name],
|
|
metrics = AccuracyMetric(pred = "pred" , target = "target"),
|
|
metric_key = 'acc',
|
|
device = [i for i in range(torch.cuda.device_count())],
|
|
check_code_level = -1,
|
|
callbacks = callbacks,
|
|
loss = CrossEntropyLoss(pred = "pred" , target = "target")
|
|
)
|
|
trainer.train(load_best_model=True)
|
|
|
|
tester = Tester(
|
|
data=data_bundle.datasets[arg.testset_name],
|
|
model=model,
|
|
metrics=AccuracyMetric(),
|
|
batch_size=arg.batch_size,
|
|
device=[i for i in range(torch.cuda.device_count())],
|
|
)
|
|
tester.test()
|