mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
- add star-transformer reproduction
This commit is contained in:
parent
6a8f50e73e
commit
040bd2ab07
@ -26,13 +26,11 @@ class StarTransEnc(nn.Module):
|
||||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
|
||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
|
||||
此时就以传入的对象作为embedding
|
||||
:param num_cls: 输出类别个数
|
||||
:param hidden_size: 模型中特征维度.
|
||||
:param num_layers: 模型层数.
|
||||
:param num_head: 模型中multi-head的head个数.
|
||||
:param head_dim: 模型中multi-head中每个head特征维度.
|
||||
:param max_len: 模型能接受的最大输入长度.
|
||||
:param cls_hidden_size: 分类器隐层维度.
|
||||
:param emb_dropout: 词嵌入的dropout概率.
|
||||
:param dropout: 模型除词嵌入外的dropout概率.
|
||||
"""
|
||||
@ -59,7 +57,7 @@ class StarTransEnc(nn.Module):
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
:param FloatTensor data: [batch, length, hidden] 输入的序列
|
||||
:param FloatTensor x: [batch, length, hidden] 输入的序列
|
||||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0,
|
||||
否则为 1
|
||||
:return: [batch, length, hidden] 编码后的输出序列
|
||||
@ -110,8 +108,9 @@ class STSeqLabel(nn.Module):
|
||||
|
||||
用于序列标注的Star-Transformer模型
|
||||
|
||||
:param vocab_size: 词嵌入的词典大小
|
||||
:param emb_dim: 每个词嵌入的特征维度
|
||||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
|
||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
|
||||
此时就以传入的对象作为embedding
|
||||
:param num_cls: 输出类别个数
|
||||
:param hidden_size: 模型中特征维度. Default: 300
|
||||
:param num_layers: 模型层数. Default: 4
|
||||
@ -174,8 +173,9 @@ class STSeqCls(nn.Module):
|
||||
|
||||
用于分类任务的Star-Transformer
|
||||
|
||||
:param vocab_size: 词嵌入的词典大小
|
||||
:param emb_dim: 每个词嵌入的特征维度
|
||||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
|
||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
|
||||
此时就以传入的对象作为embedding
|
||||
:param num_cls: 输出类别个数
|
||||
:param hidden_size: 模型中特征维度. Default: 300
|
||||
:param num_layers: 模型层数. Default: 4
|
||||
@ -238,8 +238,9 @@ class STNLICls(nn.Module):
|
||||
|
||||
用于自然语言推断(NLI)的Star-Transformer
|
||||
|
||||
:param vocab_size: 词嵌入的词典大小
|
||||
:param emb_dim: 每个词嵌入的特征维度
|
||||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
|
||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
|
||||
此时就以传入的对象作为embedding
|
||||
:param num_cls: 输出类别个数
|
||||
:param hidden_size: 模型中特征维度. Default: 300
|
||||
:param num_layers: 模型层数. Default: 4
|
||||
|
@ -43,7 +43,7 @@ class StarTransformer(nn.Module):
|
||||
for _ in range(self.iters)])
|
||||
|
||||
if max_len is not None:
|
||||
self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size)
|
||||
self.pos_emb = nn.Embedding(max_len, hidden_size)
|
||||
else:
|
||||
self.pos_emb = None
|
||||
|
||||
|
157
reproduction/Star_transformer/datasets.py
Normal file
157
reproduction/Star_transformer/datasets.py
Normal file
@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import json
|
||||
import os
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader
|
||||
from fastNLP.core import Const as C
|
||||
import numpy as np
|
||||
|
||||
MAX_LEN = 128
|
||||
|
||||
def update_v(vocab, data, field):
|
||||
data.apply(lambda x: vocab.add_word_lst(x[field]), new_field_name=None)
|
||||
|
||||
|
||||
def to_index(vocab, data, field, name):
|
||||
def func(x):
|
||||
try:
|
||||
return [vocab.to_index(w) for w in x[field]]
|
||||
except ValueError:
|
||||
return [vocab.padding_idx for _ in x[field]]
|
||||
data.apply(func, new_field_name=name)
|
||||
|
||||
|
||||
def load_seqtag(path, files, indexs):
|
||||
word_h, tag_h = 'words', 'tags'
|
||||
loader = ConllLoader(headers=[word_h, tag_h], indexes=indexs)
|
||||
ds_list = []
|
||||
for fn in files:
|
||||
ds_list.append(loader.load(os.path.join(path, fn)))
|
||||
word_v = Vocabulary(min_freq=2)
|
||||
tag_v = Vocabulary(unknown=None)
|
||||
update_v(word_v, ds_list[0], word_h)
|
||||
update_v(tag_v, ds_list[0], tag_h)
|
||||
|
||||
def process_data(ds):
|
||||
to_index(word_v, ds, word_h, C.INPUT)
|
||||
to_index(tag_v, ds, tag_h, C.TARGET)
|
||||
ds.apply(lambda x: x[C.INPUT][:MAX_LEN], new_field_name=C.INPUT)
|
||||
ds.apply(lambda x: x[C.TARGET][:MAX_LEN], new_field_name=C.TARGET)
|
||||
ds.apply(lambda x: len(x[word_h]), new_field_name=C.INPUT_LEN)
|
||||
ds.set_input(C.INPUT, C.INPUT_LEN)
|
||||
ds.set_target(C.TARGET, C.INPUT_LEN)
|
||||
for i in range(len(ds_list)):
|
||||
process_data(ds_list[i])
|
||||
return ds_list, word_v, tag_v
|
||||
|
||||
|
||||
def load_sst(path, files):
|
||||
loaders = [SSTLoader(subtree=sub, fine_grained=True)
|
||||
for sub in [True, False, False]]
|
||||
ds_list = [loader.load(os.path.join(path, fn))
|
||||
for fn, loader in zip(files, loaders)]
|
||||
word_v = Vocabulary(min_freq=2)
|
||||
tag_v = Vocabulary(unknown=None, padding=None)
|
||||
for ds in ds_list:
|
||||
ds.apply(lambda x: [w.lower()
|
||||
for w in x['words']], new_field_name='words')
|
||||
ds_list[0].drop(lambda x: len(x['words']) < 3)
|
||||
update_v(word_v, ds_list[0], 'words')
|
||||
ds_list[0].apply(lambda x: tag_v.add_word(
|
||||
x['target']), new_field_name=None)
|
||||
|
||||
def process_data(ds):
|
||||
to_index(word_v, ds, 'words', C.INPUT)
|
||||
ds.apply(lambda x: tag_v.to_index(x['target']), new_field_name=C.TARGET)
|
||||
ds.apply(lambda x: x[C.INPUT][:MAX_LEN], new_field_name=C.INPUT)
|
||||
ds.apply(lambda x: len(x['words']), new_field_name=C.INPUT_LEN)
|
||||
ds.set_input(C.INPUT, C.INPUT_LEN)
|
||||
ds.set_target(C.TARGET)
|
||||
for i in range(len(ds_list)):
|
||||
process_data(ds_list[i])
|
||||
return ds_list, word_v, tag_v
|
||||
|
||||
|
||||
def load_snli(path, files):
|
||||
loader = SNLILoader()
|
||||
ds_list = [loader.load(os.path.join(path, f)) for f in files]
|
||||
word_v = Vocabulary(min_freq=2)
|
||||
tag_v = Vocabulary(unknown=None, padding=None)
|
||||
for ds in ds_list:
|
||||
ds.apply(lambda x: [w.lower()
|
||||
for w in x['words1']], new_field_name='words1')
|
||||
ds.apply(lambda x: [w.lower()
|
||||
for w in x['words2']], new_field_name='words2')
|
||||
update_v(word_v, ds_list[0], 'words1')
|
||||
update_v(word_v, ds_list[0], 'words2')
|
||||
ds_list[0].apply(lambda x: tag_v.add_word(
|
||||
x['target']), new_field_name=None)
|
||||
|
||||
def process_data(ds):
|
||||
to_index(word_v, ds, 'words1', C.INPUTS(0))
|
||||
to_index(word_v, ds, 'words2', C.INPUTS(1))
|
||||
ds.apply(lambda x: tag_v.to_index(x['target']), new_field_name=C.TARGET)
|
||||
ds.apply(lambda x: x[C.INPUTS(0)][:MAX_LEN], new_field_name=C.INPUTS(0))
|
||||
ds.apply(lambda x: x[C.INPUTS(1)][:MAX_LEN], new_field_name=C.INPUTS(1))
|
||||
ds.apply(lambda x: len(x[C.INPUTS(0)]), new_field_name=C.INPUT_LENS(0))
|
||||
ds.apply(lambda x: len(x[C.INPUTS(1)]), new_field_name=C.INPUT_LENS(1))
|
||||
ds.set_input(C.INPUTS(0), C.INPUTS(1), C.INPUT_LENS(0), C.INPUT_LENS(1))
|
||||
ds.set_target(C.TARGET)
|
||||
for i in range(len(ds_list)):
|
||||
process_data(ds_list[i])
|
||||
return ds_list, word_v, tag_v
|
||||
|
||||
|
||||
class EmbedLoader:
|
||||
@staticmethod
|
||||
def parse_glove_line(line):
|
||||
line = line.split()
|
||||
if len(line) <= 2:
|
||||
raise RuntimeError(
|
||||
"something goes wrong in parsing glove embedding")
|
||||
return line[0], line[1:]
|
||||
|
||||
@staticmethod
|
||||
def str_list_2_vec(line):
|
||||
return torch.Tensor(list(map(float, line)))
|
||||
|
||||
@staticmethod
|
||||
def fast_load_embedding(emb_dim, emb_file, vocab):
|
||||
"""Fast load the pre-trained embedding and combine with the given dictionary.
|
||||
This loading method uses line-by-line operation.
|
||||
|
||||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding.
|
||||
:param str emb_file: the pre-trained embedding file path.
|
||||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding
|
||||
:return embedding_matrix: numpy.ndarray
|
||||
|
||||
"""
|
||||
if vocab is None:
|
||||
raise RuntimeError("You must provide a vocabulary.")
|
||||
embedding_matrix = np.zeros(
|
||||
shape=(len(vocab), emb_dim), dtype=np.float32)
|
||||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int)
|
||||
with open(emb_file, "r", encoding="utf-8") as f:
|
||||
startline = f.readline()
|
||||
if len(startline.split()) > 2:
|
||||
f.seek(0)
|
||||
for line in f:
|
||||
word, vector = EmbedLoader.parse_glove_line(line)
|
||||
try:
|
||||
if word in vocab:
|
||||
vector = EmbedLoader.str_list_2_vec(vector)
|
||||
if emb_dim != vector.size(0):
|
||||
continue
|
||||
embedding_matrix[vocab[word]] = vector
|
||||
hit_flags[vocab[word]] = 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if np.sum(hit_flags) < len(vocab):
|
||||
# some words from vocab are missing in pre-trained embedding
|
||||
# we normally sample each dimension
|
||||
vocab_embed = embedding_matrix[np.where(hit_flags)]
|
||||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
|
||||
size=(len(vocab) - np.sum(hit_flags), emb_dim))
|
||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors
|
||||
return embedding_matrix
|
56
reproduction/Star_transformer/modules.py
Normal file
56
reproduction/Star_transformer/modules.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from fastNLP.core.losses import LossBase
|
||||
|
||||
|
||||
reduce_func = {
|
||||
'none': lambda x, mask: x*mask,
|
||||
'sum': lambda x, mask: (x*mask).sum(),
|
||||
'mean': lambda x, mask: (x*mask).sum() / mask.sum(),
|
||||
}
|
||||
|
||||
|
||||
class LabelSmoothCrossEntropy(nn.Module):
|
||||
def __init__(self, smoothing=0.1, ignore_index=-100, reduction='mean'):
|
||||
global reduce_func
|
||||
super().__init__()
|
||||
if smoothing < 0 or smoothing > 1:
|
||||
raise ValueError('invalid smoothing value: {}'.format(smoothing))
|
||||
self.smoothing = smoothing
|
||||
self.ignore_index = ignore_index
|
||||
if reduction not in reduce_func:
|
||||
raise ValueError('invalid reduce type: {}'.format(reduction))
|
||||
self.reduce_func = reduce_func[reduction]
|
||||
|
||||
def forward(self, input, target):
|
||||
input = F.log_softmax(input, dim=1) # [N, C, ...]
|
||||
smooth_val = self.smoothing / input.size(1) # [N, C, ...]
|
||||
target_logit = input.new_full(input.size(), fill_value=smooth_val)
|
||||
target_logit.scatter_(1, target[:, None], 1 - self.smoothing)
|
||||
result = -(target_logit * input).sum(1) # [N, ...]
|
||||
mask = (target != self.ignore_index).float()
|
||||
return self.reduce_func(result, mask)
|
||||
|
||||
|
||||
class SmoothCE(LossBase):
|
||||
def __init__(self, pred=None, target=None, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_fn = LabelSmoothCrossEntropy(**kwargs)
|
||||
self._init_param_map(pred=pred, target=target)
|
||||
|
||||
def get_loss(self, pred, target):
|
||||
return self.loss_fn(pred, target)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
|
||||
sm_loss_fn = LabelSmoothCrossEntropy(smoothing=0, ignore_index=0)
|
||||
predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0],
|
||||
[0, 0.9, 0.2, 0.1, 0],
|
||||
[1, 0.2, 0.7, 0.1, 0]])
|
||||
target = torch.tensor([2, 1, 0])
|
||||
loss = loss_fn(predict, target)
|
||||
sm_loss = sm_loss_fn(predict, target)
|
||||
print(loss, sm_loss)
|
5
reproduction/Star_transformer/run.sh
Normal file
5
reproduction/Star_transformer/run.sh
Normal file
@ -0,0 +1,5 @@
|
||||
#python -u train.py --task pos --ds conll --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > conll_pos102.log 2>&1 &
|
||||
#python -u train.py --task pos --ds ctb --mode train --gpu 1 --lr 3e-4 --w_decay 2e-5 --lr_decay .95 --drop 0.3 --ep 25 --bsz 64 > ctb_pos101.log 2>&1 &
|
||||
#python -u train.py --task cls --ds sst --mode train --gpu 2 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.5 --ep 50 --bsz 128 > sst_cls201.log &
|
||||
#python -u train.py --task nli --ds snli --mode train --gpu 1 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 128 > snli_nli201.log &
|
||||
python -u train.py --task ner --ds conll --mode train --gpu 0 --lr 1e-4 --w_decay 1e-5 --lr_decay 0.9 --drop 0.4 --ep 120 --bsz 64 > conll_ner201.log &
|
214
reproduction/Star_transformer/train.py
Normal file
214
reproduction/Star_transformer/train.py
Normal file
@ -0,0 +1,214 @@
|
||||
from util import get_argparser, set_gpu, set_rng_seeds, add_model_args
|
||||
from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import numpy as np
|
||||
import fastNLP as FN
|
||||
from fastNLP.models.star_transformer import STSeqLabel, STSeqCls, STNLICls
|
||||
from fastNLP.core.const import Const as C
|
||||
import sys
|
||||
sys.path.append('/remote-home/yfshao/workdir/dev_fastnlp/')
|
||||
|
||||
|
||||
g_model_select = {
|
||||
'pos': STSeqLabel,
|
||||
'ner': STSeqLabel,
|
||||
'cls': STSeqCls,
|
||||
'nli': STNLICls,
|
||||
}
|
||||
|
||||
g_emb_file_path = {'en': '/remote-home/yfshao/workdir/datasets/word_vector/glove.840B.300d.txt',
|
||||
'zh': '/remote-home/yfshao/workdir/datasets/word_vector/cc.zh.300.vec'}
|
||||
|
||||
g_args = None
|
||||
g_model_cfg = None
|
||||
|
||||
|
||||
def get_ptb_pos():
|
||||
pos_dir = '/remote-home/yfshao/workdir/datasets/pos'
|
||||
pos_files = ['train.pos', 'dev.pos', 'test.pos', ]
|
||||
return load_seqtag(pos_dir, pos_files, [0, 1])
|
||||
|
||||
|
||||
def get_ctb_pos():
|
||||
ctb_dir = '/remote-home/yfshao/workdir/datasets/ctb9_hy'
|
||||
files = ['train.conllx', 'dev.conllx', 'test.conllx']
|
||||
return load_seqtag(ctb_dir, files, [1, 4])
|
||||
|
||||
|
||||
def get_conll2012_pos():
|
||||
path = '/remote-home/yfshao/workdir/datasets/ontonotes/pos'
|
||||
files = ['ontonotes-conll.train',
|
||||
'ontonotes-conll.dev',
|
||||
'ontonotes-conll.conll-2012-test']
|
||||
return load_seqtag(path, files, [0, 1])
|
||||
|
||||
|
||||
def get_conll2012_ner():
|
||||
path = '/remote-home/yfshao/workdir/datasets/ontonotes/ner'
|
||||
files = ['bieso-ontonotes-conll-ner.train',
|
||||
'bieso-ontonotes-conll-ner.dev',
|
||||
'bieso-ontonotes-conll-ner.conll-2012-test']
|
||||
return load_seqtag(path, files, [0, 1])
|
||||
|
||||
|
||||
def get_sst():
|
||||
path = '/remote-home/yfshao/workdir/datasets/SST'
|
||||
files = ['train.txt', 'dev.txt', 'test.txt']
|
||||
return load_sst(path, files)
|
||||
|
||||
|
||||
def get_snli():
|
||||
path = '/remote-home/yfshao/workdir/datasets/nli-data/snli_1.0'
|
||||
files = ['snli_1.0_train.jsonl',
|
||||
'snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']
|
||||
return load_snli(path, files)
|
||||
|
||||
|
||||
g_datasets = {
|
||||
'ptb-pos': get_ptb_pos,
|
||||
'ctb-pos': get_ctb_pos,
|
||||
'conll-pos': get_conll2012_pos,
|
||||
'conll-ner': get_conll2012_ner,
|
||||
'sst-cls': get_sst,
|
||||
'snli-nli': get_snli,
|
||||
}
|
||||
|
||||
|
||||
def load_pretrain_emb(word_v, lang='en'):
|
||||
print('loading pre-train embeddings')
|
||||
emb = EmbedLoader.fast_load_embedding(300, g_emb_file_path[lang], word_v)
|
||||
emb /= np.linalg.norm(emb, axis=1, keepdims=True)
|
||||
emb = torch.tensor(emb, dtype=torch.float32)
|
||||
print('embedding mean: {:.6}, std: {:.6}'.format(emb.mean(), emb.std()))
|
||||
emb[word_v.padding_idx].fill_(0)
|
||||
return emb
|
||||
|
||||
|
||||
class MyCallback(FN.core.callback.Callback):
|
||||
def on_train_begin(self):
|
||||
super(MyCallback, self).on_train_begin()
|
||||
self.init_lrs = [pg['lr'] for pg in self.optimizer.param_groups]
|
||||
|
||||
def on_backward_end(self):
|
||||
nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5.0)
|
||||
|
||||
def on_step_end(self):
|
||||
warm_steps = 6000
|
||||
# learning rate warm-up & decay
|
||||
if self.step <= warm_steps:
|
||||
for lr, pg in zip(self.init_lrs, self.optimizer.param_groups):
|
||||
pg['lr'] = lr * (self.step / float(warm_steps))
|
||||
|
||||
elif self.step % 3000 == 0:
|
||||
for pg in self.optimizer.param_groups:
|
||||
cur_lr = pg['lr']
|
||||
pg['lr'] = max(1e-5, cur_lr*g_args.lr_decay)
|
||||
|
||||
|
||||
|
||||
def train():
|
||||
seed = set_rng_seeds(1234)
|
||||
print('RNG SEED {}'.format(seed))
|
||||
print('loading data')
|
||||
ds_list, word_v, tag_v = g_datasets['{}-{}'.format(
|
||||
g_args.ds, g_args.task)]()
|
||||
print(ds_list[0][:2])
|
||||
embed = load_pretrain_emb(word_v, lang='zh' if g_args.ds == 'ctb' else 'en')
|
||||
g_model_cfg['num_cls'] = len(tag_v)
|
||||
print(g_model_cfg)
|
||||
g_model_cfg['init_embed'] = embed
|
||||
model = g_model_select[g_args.task.lower()](**g_model_cfg)
|
||||
|
||||
def init_model(model):
|
||||
for p in model.parameters():
|
||||
if p.size(0) != len(word_v):
|
||||
nn.init.normal_(p, 0.0, 0.05)
|
||||
init_model(model)
|
||||
train_data = ds_list[0]
|
||||
dev_data = ds_list[2]
|
||||
test_data = ds_list[1]
|
||||
print(tag_v.word2idx)
|
||||
|
||||
if g_args.task in ['pos', 'ner']:
|
||||
padding_idx = tag_v.padding_idx
|
||||
else:
|
||||
padding_idx = -100
|
||||
print('padding_idx ', padding_idx)
|
||||
loss = FN.CrossEntropyLoss(padding_idx=padding_idx)
|
||||
metrics = {
|
||||
'pos': (None, FN.AccuracyMetric()),
|
||||
'ner': ('f', FN.core.metrics.SpanFPreRecMetric(
|
||||
tag_vocab=tag_v, encoding_type='bmeso', ignore_labels=[''], )),
|
||||
'cls': (None, FN.AccuracyMetric()),
|
||||
'nli': (None, FN.AccuracyMetric()),
|
||||
}
|
||||
metric_key, metric = metrics[g_args.task]
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
ex_param = [x for x in model.parameters(
|
||||
) if x.requires_grad and x.size(0) != len(word_v)]
|
||||
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1},
|
||||
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ]
|
||||
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data,
|
||||
loss=loss, metrics=metric, metric_key=metric_key,
|
||||
optimizer=torch.optim.Adam(optim_cfg),
|
||||
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000,
|
||||
device=device,
|
||||
use_tqdm=False, prefetch=False,
|
||||
save_path=g_args.log,
|
||||
callbacks=[MyCallback()])
|
||||
|
||||
trainer.train()
|
||||
tester = FN.Tester(data=test_data, model=model, metrics=metric,
|
||||
batch_size=128, device=device)
|
||||
tester.test()
|
||||
|
||||
|
||||
def test():
|
||||
pass
|
||||
|
||||
|
||||
def infer():
|
||||
pass
|
||||
|
||||
|
||||
run_select = {
|
||||
'train': train,
|
||||
'test': test,
|
||||
'infer': infer,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
global g_args, g_model_cfg
|
||||
import signal
|
||||
|
||||
def signal_handler(signal, frame):
|
||||
raise KeyboardInterrupt
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
parser = get_argparser()
|
||||
parser.add_argument('--task', choices=['pos', 'ner', 'cls', 'nli'])
|
||||
parser.add_argument('--mode', choices=['train', 'test', 'infer'])
|
||||
parser.add_argument('--ds', type=str)
|
||||
add_model_args(parser)
|
||||
g_args = parser.parse_args()
|
||||
print(g_args.__dict__)
|
||||
set_gpu(g_args.gpu)
|
||||
g_model_cfg = {
|
||||
'init_embed': (None, 300),
|
||||
'num_cls': None,
|
||||
'hidden_size': g_args.hidden,
|
||||
'num_layers': 4,
|
||||
'num_head': g_args.nhead,
|
||||
'head_dim': g_args.hdim,
|
||||
'max_len': MAX_LEN,
|
||||
'cls_hidden_size': 600,
|
||||
'emb_dropout': 0.3,
|
||||
'dropout': g_args.drop,
|
||||
}
|
||||
run_select[g_args.mode.lower()]()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
112
reproduction/Star_transformer/util.py
Normal file
112
reproduction/Star_transformer/util.py
Normal file
@ -0,0 +1,112 @@
|
||||
import fastNLP as FN
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
|
||||
def get_argparser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--lr', type=float, required=True)
|
||||
parser.add_argument('--w_decay', type=float, required=True)
|
||||
parser.add_argument('--lr_decay', type=float, required=True)
|
||||
parser.add_argument('--bsz', type=int, required=True)
|
||||
parser.add_argument('--ep', type=int, required=True)
|
||||
parser.add_argument('--drop', type=float, required=True)
|
||||
parser.add_argument('--gpu', type=str, required=True)
|
||||
parser.add_argument('--log', type=str, default=None)
|
||||
return parser
|
||||
|
||||
|
||||
def add_model_args(parser):
|
||||
parser.add_argument('--nhead', type=int, default=6)
|
||||
parser.add_argument('--hdim', type=int, default=50)
|
||||
parser.add_argument('--hidden', type=int, default=300)
|
||||
return parser
|
||||
|
||||
|
||||
def set_gpu(gpu_str):
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str
|
||||
|
||||
|
||||
def set_rng_seeds(seed=None):
|
||||
if seed is None:
|
||||
seed = numpy.random.randint(0, 65536)
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# print('RNG_SEED {}'.format(seed))
|
||||
return seed
|
||||
|
||||
|
||||
class TensorboardCallback(FN.Callback):
|
||||
"""
|
||||
接受以下一个或多个字符串作为参数:
|
||||
- "model"
|
||||
- "loss"
|
||||
- "metric"
|
||||
"""
|
||||
|
||||
def __init__(self, *options):
|
||||
super(TensorboardCallback, self).__init__()
|
||||
args = {"model", "loss", "metric"}
|
||||
for opt in options:
|
||||
if opt not in args:
|
||||
raise ValueError(
|
||||
"Unrecognized argument {}. Expect one of {}".format(opt, args))
|
||||
self.options = options
|
||||
self._summary_writer = None
|
||||
self.graph_added = False
|
||||
|
||||
def on_train_begin(self):
|
||||
save_dir = self.trainer.save_path
|
||||
if save_dir is None:
|
||||
path = os.path.join(
|
||||
"./", 'tensorboard_logs_{}'.format(self.trainer.start_time))
|
||||
else:
|
||||
path = os.path.join(
|
||||
save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time))
|
||||
self._summary_writer = SummaryWriter(path)
|
||||
|
||||
def on_batch_begin(self, batch_x, batch_y, indices):
|
||||
if "model" in self.options and self.graph_added is False:
|
||||
# tesorboardX 这里有大bug,暂时没法画模型图
|
||||
# from fastNLP.core.utils import _build_args
|
||||
# inputs = _build_args(self.trainer.model, **batch_x)
|
||||
# args = tuple([value for value in inputs.values()])
|
||||
# args = args[0] if len(args) == 1 else args
|
||||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2))
|
||||
self.graph_added = True
|
||||
|
||||
def on_backward_begin(self, loss):
|
||||
if "loss" in self.options:
|
||||
self._summary_writer.add_scalar(
|
||||
"loss", loss.item(), global_step=self.trainer.step)
|
||||
|
||||
if "model" in self.options:
|
||||
for name, param in self.trainer.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self._summary_writer.add_scalar(
|
||||
name + "_mean", param.mean(), global_step=self.trainer.step)
|
||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step)
|
||||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(),
|
||||
global_step=self.trainer.step)
|
||||
|
||||
def on_valid_end(self, eval_result, metric_key):
|
||||
if "metric" in self.options:
|
||||
for name, metric in eval_result.items():
|
||||
for metric_key, metric_val in metric.items():
|
||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
|
||||
global_step=self.trainer.step)
|
||||
|
||||
def on_train_end(self):
|
||||
self._summary_writer.close()
|
||||
del self._summary_writer
|
||||
|
||||
def on_exception(self, exception):
|
||||
if hasattr(self, "_summary_writer"):
|
||||
self._summary_writer.close()
|
||||
del self._summary_writer
|
Loading…
Reference in New Issue
Block a user