mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 19:57:34 +08:00
Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0
This commit is contained in:
commit
9706c8bd66
@ -1036,4 +1036,18 @@ class DataSet:
|
||||
self.collator.set_ignore(*field_names)
|
||||
return self.collator
|
||||
else:
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")
|
||||
|
||||
@classmethod
|
||||
def from_datasets(cls, dataset):
|
||||
"""
|
||||
将 Huggingface Dataset 转为 fastNLP 的 DataSet
|
||||
|
||||
:param dataset 为实例化好的 huggingface Dataset 对象
|
||||
"""
|
||||
from datasets import Dataset
|
||||
if not isinstance(dataset, DataSet):
|
||||
raise ValueError(f"Support huggingface dataset, but is {type(dataset)}!")
|
||||
|
||||
data_dict = dataset.to_dict()
|
||||
return DataSet(data_dict)
|
@ -522,3 +522,10 @@ class TestCase:
|
||||
ins = Instance(**fields)
|
||||
# simple print, that is enough.
|
||||
print(ins)
|
||||
|
||||
def test_dataset(self):
|
||||
from datasets import Dataset as HuggingfaceDataset
|
||||
# ds = DataSet({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
|
||||
ds = HuggingfaceDataset.from_dict({"x": ["11sxa", "1sasz"]*100, "y": [0, 1]*100})
|
||||
print(DataSet.from_datasets(ds))
|
||||
# print(ds.from_datasets())
|
||||
|
Loading…
Reference in New Issue
Block a user