Sampler中增加了一个BucketSampler, CWS的训练基本可以实现

This commit is contained in:
yh_cc 2018-11-10 14:46:38 +08:00
parent 69a138eb18
commit 3cb98ddcf2
6 changed files with 85 additions and 51 deletions

View File

@ -72,7 +72,8 @@ class DataSet(object):
self.field_arrays[name].append(field)
def add_field(self, name, fields):
assert len(self) == len(fields)
if len(self.field_arrays)!=0:
assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields)
def delete_field(self, name):

View File

@ -28,11 +28,15 @@ class FieldArray(object):
return self.content[idxes]
assert self.need_tensor is True
batch_size = len(idxes)
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时不需要padding需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float):
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0]))
else:
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)
for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx]
for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx]
return array
def __len__(self):

View File

@ -1,6 +1,6 @@
import numpy as np
import torch
from itertools import chain
def convert_to_torch_tensor(data_list, use_cuda):
"""Convert lists into (cuda) Tensors.
@ -43,6 +43,47 @@ class RandomSampler(BaseSampler):
def __call__(self, data_set):
return list(np.random.permutation(len(data_set)))
class BucketSampler(BaseSampler):
def __init__(self, num_buckets=10, batch_size=32):
self.num_buckets = num_buckets
self.batch_size = batch_size
def __call__(self, data_set):
assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now."
seq_lens = data_set['seq_lens'].content
total_sample_num = len(seq_lens)
bucket_indexes = []
num_sample_per_bucket = total_sample_num//self.num_buckets
for i in range(self.num_buckets):
bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)])
bucket_indexes[-1][1] = total_sample_num
sorted_seq_lens = list(sorted([(idx, seq_len) for
idx, seq_len in zip(range(total_sample_num), seq_lens)],
key=lambda x:x[1]))
batchs = []
left_init_indexes = []
for b_idx in range(self.num_buckets):
start_idx = bucket_indexes[b_idx][0]
end_idx = bucket_indexes[b_idx][1]
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
num_batch_per_bucket = len(left_init_indexes)//self.batch_size
np.random.shuffle(left_init_indexes)
for i in range(num_batch_per_bucket):
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size])
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:]
np.random.shuffle(batchs)
return list(chain(*batchs))
def simple_sort_bucketing(lengths):
"""

View File

@ -68,7 +68,6 @@ class CWSBiLSTMEncoder(BaseModel):
if not bigrams is None:
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True)
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True)
@ -97,36 +96,22 @@ class CWSBiLSTMSegApp(BaseModel):
def forward(self, batch_dict):
device = self.parameters().__next__().device
chars = batch_dict['indexed_chars_list'].to(device)
if 'bigram' in batch_dict:
bigrams = batch_dict['indexed_chars_list'].to(device)
chars = batch_dict['indexed_chars_list'].to(device).long()
if 'indexed_bigrams_list' in batch_dict:
bigrams = batch_dict['indexed_bigrams_list'].to(device).long()
else:
bigrams = None
seq_lens = batch_dict['seq_lens'].to(device)
seq_lens = batch_dict['seq_lens'].to(device).long()
feats = self.encoder_model(chars, bigrams, seq_lens)
probs = self.decoder_model(feats)
pred_dict = {}
pred_dict['seq_lens'] = seq_lens
pred_dict['pred_prob'] = probs
pred_dict['pred_probs'] = probs
return pred_dict
def predict(self, batch_dict):
pass
def loss_fn(self, pred_dict, true_dict):
seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()
pred_prob = pred_dict['pred_prob']
true_y = true_dict['tags']
# TODO 当前把loss写死了
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size),
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks)
return loss

View File

@ -110,9 +110,9 @@ class CWSTagProcessor(Processor):
for ins in dataset:
sentence = ins[self.field_name]
tag_list = self._generate_tag(sentence)
new_tag_field = SeqLabelField(tag_list)
ins[self.new_added_field_name] = new_tag_field
ins[self.new_added_field_name] = tag_list
dataset.set_is_target(**{self.new_added_field_name:True})
dataset.set_need_tensor(**{self.new_added_field_name:True})
return dataset
def _tags_from_word_len(self, word_len):

View File

@ -1,6 +1,4 @@
from fastNLP.core.instance import Instance
from fastNLP.core.dataset import DataSet
from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
from fastNLP.api.processor import IndexerProcessor
@ -143,7 +141,7 @@ def decode_iterator(model, batcher):
from reproduction.chinese_word_segment.utils import FocalLoss
from reproduction.chinese_word_segment.utils import seq_lens_to_mask
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import BucketSampler
from fastNLP.core.sampler import SequentialSampler
import torch
@ -159,6 +157,7 @@ cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100,
bigram_embed_dim=100, num_bigram_per_char=8,
hidden_size=200, bidirectional=True, embed_drop_p=None,
num_layers=1, tag_size=tag_size)
cws_model.cuda()
num_epochs = 3
loss_fn = FocalLoss(class_num=tag_size)
@ -167,7 +166,7 @@ optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01)
print_every = 50
batch_size = 32
tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False)
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False)
num_batch_per_epoch = len(tr_dataset) // batch_size
best_f1 = 0
@ -181,10 +180,12 @@ for num_epoch in range(num_epochs):
cws_model.train()
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
pred_dict = cws_model(batch_x) # B x L x tag_size
seq_lens = batch_x['seq_lens']
masks = seq_lens_to_mask(seq_lens)
tags = batch_y['tags']
loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size),
seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()
tags = batch_y['tags'].long().to(seq_lens.device)
loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size),
tags.view(-1)) * masks.view(-1)) / torch.sum(masks)
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float())
@ -201,20 +202,20 @@ for num_epoch in range(num_epochs):
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every))
avg_loss = 0
pbar.update(print_every)
# 验证集
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
pre*100,
rec*100))
if best_f1<f1:
best_f1 = f1
# 缓存最佳的parameter可能之后会用于保存
best_state_dict = {
key:value.clone() for key, value in
cws_model.state_dict().items()
}
best_epoch = num_epoch
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
# 验证集
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
pre*100,
rec*100))
if best_f1<f1:
best_f1 = f1
# 缓存最佳的parameter可能之后会用于保存
best_state_dict = {
key:value.clone() for key, value in
cws_model.state_dict().items()
}
best_epoch = num_epoch
# 4. 组装需要存下的内容
@ -224,4 +225,6 @@ pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(bigram_proc)
pp.add_processor(char_index_proc)
pp.add_processor(bigram_index_proc)
pp.add_processor(bigram_index_proc)
pp.add_processor(seq_len_proc)