mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
全部改为相对路径引用
This commit is contained in:
parent
a1f8cdec48
commit
f66012a640
@ -1 +1,2 @@
|
||||
__all__ = ["CWS", "POS", "Parser"]
|
||||
from .api import CWS, POS, Parser
|
||||
|
@ -1,6 +1,3 @@
|
||||
"""
|
||||
api.api的介绍文档
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
@ -8,15 +5,14 @@ import torch
|
||||
warnings.filterwarnings('ignore')
|
||||
import os
|
||||
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
from fastNLP.api.utils import load_url
|
||||
from fastNLP.api.processor import ModelProcessor
|
||||
from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.api.pipeline import Pipeline
|
||||
from fastNLP.core.metrics import SpanFPreRecMetric
|
||||
from fastNLP.api.processor import IndexerProcessor
|
||||
from ..core.dataset import DataSet
|
||||
from .utils import load_url
|
||||
from .processor import ModelProcessor
|
||||
from ..io.dataset_loader import _cut_long_sentence, ConllLoader
|
||||
from ..core.instance import Instance
|
||||
from ..api.pipeline import Pipeline
|
||||
from ..core.metrics import SpanFPreRecMetric
|
||||
from .processor import IndexerProcessor
|
||||
|
||||
# TODO add pretrain urls
|
||||
model_urls = {
|
||||
@ -28,9 +24,10 @@ model_urls = {
|
||||
|
||||
class ConllCWSReader(object):
|
||||
"""Deprecated. Use ConllLoader for all types of conll-format files."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def load(self, path, cut_long_sent=False):
|
||||
"""
|
||||
返回的DataSet只包含raw_sentence这个field,内容为str。
|
||||
@ -63,7 +60,7 @@ class ConllCWSReader(object):
|
||||
sample.append(line.strip().split())
|
||||
if len(sample) > 0:
|
||||
datalist.append(sample)
|
||||
|
||||
|
||||
ds = DataSet()
|
||||
for sample in datalist:
|
||||
# print(sample)
|
||||
@ -78,7 +75,7 @@ class ConllCWSReader(object):
|
||||
for raw_sentence in sents:
|
||||
ds.append(Instance(raw_sentence=raw_sentence))
|
||||
return ds
|
||||
|
||||
|
||||
def get_char_lst(self, sample):
|
||||
if len(sample) == 0:
|
||||
return None
|
||||
@ -90,11 +87,13 @@ class ConllCWSReader(object):
|
||||
text.append(t1)
|
||||
return text
|
||||
|
||||
|
||||
class ConllxDataLoader(ConllLoader):
|
||||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。
|
||||
|
||||
Deprecated. Use ConllLoader for all types of conll-format files.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
headers = [
|
||||
'words', 'pos_tags', 'heads', 'labels',
|
||||
@ -106,18 +105,15 @@ class ConllxDataLoader(ConllLoader):
|
||||
|
||||
|
||||
class API:
|
||||
"""
|
||||
这是 API 类的文档
|
||||
"""
|
||||
def __init__(self):
|
||||
self.pipeline = None
|
||||
self._dict = None
|
||||
|
||||
|
||||
def predict(self, *args, **kwargs):
|
||||
"""Do prediction for the given input.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test(self, file_path):
|
||||
"""Test performance over the given data set.
|
||||
|
||||
@ -125,7 +121,7 @@ class API:
|
||||
:return: a dictionary of metric values
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def load(self, path, device):
|
||||
if os.path.exists(os.path.expanduser(path)):
|
||||
_dict = torch.load(path, map_location='cpu')
|
||||
@ -145,14 +141,14 @@ class POS(API):
|
||||
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model_path=None, device='cpu'):
|
||||
super(POS, self).__init__()
|
||||
if model_path is None:
|
||||
model_path = model_urls['pos']
|
||||
|
||||
|
||||
self.load(model_path, device)
|
||||
|
||||
|
||||
def predict(self, content):
|
||||
"""predict函数的介绍,
|
||||
函数介绍的第二句,这句话不会换行
|
||||
@ -162,48 +158,48 @@ class POS(API):
|
||||
"""
|
||||
if not hasattr(self, "pipeline"):
|
||||
raise ValueError("You have to load model first.")
|
||||
|
||||
|
||||
sentence_list = content
|
||||
# 1. 检查sentence的类型
|
||||
for sentence in sentence_list:
|
||||
if not all((type(obj) == str for obj in sentence)):
|
||||
raise ValueError("Input must be list of list of string.")
|
||||
|
||||
|
||||
# 2. 组建dataset
|
||||
dataset = DataSet()
|
||||
dataset.add_field("words", sentence_list)
|
||||
|
||||
|
||||
# 3. 使用pipeline
|
||||
self.pipeline(dataset)
|
||||
|
||||
|
||||
def merge_tag(words_list, tags_list):
|
||||
rtn = []
|
||||
for words, tags in zip(words_list, tags_list):
|
||||
rtn.append([w + "/" + t for w, t in zip(words, tags)])
|
||||
return rtn
|
||||
|
||||
|
||||
output = dataset.field_arrays["tag"].content
|
||||
if isinstance(content, str):
|
||||
return output[0]
|
||||
elif isinstance(content, list):
|
||||
return merge_tag(content, output)
|
||||
|
||||
|
||||
def test(self, file_path):
|
||||
test_data = ConllxDataLoader().load(file_path)
|
||||
|
||||
|
||||
save_dict = self._dict
|
||||
tag_vocab = save_dict["tag_vocab"]
|
||||
pipeline = save_dict["pipeline"]
|
||||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
|
||||
pipeline.pipeline = [index_tag] + pipeline.pipeline
|
||||
|
||||
|
||||
test_data.rename_field("pos_tags", "tag")
|
||||
pipeline(test_data)
|
||||
test_data.set_target("truth")
|
||||
prediction = test_data.field_arrays["predict"].content
|
||||
truth = test_data.field_arrays["truth"].content
|
||||
seq_len = test_data.field_arrays["word_seq_origin_len"].content
|
||||
|
||||
|
||||
# padding by hand
|
||||
max_length = max([len(seq) for seq in prediction])
|
||||
for idx in range(len(prediction)):
|
||||
@ -217,7 +213,7 @@ class POS(API):
|
||||
f1 = round(test_result['f'] * 100, 2)
|
||||
pre = round(test_result['pre'] * 100, 2)
|
||||
rec = round(test_result['rec'] * 100, 2)
|
||||
|
||||
|
||||
return {"F1": f1, "precision": pre, "recall": rec}
|
||||
|
||||
|
||||
@ -228,14 +224,15 @@ class CWS(API):
|
||||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型
|
||||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。
|
||||
"""
|
||||
|
||||
def __init__(self, model_path=None, device='cpu'):
|
||||
|
||||
super(CWS, self).__init__()
|
||||
if model_path is None:
|
||||
model_path = model_urls['cws']
|
||||
|
||||
|
||||
self.load(model_path, device)
|
||||
|
||||
|
||||
def predict(self, content):
|
||||
"""
|
||||
分词接口。
|
||||
@ -246,27 +243,27 @@ class CWS(API):
|
||||
"""
|
||||
if not hasattr(self, 'pipeline'):
|
||||
raise ValueError("You have to load model first.")
|
||||
|
||||
|
||||
sentence_list = []
|
||||
# 1. 检查sentence的类型
|
||||
if isinstance(content, str):
|
||||
sentence_list.append(content)
|
||||
elif isinstance(content, list):
|
||||
sentence_list = content
|
||||
|
||||
|
||||
# 2. 组建dataset
|
||||
dataset = DataSet()
|
||||
dataset.add_field('raw_sentence', sentence_list)
|
||||
|
||||
|
||||
# 3. 使用pipeline
|
||||
self.pipeline(dataset)
|
||||
|
||||
|
||||
output = dataset.get_field('output').content
|
||||
if isinstance(content, str):
|
||||
return output[0]
|
||||
elif isinstance(content, list):
|
||||
return output
|
||||
|
||||
|
||||
def test(self, filepath):
|
||||
"""
|
||||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。
|
||||
@ -292,28 +289,28 @@ class CWS(API):
|
||||
tag_proc = self._dict['tag_proc']
|
||||
cws_model = self.pipeline.pipeline[-2].model
|
||||
pipeline = self.pipeline.pipeline[:-2]
|
||||
|
||||
|
||||
pipeline.insert(1, tag_proc)
|
||||
pp = Pipeline(pipeline)
|
||||
|
||||
|
||||
reader = ConllCWSReader()
|
||||
|
||||
|
||||
# te_filename = '/home/hyan/ctb3/test.conllx'
|
||||
te_dataset = reader.load(filepath)
|
||||
pp(te_dataset)
|
||||
|
||||
|
||||
from fastNLP.core.tester import Tester
|
||||
from fastNLP.core.metrics import BMESF1PreRecMetric
|
||||
|
||||
|
||||
tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64,
|
||||
verbose=0)
|
||||
eval_res = tester.test()
|
||||
|
||||
|
||||
f1 = eval_res['BMESF1PreRecMetric']['f']
|
||||
pre = eval_res['BMESF1PreRecMetric']['pre']
|
||||
rec = eval_res['BMESF1PreRecMetric']['rec']
|
||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
|
||||
|
||||
|
||||
return {"F1": f1, "precision": pre, "recall": rec}
|
||||
|
||||
|
||||
@ -322,25 +319,25 @@ class Parser(API):
|
||||
super(Parser, self).__init__()
|
||||
if model_path is None:
|
||||
model_path = model_urls['parser']
|
||||
|
||||
|
||||
self.pos_tagger = POS(device=device)
|
||||
self.load(model_path, device)
|
||||
|
||||
|
||||
def predict(self, content):
|
||||
if not hasattr(self, 'pipeline'):
|
||||
raise ValueError("You have to load model first.")
|
||||
|
||||
|
||||
# 1. 利用POS得到分词和pos tagging结果
|
||||
pos_out = self.pos_tagger.predict(content)
|
||||
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()]
|
||||
|
||||
|
||||
# 2. 组建dataset
|
||||
dataset = DataSet()
|
||||
dataset.add_field('wp', pos_out)
|
||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words')
|
||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos')
|
||||
dataset.rename_field("words", "raw_words")
|
||||
|
||||
|
||||
# 3. 使用pipeline
|
||||
self.pipeline(dataset)
|
||||
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred')
|
||||
@ -348,7 +345,7 @@ class Parser(API):
|
||||
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output')
|
||||
# output like: [['2/top', '0/root', '4/nn', '2/dep']]
|
||||
return dataset.field_arrays['output'].content
|
||||
|
||||
|
||||
def load_test_file(self, path):
|
||||
def get_one(sample):
|
||||
sample = list(map(list, zip(*sample)))
|
||||
@ -360,7 +357,7 @@ class Parser(API):
|
||||
return None
|
||||
# return word_seq, pos_seq, head_seq, head_tag_seq
|
||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7]
|
||||
|
||||
|
||||
datalist = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
sample = []
|
||||
@ -374,14 +371,14 @@ class Parser(API):
|
||||
sample.append(line.split('\t'))
|
||||
if len(sample) > 0:
|
||||
datalist.append(sample)
|
||||
|
||||
|
||||
data = [get_one(sample) for sample in datalist]
|
||||
data_list = list(filter(lambda x: x is not None, data))
|
||||
return data_list
|
||||
|
||||
|
||||
def test(self, filepath):
|
||||
data = self.load_test_file(filepath)
|
||||
|
||||
|
||||
def convert(data):
|
||||
BOS = '<BOS>'
|
||||
dataset = DataSet()
|
||||
@ -396,7 +393,7 @@ class Parser(API):
|
||||
arc_true=heads,
|
||||
tags=head_tags))
|
||||
return dataset
|
||||
|
||||
|
||||
ds = convert(data)
|
||||
pp = self.pipeline
|
||||
for p in pp:
|
||||
@ -417,23 +414,23 @@ class Parser(API):
|
||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0
|
||||
uas = head_cor / total
|
||||
# print('uas:{:.2f}'.format(uas))
|
||||
|
||||
|
||||
for p in pp:
|
||||
if p.field_name == 'gold_words':
|
||||
p.field_name = 'word_list'
|
||||
elif p.field_name == 'gold_pos':
|
||||
p.field_name = 'pos_list'
|
||||
|
||||
|
||||
return {"USA": round(uas, 5)}
|
||||
|
||||
|
||||
class Analyzer:
|
||||
def __init__(self, device='cpu'):
|
||||
|
||||
|
||||
self.cws = CWS(device=device)
|
||||
self.pos = POS(device=device)
|
||||
self.parser = Parser(device=device)
|
||||
|
||||
|
||||
def predict(self, content, seg=False, pos=False, parser=False):
|
||||
if seg is False and pos is False and parser is False:
|
||||
seg = True
|
||||
@ -447,9 +444,9 @@ class Analyzer:
|
||||
if parser:
|
||||
parser_output = self.parser.predict(content)
|
||||
output_dict['parser'] = parser_output
|
||||
|
||||
|
||||
return output_dict
|
||||
|
||||
|
||||
def test(self, filepath):
|
||||
output_dict = {}
|
||||
if self.cws:
|
||||
@ -461,5 +458,5 @@ class Analyzer:
|
||||
if self.parser:
|
||||
parser_output = self.parser.test(filepath)
|
||||
output_dict['parser'] = parser_output
|
||||
|
||||
|
||||
return output_dict
|
||||
|
@ -3,7 +3,7 @@ api/example.py contains all API examples provided by fastNLP.
|
||||
It is used as a tutorial for API or a test script since it is difficult to test APIs in travis.
|
||||
|
||||
"""
|
||||
from fastNLP.api import CWS, POS, Parser
|
||||
from . import CWS, POS, Parser
|
||||
|
||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
|
||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
|
||||
|
@ -1,4 +1,4 @@
|
||||
from fastNLP.api.processor import Processor
|
||||
from ..api.processor import Processor
|
||||
|
||||
|
||||
class Pipeline:
|
||||
|
@ -3,10 +3,10 @@ from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.sampler import SequentialSampler
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from ..core.batch import Batch
|
||||
from ..core.dataset import DataSet
|
||||
from ..core.sampler import SequentialSampler
|
||||
from ..core.vocabulary import Vocabulary
|
||||
|
||||
|
||||
class Processor(object):
|
||||
|
@ -11,15 +11,15 @@ import torch
|
||||
try:
|
||||
from tqdm.autonotebook import tqdm
|
||||
except:
|
||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm
|
||||
from ..core.utils import _pseudo_tqdm as tqdm
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.callback import CallbackException
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.utils import _move_dict_value_to_device
|
||||
from ..core.batch import Batch
|
||||
from ..core.callback import CallbackException
|
||||
from ..core.dataset import DataSet
|
||||
from ..core.utils import _move_dict_value_to_device
|
||||
import fastNLP
|
||||
import fastNLP.automl.enas_utils as utils
|
||||
from fastNLP.core.utils import _build_args
|
||||
from . import enas_utils as utils
|
||||
from ..core.utils import _build_args
|
||||
|
||||
from torch.optim import Adam
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from fastNLP.modules.decoder.MLP import MLP
|
||||
from ..modules.decoder.MLP import MLP
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from fastNLP.modules.encoder import BertModel
|
||||
from ..modules.encoder import BertModel
|
||||
|
||||
|
||||
class BertForSequenceClassification(BaseModel):
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastNLP.modules.encoder.lstm import LSTM
|
||||
from ..modules.encoder.lstm import LSTM
|
||||
|
||||
|
||||
class Highway(nn.Module):
|
||||
|
@ -5,9 +5,8 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import fastNLP
|
||||
import fastNLP.models.enas_utils as utils
|
||||
from fastNLP.models.enas_utils import Node
|
||||
from . import enas_utils as utils
|
||||
from .enas_utils import Node
|
||||
|
||||
|
||||
def _construct_dags(prev_nodes, activations, func_names, num_blocks):
|
||||
|
@ -9,9 +9,8 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
import fastNLP.models.enas_utils as utils
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
import fastNLP.modules.encoder as encoder
|
||||
from . import enas_utils as utils
|
||||
from .base_model import BaseModel
|
||||
|
||||
def _get_dropped_weights(w_raw, dropout_p, is_training):
|
||||
"""Drops out weights to implement DropConnect.
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@ -8,21 +7,19 @@ from datetime import timedelta
|
||||
import numpy as np
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
from tqdm.autonotebook import tqdm
|
||||
except:
|
||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm
|
||||
from ..core.utils import _pseudo_tqdm as tqdm
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.callback import CallbackManager, CallbackException
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.utils import _CheckError
|
||||
from fastNLP.core.utils import _move_dict_value_to_device
|
||||
import fastNLP
|
||||
import fastNLP.models.enas_utils as utils
|
||||
from fastNLP.core.utils import _build_args
|
||||
from ..core.trainer import Trainer
|
||||
from ..core.batch import Batch
|
||||
from ..core.callback import CallbackManager, CallbackException
|
||||
from ..core.dataset import DataSet
|
||||
from ..core.utils import _move_dict_value_to_device
|
||||
from . import enas_utils as utils
|
||||
from ..core.utils import _build_args
|
||||
|
||||
from torch.optim import Adam
|
||||
|
||||
@ -34,7 +31,7 @@ def _get_no_grad_ctx_mgr():
|
||||
return torch.no_grad()
|
||||
|
||||
|
||||
class ENASTrainer(fastNLP.Trainer):
|
||||
class ENASTrainer(Trainer):
|
||||
"""A class to wrap training code."""
|
||||
def __init__(self, train_data, model, controller, **kwargs):
|
||||
"""Constructor for training algorithm.
|
||||
|
@ -4,21 +4,20 @@ from __future__ import print_function
|
||||
|
||||
from collections import defaultdict
|
||||
import collections
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def detach(h):
|
||||
if type(h) == Variable:
|
||||
return Variable(h.data)
|
||||
else:
|
||||
return tuple(detach(v) for v in h)
|
||||
|
||||
|
||||
def get_variable(inputs, cuda=False, **kwargs):
|
||||
if type(inputs) in [list, np.ndarray]:
|
||||
inputs = torch.Tensor(inputs)
|
||||
@ -28,10 +27,12 @@ def get_variable(inputs, cuda=False, **kwargs):
|
||||
out = Variable(inputs, **kwargs)
|
||||
return out
|
||||
|
||||
|
||||
def update_lr(optimizer, lr):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
Node = collections.namedtuple('Node', ['id', 'name'])
|
||||
|
||||
|
||||
@ -48,9 +49,9 @@ def to_item(x):
|
||||
"""Converts x, possibly scalar and possibly tensor, to a Python scalar."""
|
||||
if isinstance(x, (float, int)):
|
||||
return x
|
||||
|
||||
|
||||
if float(torch.__version__[0:3]) < 0.4:
|
||||
assert (x.dim() == 1) and (len(x) == 1)
|
||||
return x[0]
|
||||
|
||||
|
||||
return x.item()
|
||||
|
@ -1,9 +1,9 @@
|
||||
import torch
|
||||
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules import decoder, encoder
|
||||
from fastNLP.modules.decoder.CRF import allowed_transitions
|
||||
from fastNLP.modules.utils import seq_mask
|
||||
from .base_model import BaseModel
|
||||
from ..modules import decoder, encoder
|
||||
from ..modules.decoder.CRF import allowed_transitions
|
||||
from ..modules.utils import seq_mask
|
||||
|
||||
|
||||
class SeqLabeling(BaseModel):
|
||||
|
@ -1,11 +1,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
from fastNLP.modules import decoder as Decoder
|
||||
from fastNLP.modules import encoder as Encoder
|
||||
from fastNLP.modules import aggregator as Aggregator
|
||||
from fastNLP.modules.utils import seq_mask
|
||||
from .base_model import BaseModel
|
||||
from ..modules import decoder as Decoder
|
||||
from ..modules import encoder as Encoder
|
||||
from ..modules import aggregator as Aggregator
|
||||
from ..modules.utils import seq_mask
|
||||
|
||||
|
||||
my_inf = 10e12
|
||||
|
@ -7,7 +7,6 @@ from ..core.const import Const
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class StarTransEnc(nn.Module):
|
||||
|
@ -4,10 +4,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.modules.dropout import TimestepDropout
|
||||
from fastNLP.modules.utils import mask_softmax
|
||||
from ..dropout import TimestepDropout
|
||||
from ..utils import mask_softmax
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from ..utils import initial_parameter
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
@ -1,17 +1,12 @@
|
||||
# python: 3.6
|
||||
# encoding: utf-8
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MaxPool(nn.Module):
|
||||
"""Max-pooling模块。"""
|
||||
|
||||
def __init__(
|
||||
self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None,
|
||||
return_indices=False, ceil_mode=False
|
||||
):
|
||||
|
||||
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None,
|
||||
return_indices=False, ceil_mode=False):
|
||||
"""
|
||||
:param stride: 窗口移动大小,默认为kernel_size
|
||||
:param padding: padding的内容,默认为0
|
||||
@ -30,7 +25,7 @@ class MaxPool(nn.Module):
|
||||
self.kernel_size = kernel_size
|
||||
self.return_indices = return_indices
|
||||
self.ceil_mode = ceil_mode
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.dimension == 1:
|
||||
pooling = nn.MaxPool1d(
|
||||
@ -57,10 +52,11 @@ class MaxPool(nn.Module):
|
||||
|
||||
class MaxPoolWithMask(nn.Module):
|
||||
"""带mask矩阵的1维max pooling"""
|
||||
|
||||
def __init__(self):
|
||||
super(MaxPoolWithMask, self).__init__()
|
||||
self.inf = 10e12
|
||||
|
||||
|
||||
def forward(self, tensor, mask, dim=1):
|
||||
"""
|
||||
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor
|
||||
@ -75,11 +71,11 @@ class MaxPoolWithMask(nn.Module):
|
||||
|
||||
class KMaxPool(nn.Module):
|
||||
"""K max-pooling module."""
|
||||
|
||||
|
||||
def __init__(self, k=1):
|
||||
super(KMaxPool, self).__init__()
|
||||
self.k = k
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param torch.Tensor x: [N, C, L] 初始tensor
|
||||
@ -92,12 +88,12 @@ class KMaxPool(nn.Module):
|
||||
|
||||
class AvgPool(nn.Module):
|
||||
"""1-d average pooling module."""
|
||||
|
||||
|
||||
def __init__(self, stride=None, padding=0):
|
||||
super(AvgPool, self).__init__()
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param torch.Tensor x: [N, C, L] 初始tensor
|
||||
@ -117,7 +113,7 @@ class MeanPoolWithMask(nn.Module):
|
||||
def __init__(self):
|
||||
super(MeanPoolWithMask, self).__init__()
|
||||
self.inf = 10e12
|
||||
|
||||
|
||||
def forward(self, tensor, mask, dim=1):
|
||||
"""
|
||||
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor
|
||||
@ -127,7 +123,3 @@ class MeanPoolWithMask(nn.Module):
|
||||
"""
|
||||
masks = mask.view(mask.size(0), mask.size(1), -1).float()
|
||||
return torch.sum(tensor * masks.float(), dim=dim) / torch.sum(masks.float(), dim=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from fastNLP.modules.decoder.utils import log_sum_exp
|
||||
from ..utils import initial_parameter
|
||||
from ..decoder.utils import log_sum_exp
|
||||
|
||||
|
||||
def seq_len_to_byte_mask(seq_lens):
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from ..utils import initial_parameter
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from ..utils import initial_parameter
|
||||
|
||||
|
||||
# from torch.nn.init import xavier_uniform
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from ..utils import initial_parameter
|
||||
|
||||
|
||||
class ConvMaxpool(nn.Module):
|
||||
|
@ -1,5 +1,5 @@
|
||||
import torch.nn as nn
|
||||
from fastNLP.modules.utils import get_embeddings
|
||||
from ..utils import get_embeddings
|
||||
|
||||
class Embedding(nn.Embedding):
|
||||
"""Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度"""
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from ..utils import initial_parameter
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils.rnn as rnn
|
||||
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from ..utils import initial_parameter
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
|
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from ..utils import initial_parameter
|
||||
|
||||
try:
|
||||
from torch import flip
|
||||
|
Loading…
Reference in New Issue
Block a user