From 3cb98ddcf2dfb30902490c41d0d1129d8ead57ff Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 10 Nov 2018 14:46:38 +0800 Subject: [PATCH] =?UTF-8?q?Sampler=E4=B8=AD=E5=A2=9E=E5=8A=A0=E4=BA=86?= =?UTF-8?q?=E4=B8=80=E4=B8=AABucketSampler,=20CWS=E7=9A=84=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E5=9F=BA=E6=9C=AC=E5=8F=AF=E4=BB=A5=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 3 +- fastNLP/core/fieldarray.py | 12 +++-- fastNLP/core/sampler.py | 43 +++++++++++++++- .../chinese_word_segment/models/cws_model.py | 25 ++-------- .../process/cws_processor.py | 4 +- .../chinese_word_segment/train_context.py | 49 ++++++++++--------- 6 files changed, 85 insertions(+), 51 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index cffe95a9..e3162356 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -72,7 +72,8 @@ class DataSet(object): self.field_arrays[name].append(field) def add_field(self, name, fields): - assert len(self) == len(fields) + if len(self.field_arrays)!=0: + assert len(self) == len(fields) self.field_arrays[name] = FieldArray(name, fields) def delete_field(self, name): diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index a08e7f12..f2d612f9 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -28,11 +28,15 @@ class FieldArray(object): return self.content[idxes] assert self.need_tensor is True batch_size = len(idxes) - max_len = max([len(self.content[i]) for i in idxes]) - array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) + # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 + if isinstance(self.content[0], int) or isinstance(self.content[0], float): + array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0])) + else: + max_len = max([len(self.content[i]) for i in idxes]) + array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) - for i, idx in enumerate(idxes): - array[i][:len(self.content[idx])] = self.content[idx] + for i, idx in enumerate(idxes): + array[i][:len(self.content[idx])] = self.content[idx] return array def __len__(self): diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 74f67125..d2d1b301 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -1,6 +1,6 @@ import numpy as np import torch - +from itertools import chain def convert_to_torch_tensor(data_list, use_cuda): """Convert lists into (cuda) Tensors. @@ -43,6 +43,47 @@ class RandomSampler(BaseSampler): def __call__(self, data_set): return list(np.random.permutation(len(data_set))) +class BucketSampler(BaseSampler): + + def __init__(self, num_buckets=10, batch_size=32): + self.num_buckets = num_buckets + self.batch_size = batch_size + + def __call__(self, data_set): + assert 'seq_lens' in data_set, "BuckectSampler only support data_set with seq_lens right now." + + seq_lens = data_set['seq_lens'].content + total_sample_num = len(seq_lens) + + bucket_indexes = [] + num_sample_per_bucket = total_sample_num//self.num_buckets + for i in range(self.num_buckets): + bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)]) + bucket_indexes[-1][1] = total_sample_num + + sorted_seq_lens = list(sorted([(idx, seq_len) for + idx, seq_len in zip(range(total_sample_num), seq_lens)], + key=lambda x:x[1])) + + batchs = [] + + left_init_indexes = [] + for b_idx in range(self.num_buckets): + start_idx = bucket_indexes[b_idx][0] + end_idx = bucket_indexes[b_idx][1] + sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx] + left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens]) + num_batch_per_bucket = len(left_init_indexes)//self.batch_size + np.random.shuffle(left_init_indexes) + for i in range(num_batch_per_bucket): + batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size]) + left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:] + + np.random.shuffle(batchs) + + return list(chain(*batchs)) + + def simple_sort_bucketing(lengths): """ diff --git a/reproduction/chinese_word_segment/models/cws_model.py b/reproduction/chinese_word_segment/models/cws_model.py index 1fc1af26..b46a1940 100644 --- a/reproduction/chinese_word_segment/models/cws_model.py +++ b/reproduction/chinese_word_segment/models/cws_model.py @@ -68,7 +68,6 @@ class CWSBiLSTMEncoder(BaseModel): if not bigrams is None: bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) - sorted_lens, sorted_indices = torch.sort(seq_lens, descending=True) packed_x = nn.utils.rnn.pack_padded_sequence(x_tensor[sorted_indices], sorted_lens, batch_first=True) @@ -97,36 +96,22 @@ class CWSBiLSTMSegApp(BaseModel): def forward(self, batch_dict): device = self.parameters().__next__().device - chars = batch_dict['indexed_chars_list'].to(device) - if 'bigram' in batch_dict: - bigrams = batch_dict['indexed_chars_list'].to(device) + chars = batch_dict['indexed_chars_list'].to(device).long() + if 'indexed_bigrams_list' in batch_dict: + bigrams = batch_dict['indexed_bigrams_list'].to(device).long() else: bigrams = None - seq_lens = batch_dict['seq_lens'].to(device) + seq_lens = batch_dict['seq_lens'].to(device).long() feats = self.encoder_model(chars, bigrams, seq_lens) probs = self.decoder_model(feats) pred_dict = {} pred_dict['seq_lens'] = seq_lens - pred_dict['pred_prob'] = probs + pred_dict['pred_probs'] = probs return pred_dict def predict(self, batch_dict): pass - - def loss_fn(self, pred_dict, true_dict): - seq_lens = pred_dict['seq_lens'] - masks = seq_lens_to_mask(seq_lens).float() - - pred_prob = pred_dict['pred_prob'] - true_y = true_dict['tags'] - - # TODO 当前把loss写死了 - loss = F.cross_entropy(pred_prob.view(-1, self.tag_size), - true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks) - - - return loss \ No newline at end of file diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index 27a6fb1d..e93431ff 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -110,9 +110,9 @@ class CWSTagProcessor(Processor): for ins in dataset: sentence = ins[self.field_name] tag_list = self._generate_tag(sentence) - new_tag_field = SeqLabelField(tag_list) - ins[self.new_added_field_name] = new_tag_field + ins[self.new_added_field_name] = tag_list dataset.set_is_target(**{self.new_added_field_name:True}) + dataset.set_need_tensor(**{self.new_added_field_name:True}) return dataset def _tags_from_word_len(self, word_len): diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index c5e7b2a4..e43f8a24 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -1,6 +1,4 @@ -from fastNLP.core.instance import Instance -from fastNLP.core.dataset import DataSet from fastNLP.api.pipeline import Pipeline from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor from fastNLP.api.processor import IndexerProcessor @@ -143,7 +141,7 @@ def decode_iterator(model, batcher): from reproduction.chinese_word_segment.utils import FocalLoss from reproduction.chinese_word_segment.utils import seq_lens_to_mask from fastNLP.core.batch import Batch -from fastNLP.core.sampler import RandomSampler +from fastNLP.core.sampler import BucketSampler from fastNLP.core.sampler import SequentialSampler import torch @@ -159,6 +157,7 @@ cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100, bigram_embed_dim=100, num_bigram_per_char=8, hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=tag_size) +cws_model.cuda() num_epochs = 3 loss_fn = FocalLoss(class_num=tag_size) @@ -167,7 +166,7 @@ optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01) print_every = 50 batch_size = 32 -tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False) +tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) num_batch_per_epoch = len(tr_dataset) // batch_size best_f1 = 0 @@ -181,10 +180,12 @@ for num_epoch in range(num_epochs): cws_model.train() for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): pred_dict = cws_model(batch_x) # B x L x tag_size - seq_lens = batch_x['seq_lens'] - masks = seq_lens_to_mask(seq_lens) - tags = batch_y['tags'] - loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size), + + seq_lens = pred_dict['seq_lens'] + masks = seq_lens_to_mask(seq_lens).float() + tags = batch_y['tags'].long().to(seq_lens.device) + + loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), tags.view(-1)) * masks.view(-1)) / torch.sum(masks) # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) @@ -201,20 +202,20 @@ for num_epoch in range(num_epochs): pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) avg_loss = 0 pbar.update(print_every) - - # 验证集 - pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) - print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, - pre*100, - rec*100)) - if best_f1