fastNLP/legacy/api/converter.py
ChenXin 881ce01762
Dev0.4.0 (#149)
* 1. CRF增加支持bmeso类型的tag 2. vocabulary中增加注释

* BucketSampler增加一条错误检测

* 1.修改ClipGradientCallback的bug;删除LRSchedulerCallback中的print,之后应该传入pbar进行打印;2.增加MLP注释

* update MLP module

* 增加metric注释;修改trainer save过程中的bug

* Update README.md

fix tutorial link

* Add ENAS (Efficient Neural Architecture Search)

* add ignore_type in DataSet.add_field

* * AutoPadder will not pad when dtype is None
* add ignore_type in DataSet.apply

* 修复fieldarray中padder潜在bug

* 修复crf中typo; 以及可能导致数值不稳定的地方

* 修复CRF中可能存在的bug

* change two default init arguments of Trainer into None

* Changes to Callbacks:
* 给callback添加给定几个只读属性
* 通过manager设置这些属性
* 代码优化,减轻@transfer的负担

* * 将enas相关代码放到automl目录下
* 修复fast_param_mapping的一个bug
* Trainer添加自动创建save目录
* Vocabulary的打印,显示内容

* * 给vocabulary添加遍历方法

* 修复CRF为负数的bug

* add SQuAD metric

* add sigmoid activate function in MLP

* - add star transformer model
- add ConllLoader, for all kinds of conll-format files
- add JsonLoader, for json-format files
- add SSTLoader, for SST-2 & SST-5
- change Callback interface
- fix batch multi-process when killed
- add README to list models and their performance

* - fix test

* - fix callback & tests

* - update README

* 修改部分bug;调整callback

* 准备发布0.4.0版本“

* update readme

* support parallel loss

* 防止多卡的情况导致无法正确计算loss“

* update advance_tutorial jupyter notebook

* 1. 在embedding_loader中增加新的读取函数load_with_vocab(), load_without_vocab, 比之前的函数改变主要在(1)不再需要传入embed_dim(2)自动判断当前是word2vec还是glove.
2. vocabulary增加from_dataset(), index_dataset()函数。避免需要多行写index dataset的问题。
3. 在utils中新增一个cache_result()修饰器,用于cache函数的返回值。
4. callback中新增update_every属性

* 1.DataSet.apply()报错时提供错误的index
2.Vocabulary.from_dataset(), index_dataset()提供报错时的vocab顺序
3.embedloader在embed读取时遇到不规则的数据跳过这一行.

* update attention

* doc tools

* fix some doc errors

* 修改为中文注释,增加viterbi解码方法

* 样例版本

* - add pad sequence for lstm
- add csv, conll, json filereader
- update dataloader
- remove useless dataloader
- fix trainer loss print
- fix tests

* - fix test_tutorial

* 注释增加

* 测试文档

* 本地暂存

* 本地暂存

* 修改文档的顺序

* - add document

* 本地暂存

* update pooling

* update bert

* update documents in MLP

* update documents in snli

* combine self attention module to attention.py

* update documents on losses.py

* 对DataSet的文档进行更新

* update documents on metrics

* 1. 删除了LSTM中print的内容; 2. 将Trainer和Tester的use_cuda修改为了device; 3.补充Trainer的文档

* 增加对Trainer的注释

* 完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏

* update char level encoder

* update documents on embedding.py

* - update doc

* 补充注释,并修改部分代码

* - update doc
- add get_embeddings

* 修改了文档配置项

* 修改embedding为init_embed初始化

* 1.增加对Trainer和Tester的多卡支持;

* - add test
- fix jsonloader

* 删除了注释教程

* 给 dataset 增加了get_field_names

* 修复bug

* - add Const
- fix bugs

* 修改部分注释

* - add model runner for easier test models
- add model tests

* 修改了 docs 的配置和架构

* 修改了核心部分的一大部分文档,TODO:
1. 完善 trainer 和 tester 部分的文档
2. 研究注释样例与测试

* core部分的注释基本检查完成

* 修改了 io 部分的注释

* 全部改为相对路径引用

* 全部改为相对路径引用

* small change

* 1. 从安装文件中删除api/automl的安装
2. metric中存在seq_len的bug
3. sampler中存在命名错误,已修改

* 修复 bug :兼容 cpu 版本的 PyTorch
TODO:其它地方可能也存在类似的 bug

* 修改文档中的引用部分

* 把 tqdm.autonotebook 换成tqdm.auto

* - fix batch & vocab

* 上传了文档文件 *.rst

* 上传了文档文件和若干 TODO

* 讨论并整合了若干模块

* core部分的测试和一些小修改

* 删除了一些冗余文档

* update init files

* update const files

* update const files

* 增加cnn的测试

* fix a little bug

* - update attention
- fix tests

* 完善测试

* 完成快速入门教程

* 修改了sequence_modeling 命名为 sequence_labeling 的文档

* 重新 apidoc 解决改名的遗留问题

* 修改文档格式

* 统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask

* 增加了一行提示

* 在文档中展示 dataset_loader

* 提示 Dataset.read_csv 会被 CSVLoader 替换

* 完成 Callback 和 Trainer 之间的文档

* index更新了部分

* 删除冗余的print

* 删除用于分词的metric,因为有可能引起错误

* 修改文档中的中文名称

* 完成了详细介绍文档

* tutorial 的 ipynb 文件

* 修改了一些介绍文档

* 修改了 models 和 modules 的主页介绍

* 加上了 titlesonly 这个设置

* 修改了模块文档展示的标题

* 修改了 core 和 io 的开篇介绍

* 修改了 modules 和 models 开篇介绍

* 使用 .. todo:: 隐藏了可能被抽到文档中的 TODO 注释

* 修改了一些注释

* delete an old metric in test

* 修改 tutorials 的测试文件

* 把暂不发布的功能移到 legacy 文件夹

* 删除了不能运行的测试

* 修改 callback 的测试文件

* 删除了过时的教程和测试文件

* cache_results 参数的修改

* 修改 io 的测试文件; 删除了一些过时的测试

* 修复bug

* 修复无法通过test_utils.py的测试

* 修复与pytorch1.1中的padsequence的兼容问题; 修改Trainer的pbar

* 1. 修复metric中的bug; 2.增加metric测试

* add model summary

* 增加别名

* 删除encoder中的嵌套层

* 修改了 core 部分 import 的顺序,__all__ 暴露的内容

* 修改了 models 部分 import 的顺序,__all__ 暴露的内容

* 修改了文件名

* 修改了 modules 模块的__all__ 和 import

* fix var runn

* 增加vocab的clear方法

* 一些符合 PEP8 的微调

* 更新了cache_results的例子

* 1. 对callback中indices潜在None作出提示;2.DataSet支持通过List进行index

* 修改了一个typo

* 修改了 README.md

* update documents on bert

* update documents on encoder/bert

* 增加一个fitlog callback,实现与fitlog实验记录

* typo

* - update dataset_loader

* 增加了到 fitlog 文档的链接。

* 增加了 DataSet Loader 的文档

* - add star-transformer reproduction
2019-05-22 18:43:56 +08:00

182 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
class SpanConverter:
def __init__(self, replace_tag, pattern):
super(SpanConverter, self).__init__()
self.replace_tag = replace_tag
self.pattern = pattern
def find_certain_span_and_replace(self, sentence):
replaced_sentence = ''
prev_end = 0
for match in re.finditer(self.pattern, sentence):
start, end = match.span()
span = sentence[start:end]
replaced_sentence += sentence[prev_end:start] + self.span_to_special_tag(span)
prev_end = end
replaced_sentence += sentence[prev_end:]
return replaced_sentence
def span_to_special_tag(self, span):
return self.replace_tag
def find_certain_span(self, sentence):
spans = []
for match in re.finditer(self.pattern, sentence):
spans.append(match.span())
return spans
class AlphaSpanConverter(SpanConverter):
def __init__(self):
replace_tag = '<ALPHA>'
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag).
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])'
super(AlphaSpanConverter, self).__init__(replace_tag, pattern)
class DigitSpanConverter(SpanConverter):
def __init__(self):
replace_tag = '<NUM>'
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])'
super(DigitSpanConverter, self).__init__(replace_tag, pattern)
def span_to_special_tag(self, span):
# return self.special_tag
if span[0] == '0' and len(span) > 2:
return '<NUM>'
decimal_point_count = 0 # one might have more than one decimal pointers
for idx, char in enumerate(span):
if char == '.' or char == '' or char == '·':
decimal_point_count += 1
if span[-1] == '.' or span[-1] == '' or span[-1] == '·':
# last digit being decimal point means this is not a number
if decimal_point_count == 1:
return span
else:
return '<UNKDGT>'
if decimal_point_count == 1:
return '<DEC>'
elif decimal_point_count > 1:
return '<UNKDGT>'
else:
return '<NUM>'
class TimeConverter(SpanConverter):
def __init__(self):
replace_tag = '<TOC>'
pattern = '\d+[:][\d:]+(?=[\u4e00-\u9fff ,%.!<-])'
super().__init__(replace_tag, pattern)
class MixNumAlphaConverter(SpanConverter):
def __init__(self):
replace_tag = '<MIX>'
pattern = None
super().__init__(replace_tag, pattern)
def find_certain_span_and_replace(self, sentence):
replaced_sentence = ''
start = 0
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False
for idx in range(len(sentence)):
if re.match('[0-9a-zA-Z/\\(\\)\'&\\-]', sentence[idx]):
if not matching_flag:
replaced_sentence += sentence[start:idx]
start = idx
if re.match('[0-9]', sentence[idx]):
number_flag = True
elif re.match('[\'&\\-]', sentence[idx]):
link_flag = True
elif re.match('/', sentence[idx]):
slash_flag = True
elif re.match('[\\(\\)]', sentence[idx]):
bracket_flag = True
else:
alpha_flag = True
matching_flag = True
elif re.match('[\\.]', sentence[idx]):
pass
else:
if matching_flag:
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \
or (slash_flag and alpha_flag) or (link_flag and number_flag) \
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag):
span = sentence[start:idx]
start = idx
replaced_sentence += self.span_to_special_tag(span)
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False
replaced_sentence += sentence[start:]
return replaced_sentence
def find_certain_span(self, sentence):
spans = []
start = 0
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False
for idx in range(len(sentence)):
if re.match('[0-9a-zA-Z/\\(\\)\'&\\-]', sentence[idx]):
if not matching_flag:
start = idx
if re.match('[0-9]', sentence[idx]):
number_flag = True
elif re.match('[\'&\\-]', sentence[idx]):
link_flag = True
elif re.match('/', sentence[idx]):
slash_flag = True
elif re.match('[\\(\\)]', sentence[idx]):
bracket_flag = True
else:
alpha_flag = True
matching_flag = True
elif re.match('[\\.]', sentence[idx]):
pass
else:
if matching_flag:
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \
or (slash_flag and alpha_flag) or (link_flag and number_flag) \
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag):
spans.append((start, idx))
start = idx
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False
return spans
class EmailConverter(SpanConverter):
def __init__(self):
replaced_tag = "<EML>"
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])'
super(EmailConverter, self).__init__(replaced_tag, pattern)