mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-01 19:57:34 +08:00
fix bug in test
This commit is contained in:
parent
089009f9f2
commit
66a7cf084e
@ -4,9 +4,9 @@ from typing import Union, Dict
|
||||
|
||||
from ...core.const import Const
|
||||
from ...core.vocabulary import Vocabulary
|
||||
from ...io.base_loader import DataInfo, DataSetLoader
|
||||
from ...io.dataset_loader import JsonLoader, CSVLoader
|
||||
from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
|
||||
from ..base_loader import DataInfo, DataSetLoader
|
||||
from ..dataset_loader import JsonLoader, CSVLoader
|
||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
|
||||
from ...modules.encoder._bert import BertTokenizer
|
||||
|
||||
|
||||
|
@ -16,8 +16,6 @@ __all__ = [
|
||||
'CSVLoader',
|
||||
'JsonLoader',
|
||||
'ConllLoader',
|
||||
'SNLILoader',
|
||||
'SSTLoader',
|
||||
'PeopleDailyCorpusLoader',
|
||||
'Conll2003Loader',
|
||||
]
|
||||
@ -30,7 +28,6 @@ from ..core.dataset import DataSet
|
||||
from ..core.instance import Instance
|
||||
from .file_reader import _read_csv, _read_json, _read_conll
|
||||
from .base_loader import DataSetLoader, DataInfo
|
||||
from .data_loader.sst import SSTLoader
|
||||
from ..core.const import Const
|
||||
from ..modules.encoder._bert import BertTokenizer
|
||||
|
||||
@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
|
||||
else:
|
||||
instance = Instance(words=sent_words)
|
||||
data_set.append(instance)
|
||||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
|
||||
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN)
|
||||
return data_set
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import os
|
||||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader
|
||||
from fastNLP.io.dataset_loader import SSTLoader, SNLILoader
|
||||
from fastNLP.io.data_loader import SSTLoader, SNLILoader
|
||||
from reproduction.text_classification.data.yelpLoader import yelpLoader
|
||||
|
||||
|
||||
@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase):
|
||||
print(info.vocabs)
|
||||
print(info.datasets)
|
||||
os.remove(train), os.remove(test)
|
||||
|
||||
def test_import(self):
|
||||
import fastNLP
|
||||
from fastNLP.io import SNLILoader
|
||||
ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
|
||||
get_index=True, seq_len_type='seq_len')
|
||||
assert 'train' in ds.datasets
|
||||
assert len(ds.datasets) == 1
|
||||
assert len(ds.datasets['train']) == 3
|
||||
|
Loading…
Reference in New Issue
Block a user