fastNLP/reproduction/utils.py
yh_cc 2f5d8967a3 1. 适配将Batch修改为pytorch的DataLoader的修改
2. 修改embedding.py中的bug
3. ConllReader默认跳过所有的DOCSTART标签
4. 交换bert的heavy lifting到_bert, 将BertEncoder在bert.py中暴露
5. crf中allow_transition的include_end_start修改为false,以与CRF的默认值适配
6. allow_transition与SpanMetric支持BIOES类型的tag
7. datainfo中增加打印格式化输出
2019-06-17 20:18:07 +08:00

53 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from typing import Union, Dict
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
"""
检查传入dataloader的文件的合法性。如果为合法路径将返回至少包含'train'这个key的dict。类似于下面的结果
{
'train': '/some/path/to/', # 一定包含建词表应该在这上面建立剩下的其它文件应该只需要处理并index。
'test': 'xxx' # 可能有,也可能没有
...
}
如果paths为不合法的将直接进行raise相应的错误
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录将在该目录下寻找train.txt,
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称value是这个文件的路径。
:return:
"""
if isinstance(paths, str):
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
train_fp = os.path.join(paths, 'train.txt')
if not os.path.isfile(train_fp):
raise FileNotFoundError(f"train.txt is not found in folder {paths}.")
files = {'train': train_fp}
for filename in ['dev.txt', 'test.txt']:
fp = os.path.join(paths, filename)
if os.path.isfile(fp):
files[filename.split('.')[0]] = fp
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")
elif isinstance(paths, dict):
if paths:
if 'train' not in paths:
raise KeyError("You have to include `train` in your dict.")
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
if not os.path.isfile(value):
raise TypeError(f"{value} is not a valid file.")
else:
raise TypeError("All keys and values in paths should be str.")
return paths
else:
raise ValueError("Empty paths is not allowed.")
else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.")