fix bug in test

This commit is contained in:
xuyige 2019-07-06 01:36:11 +08:00
parent 089009f9f2
commit 66a7cf084e
3 changed files with 14 additions and 8 deletions

View File

@ -4,9 +4,9 @@ from typing import Union, Dict
from ...core.const import Const
from ...core.vocabulary import Vocabulary
from ...io.base_loader import DataInfo, DataSetLoader
from ...io.dataset_loader import JsonLoader, CSVLoader
from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ..base_loader import DataInfo, DataSetLoader
from ..dataset_loader import JsonLoader, CSVLoader
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ...modules.encoder._bert import BertTokenizer

View File

@ -16,8 +16,6 @@ __all__ = [
'CSVLoader',
'JsonLoader',
'ConllLoader',
'SNLILoader',
'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
]
@ -30,7 +28,6 @@ from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader, DataInfo
from .data_loader.sst import SSTLoader
from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer
@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
else:
instance = Instance(words=sent_words)
data_set.append(instance)
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN)
return data_set

View File

@ -1,7 +1,7 @@
import unittest
import os
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader
from fastNLP.io.dataset_loader import SSTLoader, SNLILoader
from fastNLP.io.data_loader import SSTLoader, SNLILoader
from reproduction.text_classification.data.yelpLoader import yelpLoader
@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase):
print(info.vocabs)
print(info.datasets)
os.remove(train), os.remove(test)
def test_import(self):
import fastNLP
from fastNLP.io import SNLILoader
ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
get_index=True, seq_len_type='seq_len')
assert 'train' in ds.datasets
assert len(ds.datasets) == 1
assert len(ds.datasets['train']) == 3