Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

This commit is contained in:
x54-729 2022-07-07 14:51:49 +00:00
commit 9706c8bd66
2 changed files with 22 additions and 1 deletions

View File

@ -1036,4 +1036,18 @@ class DataSet:
self.collator.set_ignore(*field_names)
return self.collator
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
@classmethod
def from_datasets(cls, dataset):
"""
Huggingface Dataset 转为 fastNLP DataSet
:param dataset 为实例化好的 huggingface Dataset 对象
"""
from datasets import Dataset
if not isinstance(dataset, DataSet):
raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!")
data_dict = dataset.to_dict()
return DataSet(data_dict)

View File

@ -522,3 +522,10 @@ class TestCase:
ins = Instance(**fields)
# simple print, that is enough.
print(ins)
def test_dataset(self):
from datasets import Dataset as HuggingfaceDataset
# ds = DataSet({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
ds = HuggingfaceDataset.from_dict({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
print(DataSet.from_datasets(ds))
# print(ds.from_datasets())