mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-11 10:05:30 +08:00
update
This commit is contained in:
parent
baac29cfa0
commit
6c09f53c6b
@ -99,12 +99,26 @@ class DataSet(list):
|
||||
return self
|
||||
|
||||
def update_vocab(self, **name_vocab):
|
||||
"""using certain field data to update vocabulary.
|
||||
|
||||
e.g. ::
|
||||
|
||||
# update word vocab and label vocab seperately
|
||||
dataset.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
|
||||
"""
|
||||
for field_name, vocab in name_vocab.items():
|
||||
for ins in self:
|
||||
vocab.update(ins[field_name].contents())
|
||||
return self
|
||||
|
||||
def set_origin_len(self, origin_field, origin_len_name=None):
|
||||
"""make dataset tensor output contain origin_len field.
|
||||
|
||||
e.g. ::
|
||||
|
||||
# output "word_seq_origin_len", lengths based on "word_seq" field
|
||||
dataset.set_origin_len("word_seq")
|
||||
"""
|
||||
if origin_field is None:
|
||||
self.origin_len = None
|
||||
else:
|
||||
|
@ -75,6 +75,13 @@ class DataSetLoader(BaseLoader):
|
||||
super(DataSetLoader, self).__init__()
|
||||
|
||||
def load(self, path):
|
||||
""" load data in `path` into a dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert(self, data):
|
||||
"""convert list of data into dataset
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
class RawDataSetLoader(DataSetLoader):
|
||||
|
Loading…
Reference in New Issue
Block a user