fastNLP/reproduction/utils.py
Yige XU ca94a00f41
Merge pull request #177 from SrWYG/dev0.5.0
[new] text_classification:charcnn,dpcnn,HAN,awd-lstm
2019-07-08 12:44:49 +08:00

72 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from typing import Union, Dict
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
"""
检查传入dataloader的文件的合法性。如果为合法路径将返回至少包含'train'这个key的dict。类似于下面的结果
{
'train': '/some/path/to/', # 一定包含建词表应该在这上面建立剩下的其它文件应该只需要处理并index。
'test': 'xxx' # 可能有,也可能没有
...
}
如果paths为不合法的将直接进行raise相应的错误
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录将在该目录下寻找train(文件名
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称value是这个文件的路径。
:return:
"""
if isinstance(paths, str):
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
filenames = os.listdir(paths)
files = {}
for filename in filenames:
path_pair = None
if 'train' in filename:
path_pair = ('train', filename)
if 'dev' in filename:
if path_pair:
raise Exception("File:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0]))
path_pair = ('dev', filename)
if 'test' in filename:
if path_pair:
raise Exception("File:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0]))
path_pair = ('test', filename)
if path_pair:
if path_pair[0] in files:
raise RuntimeError(f"Multiple file under {paths} have '{path_pair[0]}' in their filename.")
files[path_pair[0]] = os.path.join(paths, path_pair[1])
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")
elif isinstance(paths, dict):
if paths:
if 'train' not in paths:
raise KeyError("You have to include `train` in your dict.")
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
if not os.path.isfile(value):
raise TypeError(f"{value} is not a valid file.")
else:
raise TypeError("All keys and values in paths should be str.")
return paths
else:
raise ValueError("Empty paths is not allowed.")
else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.")
def get_tokenizer():
try:
import spacy
spacy.prefer_gpu()
en = spacy.load('en')
print('use spacy tokenizer')
return lambda x: [w.text for w in en.tokenizer(x)]
except Exception as e:
print('use raw tokenizer')
return lambda x: x.split()