mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 20:57:37 +08:00
add index to word processor
This commit is contained in:
parent
7df33b23ea
commit
82f4351540
@ -5,6 +5,8 @@ from fastNLP.api.pipeline import Pipeline
|
||||
from fastNLP.api.processor import *
|
||||
from fastNLP.models.biaffine_parser import BiaffineParser
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DependencyParser(API):
|
||||
def __init__(self):
|
||||
@ -18,19 +20,35 @@ class DependencyParser(API):
|
||||
|
||||
pred = Predictor()
|
||||
res = pred.predict(self.model, dataset)
|
||||
heads, head_tags = [], []
|
||||
for batch in res:
|
||||
heads.append(batch['heads'])
|
||||
head_tags.append(batch['labels'])
|
||||
heads, head_tags = torch.cat(heads, dim=0), torch.cat(head_tags, dim=0)
|
||||
return heads, head_tags
|
||||
|
||||
return res
|
||||
|
||||
def build(self):
|
||||
pipe = Pipeline()
|
||||
|
||||
# build pipeline
|
||||
BOS = '<BOS>'
|
||||
NUM = '<NUM>'
|
||||
model_args = {}
|
||||
load_path = ''
|
||||
word_vocab = load(f'{load_path}/word_v.pkl')
|
||||
pos_vocab = load(f'{load_path}/pos_v.pkl')
|
||||
word_seq = 'word_seq'
|
||||
pos_seq = 'pos_seq'
|
||||
pipe.add_processor(Num2TagProcessor('<NUM>', 'raw_sentence', word_seq))
|
||||
|
||||
pipe = Pipeline()
|
||||
# build pipeline
|
||||
pipe.add_processor(Num2TagProcessor(NUM, 'raw_sentence', word_seq))
|
||||
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, word_seq, None))
|
||||
pipe.add_processor(MapFieldProcessor(lambda x: [BOS] + x, pos_seq, None))
|
||||
pipe.add_processor(IndexerProcessor(word_vocab, word_seq, word_seq+'_idx'))
|
||||
pipe.add_processor(IndexerProcessor(pos_vocab, pos_seq, pos_seq+'_idx'))
|
||||
pipe.add_processor(MapFieldProcessor(lambda x: len(x), word_seq, 'seq_len'))
|
||||
|
||||
|
||||
# load model parameters
|
||||
self.model = BiaffineParser()
|
||||
self.model = BiaffineParser(**model_args)
|
||||
self.pipeline = pipe
|
||||
|
||||
|
@ -145,7 +145,6 @@ class IndexerProcessor(Processor):
|
||||
|
||||
class VocabProcessor(Processor):
|
||||
def __init__(self, field_name):
|
||||
|
||||
super(VocabProcessor, self).__init__(field_name, None)
|
||||
self.vocab = Vocabulary()
|
||||
|
||||
@ -172,3 +171,15 @@ class SeqLenProcessor(Processor):
|
||||
ins[self.new_added_field_name] = length
|
||||
dataset.set_need_tensor(**{self.new_added_field_name: True})
|
||||
return dataset
|
||||
|
||||
class Index2WordProcessor(Processor):
|
||||
def __init__(self, vocab, field_name, new_added_field_name):
|
||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
|
||||
self.vocab = vocab
|
||||
|
||||
def process(self, dataset):
|
||||
for ins in dataset:
|
||||
new_sent = [self.vocab.to_word(w) for w in ins[self.field_name]]
|
||||
ins[self.new_added_field_name] = new_sent
|
||||
return dataset
|
||||
|
||||
|
@ -13,3 +13,6 @@ class BaseModel(torch.nn.Module):
|
||||
def fit(self, train_data, dev_data=None, **train_args):
|
||||
trainer = Trainer(**train_args)
|
||||
trainer.train(self, train_data, dev_data)
|
||||
|
||||
def predict(self):
|
||||
pass
|
||||
|
@ -9,6 +9,7 @@ from torch.nn import functional as F
|
||||
from fastNLP.modules.utils import initial_parameter
|
||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM
|
||||
from fastNLP.modules.dropout import TimestepDropout
|
||||
from fastNLP.models.base_model import BaseModel
|
||||
|
||||
def mst(scores):
|
||||
"""
|
||||
@ -113,7 +114,7 @@ def _find_cycle(vertices, edges):
|
||||
return [SCC for SCC in _SCCs if len(SCC) > 1]
|
||||
|
||||
|
||||
class GraphParser(nn.Module):
|
||||
class GraphParser(BaseModel):
|
||||
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding
|
||||
"""
|
||||
def __init__(self):
|
||||
@ -370,4 +371,20 @@ class BiaffineParser(GraphParser):
|
||||
label_nll = -(label_loss*float_mask).mean()
|
||||
return arc_nll + label_nll
|
||||
|
||||
def predict(self, word_seq, pos_seq, word_seq_origin_len):
|
||||
"""
|
||||
|
||||
:param word_seq:
|
||||
:param pos_seq:
|
||||
:param word_seq_origin_len:
|
||||
:return: head_pred: [B, L]
|
||||
label_pred: [B, L]
|
||||
seq_len: [B,]
|
||||
"""
|
||||
res = self(word_seq, pos_seq, word_seq_origin_len)
|
||||
output = {}
|
||||
output['head_pred'] = res.pop('head_pred')
|
||||
_, label_pred = res.pop('label_pred').max(2)
|
||||
output['label_pred'] = label_pred
|
||||
output['seq_len'] = word_seq_origin_len
|
||||
return output
|
||||
|
@ -30,11 +30,13 @@ class TestCase1(unittest.TestCase):
|
||||
for text, label in zip(texts, labels):
|
||||
x = TextField(text, is_target=False)
|
||||
y = LabelField(label, is_target=True)
|
||||
ins = Instance(text=x, label=y)
|
||||
ins = Instance(raw_text=x, label=y)
|
||||
data.append(ins)
|
||||
|
||||
# use vocabulary to index data
|
||||
data.index_field("text", vocab)
|
||||
# data.index_field("text", vocab)
|
||||
for ins in data:
|
||||
ins['text'] = [vocab.to_index(w) for w in ins['raw_text']]
|
||||
|
||||
# define naive sampler for batch class
|
||||
class SeqSampler:
|
||||
|
Loading…
Reference in New Issue
Block a user