mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
216 lines
7.6 KiB
Python
216 lines
7.6 KiB
Python
# 首先需要加入以下的路径到环境变量,因为当前只对内部测试开放,所以需要手动申明一下路径
|
||
import os
|
||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
|
||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
|
||
|
||
import sys
|
||
sys.path.append('../..')
|
||
from fastNLP.core.const import Const as C
|
||
import torch.nn as nn
|
||
from fastNLP.io.data_loader import YelpLoader
|
||
#from data.sstLoader import sst2Loader
|
||
from model.char_cnn import CharacterLevelCNN
|
||
from fastNLP import CrossEntropyLoss, AccuracyMetric
|
||
from fastNLP.core.trainer import Trainer
|
||
from torch.optim import SGD
|
||
from torch.autograd import Variable
|
||
import torch
|
||
from torch.optim.lr_scheduler import LambdaLR
|
||
from fastNLP.core import LRScheduler
|
||
|
||
|
||
##hyper
|
||
#todo 这里加入fastnlp的记录
|
||
class Config():
|
||
#seed=7777
|
||
model_dir_or_name="en-base-uncased"
|
||
embedding_grad= False,
|
||
bert_embedding_larers= '4,-2,-1'
|
||
train_epoch= 50
|
||
num_classes=2
|
||
task= "yelp_p"
|
||
#yelp_p
|
||
datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv",
|
||
"test": "/remote-home/ygwang/yelp_polarity/test.csv"}
|
||
#IMDB
|
||
#datapath = {"train": "/remote-home/ygwang/IMDB_data/train.csv",
|
||
# "test": "/remote-home/ygwang/IMDB_data/test.csv"}
|
||
# sst
|
||
# datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv",
|
||
# "dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"}
|
||
|
||
lr=0.01
|
||
batch_size=128
|
||
model_size="large"
|
||
number_of_characters=69
|
||
extra_characters=''
|
||
max_length=1014
|
||
weight_decay = 1e-5
|
||
|
||
char_cnn_config={
|
||
"alphabet": {
|
||
"en": {
|
||
"lower": {
|
||
"alphabet": "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
|
||
"number_of_characters": 69
|
||
},
|
||
"both": {
|
||
"alphabet": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
|
||
"number_of_characters": 95
|
||
}
|
||
}
|
||
},
|
||
"model_parameters": {
|
||
"small": {
|
||
"conv": [
|
||
#依次是channel,kennnel_size,maxpooling_size
|
||
[256,7,3],
|
||
[256,7,3],
|
||
[256,3,-1],
|
||
[256,3,-1],
|
||
[256,3,-1],
|
||
[256,3,3]
|
||
],
|
||
"fc": [1024,1024]
|
||
},
|
||
"large":{
|
||
"conv":[
|
||
[1024, 7, 3],
|
||
[1024, 7, 3],
|
||
[1024, 3, -1],
|
||
[1024, 3, -1],
|
||
[1024, 3, -1],
|
||
[1024, 3, 3]
|
||
],
|
||
"fc": [2048,2048]
|
||
}
|
||
},
|
||
"data": {
|
||
"text_column": "SentimentText",
|
||
"label_column": "Sentiment",
|
||
"max_length": 1014,
|
||
"num_of_classes": 2,
|
||
"encoding": None,
|
||
"chunksize": 50000,
|
||
"max_rows": 100000,
|
||
"preprocessing_steps": ["lower", "remove_hashtags", "remove_urls", "remove_user_mentions"]
|
||
},
|
||
"training": {
|
||
"batch_size": 128,
|
||
"learning_rate": 0.01,
|
||
"epochs": 10,
|
||
"optimizer": "sgd"
|
||
}
|
||
}
|
||
ops=Config
|
||
|
||
# set_rng_seeds(ops.seed)
|
||
# print('RNG SEED: {}'.format(ops.seed))
|
||
|
||
|
||
##1.task相关信息:利用dataloader载入dataInfo
|
||
#dataloader=SST2Loader()
|
||
#dataloader=IMDBLoader()
|
||
dataloader=YelpLoader(fine_grained=True)
|
||
datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False)
|
||
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
|
||
ops.number_of_characters=len(char_vocab)
|
||
ops.embedding_dim=ops.number_of_characters
|
||
|
||
#chartoindex
|
||
def chartoindex(chars):
|
||
max_seq_len=ops.max_length
|
||
zero_index=len(char_vocab)
|
||
char_index_list=[]
|
||
for char in chars:
|
||
if char in char_vocab:
|
||
char_index_list.append(char_vocab.index(char))
|
||
else:
|
||
#<unk>和<pad>均使用最后一个作为embbeding
|
||
char_index_list.append(zero_index)
|
||
if len(char_index_list) > max_seq_len:
|
||
char_index_list = char_index_list[:max_seq_len]
|
||
elif 0 < len(char_index_list) < max_seq_len:
|
||
char_index_list = char_index_list+[zero_index]*(max_seq_len-len(char_index_list))
|
||
elif len(char_index_list) == 0:
|
||
char_index_list=[zero_index]*max_seq_len
|
||
return char_index_list
|
||
|
||
for dataset in datainfo.datasets.values():
|
||
dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars')
|
||
|
||
datainfo.datasets['train'].set_input('chars')
|
||
datainfo.datasets['test'].set_input('chars')
|
||
datainfo.datasets['train'].set_target('target')
|
||
datainfo.datasets['test'].set_target('target')
|
||
|
||
##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model
|
||
class ModelFactory(nn.Module):
|
||
"""
|
||
用于拼装embedding,encoder,decoder 以及设计forward过程
|
||
|
||
:param embedding: embbeding model
|
||
:param encoder: encoder model
|
||
:param decoder: decoder model
|
||
|
||
"""
|
||
def __int__(self,embedding,encoder,decoder,**kwargs):
|
||
super(ModelFactory,self).__init__()
|
||
self.embedding=embedding
|
||
self.encoder=encoder
|
||
self.decoder=decoder
|
||
|
||
def forward(self,x):
|
||
return {C.OUTPUT:None}
|
||
|
||
## 2.或直接复用fastNLP的模型
|
||
#vocab=datainfo.vocabs['words']
|
||
vocab_label=datainfo.vocabs['target']
|
||
'''
|
||
# emded_char=CNNCharEmbedding(vocab)
|
||
# embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
|
||
# embedding=StackEmbedding([emded_char, embed_word])
|
||
# cnn_char_embed = CNNCharEmbedding(vocab)
|
||
# lstm_char_embed = LSTMCharEmbedding(vocab)
|
||
# embedding = StackEmbedding([cnn_char_embed, lstm_char_embed])
|
||
'''
|
||
#one-hot embedding
|
||
embedding_weight= Variable(torch.zeros(len(char_vocab)+1, len(char_vocab)))
|
||
|
||
for i in range(len(char_vocab)):
|
||
embedding_weight[i][i]=1
|
||
embedding=nn.Embedding(num_embeddings=len(char_vocab)+1,embedding_dim=len(char_vocab),padding_idx=len(char_vocab),_weight=embedding_weight)
|
||
for para in embedding.parameters():
|
||
para.requires_grad=False
|
||
#CNNText太过于简单
|
||
#model=CNNText(init_embed=embedding, num_classes=ops.num_classes)
|
||
model=CharacterLevelCNN(ops,embedding)
|
||
|
||
## 3. 声明loss,metric,optimizer
|
||
loss=CrossEntropyLoss
|
||
metric=AccuracyMetric
|
||
#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
|
||
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
|
||
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
|
||
callbacks = []
|
||
# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
|
||
callbacks.append(
|
||
LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
|
||
ops.train_epoch * 0.8 else ops.lr * 0.1))
|
||
)
|
||
|
||
## 4.定义train方法
|
||
def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):
|
||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size,
|
||
metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1,
|
||
n_epochs=num_epochs,callbacks=callbacks)
|
||
print(trainer.train())
|
||
|
||
|
||
|
||
if __name__=="__main__":
|
||
#print(vocab_label)
|
||
|
||
#print(datainfo.datasets["train"])
|
||
train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch)
|
||
|