From bc04b3e7fd7557d9fdd772f90802d4eb8369eb05 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 14 Sep 2018 12:38:38 +0800 Subject: [PATCH] add basic Field support --- fastNLP/data/batch.py | 86 ++++++++++++++++++++++++++++++++++++++++ fastNLP/data/dataset.py | 29 ++++++++++++++ fastNLP/data/field.py | 70 ++++++++++++++++++++++++++++++++ fastNLP/data/instance.py | 38 ++++++++++++++++++ 4 files changed, 223 insertions(+) create mode 100644 fastNLP/data/batch.py create mode 100644 fastNLP/data/dataset.py create mode 100644 fastNLP/data/field.py create mode 100644 fastNLP/data/instance.py diff --git a/fastNLP/data/batch.py b/fastNLP/data/batch.py new file mode 100644 index 00000000..ef5f7d46 --- /dev/null +++ b/fastNLP/data/batch.py @@ -0,0 +1,86 @@ +from collections import defaultdict +import torch + +class Batch(object): + def __init__(self, dataset, sampler, batch_size): + self.dataset = dataset + self.sampler = sampler + self.batch_size = batch_size + + self.idx_list = None + self.curidx = 0 + + def __iter__(self): + self.idx_list = self.sampler(self.dataset) + self.curidx = 0 + self.lengths = self.dataset.get_length() + return self + + def __next__(self): + if self.curidx >= len(self.idx_list): + raise StopIteration + else: + endidx = min(self.curidx + self.batch_size, len(self.idx_list)) + padding_length = {field_name : max(field_length[self.curidx: endidx]) + for field_name, field_length in self.lengths.items()} + + batch_x, batch_y = defaultdict(list), defaultdict(list) + for idx in range(self.curidx, endidx): + x, y = self.dataset.to_tensor(idx, padding_length) + for name, tensor in x.items(): + batch_x[name].append(tensor) + for name, tensor in y.items(): + batch_y[name].append(tensor) + + for batch in (batch_x, batch_y): + for name, tensor_list in batch.items(): + print(name, " ", tensor_list) + batch[name] = torch.stack(tensor_list, dim=0) + self.curidx += endidx + return batch_x, batch_y + + +if __name__ == "__main__": + """simple running example + """ + from field import TextField, LabelField + from instance import Instance + from dataset import DataSet + + texts = ["i am a cat", + "this is a test of new batch", + "haha" + ] + labels = [0, 1, 0] + + # prepare vocabulary + vocab = {} + for text in texts: + for tokens in text.split(): + if tokens not in vocab: + vocab[tokens] = len(vocab) + + # prepare input dataset + data = DataSet() + for text, label in zip(texts, labels): + x = TextField(text.split(), False) + y = LabelField(label, is_target=True) + ins = Instance(text=x, label=y) + data.append(ins) + + # use vocabulary to index data + data.index_field("text", vocab) + + # define naive sampler for batch class + class SeqSampler: + def __call__(self, dataset): + return list(range(len(dataset))) + + # use bacth to iterate dataset + batcher = Batch(data, SeqSampler(), 2) + for epoch in range(3): + for batch_x, batch_y in batcher: + print(batch_x) + print(batch_y) + # do stuff + diff --git a/fastNLP/data/dataset.py b/fastNLP/data/dataset.py new file mode 100644 index 00000000..ffe75494 --- /dev/null +++ b/fastNLP/data/dataset.py @@ -0,0 +1,29 @@ +from collections import defaultdict + + +class DataSet(list): + def __init__(self, name="", instances=None): + list.__init__([]) + self.name = name + if instances is not None: + self.extend(instances) + + def index_all(self, vocab): + for ins in self: + ins.index_all(vocab) + + def index_field(self, field_name, vocab): + for ins in self: + ins.index_field(field_name, vocab) + + def to_tensor(self, idx: int, padding_length: dict): + ins = self[idx] + return ins.to_tensor(padding_length) + + def get_length(self): + lengths = defaultdict(list) + for ins in self: + for field_name, field_length in ins.get_length().items(): + lengths[field_name].append(field_length) + return lengths + diff --git a/fastNLP/data/field.py b/fastNLP/data/field.py new file mode 100644 index 00000000..ada90857 --- /dev/null +++ b/fastNLP/data/field.py @@ -0,0 +1,70 @@ +import torch + +class Field(object): + def __init__(self, is_target: bool): + self.is_target = is_target + + def index(self, vocab): + pass + + def get_length(self): + pass + + def to_tensor(self, padding_length): + pass + + +class TextField(Field): + def __init__(self, text: list, is_target): + """ + :param list text: + """ + super(TextField, self).__init__(is_target) + self.text = text + self._index = None + + def index(self, vocab): + if self._index is None: + self._index = [vocab[c] for c in self.text] + else: + print('error') + return self._index + + def get_length(self): + return len(self.text) + + def to_tensor(self, padding_length: int): + pads = [] + if self._index is None: + print('error') + if padding_length > self.get_length(): + pads = [0 for i in range(padding_length - self.get_length())] + # (length, ) + return torch.LongTensor(self._index + pads) + + +class LabelField(Field): + def __init__(self, label, is_target=True): + super(LabelField, self).__init__(is_target) + self.label = label + self._index = None + + def get_length(self): + return 1 + + def index(self, vocab): + if self._index is None: + self._index = vocab[self.label] + else: + pass + return self._index + + def to_tensor(self, padding_length): + if self._index is None: + return torch.LongTensor([self.label]) + else: + return torch.LongTensor([self._index]) + +if __name__ == "__main__": + tf = TextField("test the code".split()) + diff --git a/fastNLP/data/instance.py b/fastNLP/data/instance.py new file mode 100644 index 00000000..4b78dfc3 --- /dev/null +++ b/fastNLP/data/instance.py @@ -0,0 +1,38 @@ +class Instance(object): + def __init__(self, **fields): + self.fields = fields + self.has_index = False + self.indexes = {} + + def add_field(self, field_name, field): + self.fields[field_name] = field + + def get_length(self): + length = {name : field.get_length() for name, field in self.fields.items()} + return length + + def index_field(self, field_name, vocab): + """use `vocab` to index certain field + """ + self.indexes[field_name] = self.fields[field_name].index(vocab) + + def index_all(self, vocab): + """use `vocab` to index all fields + """ + if self.has_index: + print("error") + return self.indexes + indexes = {name : field.index(vocab) for name, field in self.fields.items()} + self.indexes = indexes + return indexes + + def to_tensor(self, padding_length: dict): + tensorX = {} + tensorY = {} + for name, field in self.fields.items(): + if field.is_target: + tensorY[name] = field.to_tensor(padding_length[name]) + else: + tensorX[name] = field.to_tensor(padding_length[name]) + + return tensorX, tensorY