mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
Sampler中增加了一个BucketSampler, CWS的训练基本可以实现
This commit is contained in:
parent
69a138eb18
commit
3cb98ddcf2
@ -72,7 +72,8 @@ class DataSet(object):
|
||||
self.field_arrays[name].append(field)
|
||||
|
||||
def add_field(self, name, fields):
|
||||
assert len(self) == len(fields)
|
||||
if len(self.field_arrays)!=0:
|
||||
assert len(self) == len(fields)
|
||||
self.field_arrays[name] = FieldArray(name, fields)
|
||||
|
||||
def delete_field(self, name):
|
||||
|
@ -28,11 +28,15 @@ class FieldArray(object):
|
||||
return self.content[idxes]
|
||||
assert self.need_tensor is True
|
||||
batch_size = len(idxes)
|
||||
max_len = max([len(self.content[i]) for i in idxes])
|
||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)
|
||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
|
||||
if isinstance(self.content[0], int) or isinstance(self.content[0], float):
|
||||
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0]))
|
||||
else:
|
||||
max_len = max([len(self.content[i]) for i in idxes])
|
||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)
|
||||
|
||||
for i, idx in enumerate(idxes):
|
||||
array[i][:len(self.content[idx])] = self.content[idx]
|
||||
for i, idx in enumerate(idxes):
|
||||
array[i][:len(self.content[idx])] = self.content[idx]
|
||||
return array
|
||||
|
||||
def __len__(self):
|
||||
|
@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from itertools import chain
|
||||
|
||||
def convert_to_torch_tensor(data_list, use_cuda):
|
||||
"""Convert lists into (cuda) Tensors.
|
||||
@ -43,6 +43,47 @@ class RandomSampler(BaseSampler):
|
||||
def __call__(self, data_set):
|
||||
return list(np.random.permutation(len(data_set)))
|
||||
|
||||
class BucketSampler(BaseSampler):
|
||||
|
||||
def __init__(self, num_buckets=10, batch_size=32):
|
||||
self.num_buckets = num_buckets
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __call__(self, data_set):
|
||||
assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now."
|
||||
|
||||
seq_lens = data_set['seq_lens'].content
|
||||
total_sample_num = len(seq_lens)
|
||||
|
||||
bucket_indexes = []
|
||||
num_sample_per_bucket = total_sample_num//self.num_buckets
|
||||
for i in range(self.num_buckets):
|
||||
bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)])
|
||||
bucket_indexes[-1][1] = total_sample_num
|
||||
|
||||
sorted_seq_lens = list(sorted([(idx, seq_len) for
|
||||
idx, seq_len in zip(range(total_sample_num), seq_lens)],
|
||||
key=lambda x:x[1]))
|
||||
|
||||
batchs = []
|
||||
|
||||
left_init_indexes = []
|
||||
for b_idx in range(self.num_buckets):
|
||||
start_idx = bucket_indexes[b_idx][0]
|
||||
end_idx = bucket_indexes[b_idx][1]
|
||||
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
|
||||
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
|
||||
num_batch_per_bucket = len(left_init_indexes)//self.batch_size
|
||||
np.random.shuffle(left_init_indexes)
|
||||
for i in range(num_batch_per_bucket):
|
||||
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size])
|
||||
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:]
|
||||
|
||||
np.random.shuffle(batchs)
|
||||
|
||||
return list(chain(*batchs))
|
||||
|
||||
|
||||
|
||||
def simple_sort_bucketing(lengths):
|
||||
"""
|
||||
|
@ -68,7 +68,6 @@ class CWSBiLSTMEncoder(BaseModel):
|
||||
if not bigrams is None:
|
||||
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
|
||||
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)
|
||||
|
||||
sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True)
|
||||
packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True)
|
||||
|
||||
@ -97,36 +96,22 @@ class CWSBiLSTMSegApp(BaseModel):
|
||||
|
||||
def forward(self, batch_dict):
|
||||
device = self.parameters().__next__().device
|
||||
chars = batch_dict['indexed_chars_list'].to(device)
|
||||
if 'bigram' in batch_dict:
|
||||
bigrams = batch_dict['indexed_chars_list'].to(device)
|
||||
chars = batch_dict['indexed_chars_list'].to(device).long()
|
||||
if 'indexed_bigrams_list' in batch_dict:
|
||||
bigrams = batch_dict['indexed_bigrams_list'].to(device).long()
|
||||
else:
|
||||
bigrams = None
|
||||
seq_lens = batch_dict['seq_lens'].to(device)
|
||||
seq_lens = batch_dict['seq_lens'].to(device).long()
|
||||
|
||||
feats = self.encoder_model(chars, bigrams, seq_lens)
|
||||
probs = self.decoder_model(feats)
|
||||
|
||||
pred_dict = {}
|
||||
pred_dict['seq_lens'] = seq_lens
|
||||
pred_dict['pred_prob'] = probs
|
||||
pred_dict['pred_probs'] = probs
|
||||
|
||||
return pred_dict
|
||||
|
||||
def predict(self, batch_dict):
|
||||
pass
|
||||
|
||||
|
||||
def loss_fn(self, pred_dict, true_dict):
|
||||
seq_lens = pred_dict['seq_lens']
|
||||
masks = seq_lens_to_mask(seq_lens).float()
|
||||
|
||||
pred_prob = pred_dict['pred_prob']
|
||||
true_y = true_dict['tags']
|
||||
|
||||
# TODO 当前把loss写死了
|
||||
loss = F.cross_entropy(pred_prob.view(-1, self.tag_size),
|
||||
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks)
|
||||
|
||||
|
||||
return loss
|
@ -110,9 +110,9 @@ class CWSTagProcessor(Processor):
|
||||
for ins in dataset:
|
||||
sentence = ins[self.field_name]
|
||||
tag_list = self._generate_tag(sentence)
|
||||
new_tag_field = SeqLabelField(tag_list)
|
||||
ins[self.new_added_field_name] = new_tag_field
|
||||
ins[self.new_added_field_name] = tag_list
|
||||
dataset.set_is_target(**{self.new_added_field_name:True})
|
||||
dataset.set_need_tensor(**{self.new_added_field_name:True})
|
||||
return dataset
|
||||
|
||||
def _tags_from_word_len(self, word_len):
|
||||
|
@ -1,6 +1,4 @@
|
||||
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.api.pipeline import Pipeline
|
||||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
|
||||
from fastNLP.api.processor import IndexerProcessor
|
||||
@ -143,7 +141,7 @@ def decode_iterator(model, batcher):
|
||||
from reproduction.chinese_word_segment.utils import FocalLoss
|
||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.sampler import RandomSampler
|
||||
from fastNLP.core.sampler import BucketSampler
|
||||
from fastNLP.core.sampler import SequentialSampler
|
||||
|
||||
import torch
|
||||
@ -159,6 +157,7 @@ cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100,
|
||||
bigram_embed_dim=100, num_bigram_per_char=8,
|
||||
hidden_size=200, bidirectional=True, embed_drop_p=None,
|
||||
num_layers=1, tag_size=tag_size)
|
||||
cws_model.cuda()
|
||||
|
||||
num_epochs = 3
|
||||
loss_fn = FocalLoss(class_num=tag_size)
|
||||
@ -167,7 +166,7 @@ optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01)
|
||||
|
||||
print_every = 50
|
||||
batch_size = 32
|
||||
tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False)
|
||||
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
|
||||
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False)
|
||||
num_batch_per_epoch = len(tr_dataset) // batch_size
|
||||
best_f1 = 0
|
||||
@ -181,10 +180,12 @@ for num_epoch in range(num_epochs):
|
||||
cws_model.train()
|
||||
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
|
||||
pred_dict = cws_model(batch_x) # B x L x tag_size
|
||||
seq_lens = batch_x['seq_lens']
|
||||
masks = seq_lens_to_mask(seq_lens)
|
||||
tags = batch_y['tags']
|
||||
loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size),
|
||||
|
||||
seq_lens = pred_dict['seq_lens']
|
||||
masks = seq_lens_to_mask(seq_lens).float()
|
||||
tags = batch_y['tags'].long().to(seq_lens.device)
|
||||
|
||||
loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size),
|
||||
tags.view(-1)) * masks.view(-1)) / torch.sum(masks)
|
||||
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float())
|
||||
|
||||
@ -201,20 +202,20 @@ for num_epoch in range(num_epochs):
|
||||
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every))
|
||||
avg_loss = 0
|
||||
pbar.update(print_every)
|
||||
|
||||
# 验证集
|
||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
|
||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
|
||||
pre*100,
|
||||
rec*100))
|
||||
if best_f1<f1:
|
||||
best_f1 = f1
|
||||
# 缓存最佳的parameter,可能之后会用于保存
|
||||
best_state_dict = {
|
||||
key:value.clone() for key, value in
|
||||
cws_model.state_dict().items()
|
||||
}
|
||||
best_epoch = num_epoch
|
||||
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False)
|
||||
# 验证集
|
||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher)
|
||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100,
|
||||
pre*100,
|
||||
rec*100))
|
||||
if best_f1<f1:
|
||||
best_f1 = f1
|
||||
# 缓存最佳的parameter,可能之后会用于保存
|
||||
best_state_dict = {
|
||||
key:value.clone() for key, value in
|
||||
cws_model.state_dict().items()
|
||||
}
|
||||
best_epoch = num_epoch
|
||||
|
||||
|
||||
# 4. 组装需要存下的内容
|
||||
@ -224,4 +225,6 @@ pp.add_processor(sp_proc)
|
||||
pp.add_processor(char_proc)
|
||||
pp.add_processor(bigram_proc)
|
||||
pp.add_processor(char_index_proc)
|
||||
pp.add_processor(bigram_index_proc)
|
||||
pp.add_processor(bigram_index_proc)
|
||||
pp.add_processor(seq_len_proc)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user