冲突解决

This commit is contained in:
yh 2018-11-11 12:43:16 +08:00
commit 0a8a76f769
5 changed files with 58 additions and 10 deletions

View File

@ -5,6 +5,8 @@ from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import *
from fastNLP.models.biaffine_parser import BiaffineParser
import torch
class DependencyParser(API):
def __init__(self):
@ -18,19 +20,35 @@ class DependencyParser(API):
pred = Predictor()
res = pred.predict(self.model, dataset)
heads, head_tags = [], []
for batch in res:
heads.append(batch['heads'])
head_tags.append(batch['labels'])
heads, head_tags = torch.cat(heads, dim=0), torch.cat(head_tags, dim=0)
return heads, head_tags
return res
def build(self):
pipe = Pipeline()
# build pipeline
BOS = '<BOS>'
NUM = '<NUM>'
model_args = {}
load_path = ''
word_vocab = load(f'{load_path}/word_v.pkl')
pos_vocab = load(f'{load_path}/pos_v.pkl')
word_seq = 'word_seq'
pos_seq = 'pos_seq'
pipe.add_processor(Num2TagProcessor('<NUM>', 'raw_sentence', word_seq))
pipe = Pipeline()
# build pipeline
pipe.add_processor(Num2TagProcessor(NUM, 'raw_sentence', word_seq))
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, word_seq, None))
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, pos_seq, None))
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx'))
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx'))
pipe.add_processor(MapFieldProcessor(lambda x: len(x), word_seq, 'seq_len'))
# load model parameters
self.model = BiaffineParser()
self.model = BiaffineParser(**model_args)
self.pipeline = pipe

View File

@ -145,7 +145,6 @@ class IndexerProcessor(Processor):
class VocabProcessor(Processor):
def __init__(self, field_name):
super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary()
@ -221,4 +220,13 @@ class ModelProcessor(Processor):
def set_model(self, model):
self.model = model
class Index2WordProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name):
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab
def process(self, dataset):
for ins in dataset:
new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]]
ins[self.new_added_field_name] = new_sent
return dataset

View File

@ -13,3 +13,6 @@ class BaseModel(torch.nn.Module):
def fit(self, train_data, dev_data=None, **train_args):
trainer = Trainer(**train_args)
trainer.train(self, train_data, dev_data)
def predict(self):
pass

View File

@ -9,6 +9,7 @@ from torch.nn import functional as F
from fastNLP.modules.utils import initial_parameter
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.models.base_model import BaseModel
def mst(scores):
"""
@ -113,7 +114,7 @@ def _find_cycle(vertices, edges):
return [SCC for SCC in _SCCs if len(SCC) > 1]
class GraphParser(nn.Module):
class GraphParser(BaseModel):
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding
"""
def __init__(self):
@ -370,4 +371,20 @@ class BiaffineParser(GraphParser):
label_nll = -(label_loss*float_mask).mean()
return arc_nll + label_nll
def predict(self, word_seq, pos_seq, word_seq_origin_len):
"""
:param word_seq:
:param pos_seq:
:param word_seq_origin_len:
:return: head_pred: [B, L]
label_pred: [B, L]
seq_len: [B,]
"""
res = self(word_seq, pos_seq, word_seq_origin_len)
output = {}
output['head_pred'] = res.pop('head_pred')
_, label_pred = res.pop('label_pred').max(2)
output['label_pred'] = label_pred
output['seq_len'] = word_seq_origin_len
return output

View File

@ -30,11 +30,13 @@ class TestCase1(unittest.TestCase):
for text, label in zip(texts, labels):
x = TextField(text, is_target=False)
y = LabelField(label, is_target=True)
ins = Instance(text=x, label=y)
ins = Instance(raw_text=x, label=y)
data.append(ins)
# use vocabulary to index data
data.index_field("text", vocab)
# data.index_field("text", vocab)
for ins in data:
ins['text'] = [vocab.to_index(w) for w in ins['raw_text']]
# define naive sampler for batch class
class SeqSampler: