mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 05:07:44 +08:00
add dataset method
This commit is contained in:
parent
0098c0c896
commit
a3adafea34
@ -1037,4 +1037,18 @@ class DataSet:
|
|||||||
self.collator.set_ignore(*field_names)
|
self.collator.set_ignore(*field_names)
|
||||||
return self.collator
|
return self.collator
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_datasets(cls, dataset):
|
||||||
|
"""
|
||||||
|
将 Huggingface Dataset 转为 fastNLP 的 DataSet
|
||||||
|
|
||||||
|
:param dataset 为实例化好的 huggingface Dataset 对象
|
||||||
|
"""
|
||||||
|
from datasets import Dataset
|
||||||
|
if not isinstance(dataset, DataSet):
|
||||||
|
raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!")
|
||||||
|
|
||||||
|
data_dict = dataset.to_dict()
|
||||||
|
return DataSet(data_dict)
|
@ -522,3 +522,10 @@ class TestCase:
|
|||||||
ins = Instance(**fields)
|
ins = Instance(**fields)
|
||||||
# simple print, that is enough.
|
# simple print, that is enough.
|
||||||
print(ins)
|
print(ins)
|
||||||
|
|
||||||
|
def test_dataset(self):
|
||||||
|
from datasets import Dataset as HuggingfaceDataset
|
||||||
|
# ds = DataSet({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
|
||||||
|
ds = HuggingfaceDataset.from_dict({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
|
||||||
|
print(DataSet.from_datasets(ds))
|
||||||
|
# print(ds.from_datasets())
|
||||||
|
Loading…
Reference in New Issue
Block a user