mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 12:47:35 +08:00
add dataset method
This commit is contained in:
parent
0098c0c896
commit
a3adafea34
@ -1037,4 +1037,18 @@ class DataSet:
|
||||
self.collator.set_ignore(*field_names)
|
||||
return self.collator
|
||||
else:
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||
|
||||
@classmethod
|
||||
def from_datasets(cls, dataset):
|
||||
"""
|
||||
将 Huggingface Dataset 转为 fastNLP 的 DataSet
|
||||
|
||||
:param dataset 为实例化好的 huggingface Dataset 对象
|
||||
"""
|
||||
from datasets import Dataset
|
||||
if not isinstance(dataset, DataSet):
|
||||
raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!")
|
||||
|
||||
data_dict = dataset.to_dict()
|
||||
return DataSet(data_dict)
|
@ -522,3 +522,10 @@ class TestCase:
|
||||
ins = Instance(**fields)
|
||||
# simple print, that is enough.
|
||||
print(ins)
|
||||
|
||||
def test_dataset(self):
|
||||
from datasets import Dataset as HuggingfaceDataset
|
||||
# ds = DataSet({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
|
||||
ds = HuggingfaceDataset.from_dict({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
|
||||
print(DataSet.from_datasets(ds))
|
||||
# print(ds.from_datasets())
|
||||
|
Loading…
Reference in New Issue
Block a user