update reproduction

This commit is contained in:
yunfan 2019-01-15 15:33:39 +08:00
parent eb55856c78
commit de856fb8eb
3 changed files with 11 additions and 2 deletions

View File

@ -101,9 +101,12 @@ class EmbedLoader(BaseLoader):
""" """
if vocab is None: if vocab is None:
raise RuntimeError("You must provide a vocabulary.") raise RuntimeError("You must provide a vocabulary.")
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim)) embedding_matrix = np.zeros(shape=(len(vocab), emb_dim), dtype=np.float32)
hit_flags = np.zeros(shape=(len(vocab),), dtype=int) hit_flags = np.zeros(shape=(len(vocab),), dtype=int)
with open(emb_file, "r", encoding="utf-8") as f: with open(emb_file, "r", encoding="utf-8") as f:
startline = f.readline()
if len(startline.split()) > 2:
f.seek(0)
for line in f: for line in f:
word, vector = EmbedLoader.parse_glove_line(line) word, vector = EmbedLoader.parse_glove_line(line)
if word in vocab: if word in vocab:

View File

@ -26,7 +26,7 @@ arc_mlp_size = 500
label_mlp_size = 100 label_mlp_size = 100
num_label = -1 num_label = -1
dropout = 0.3 dropout = 0.3
encoder="transformer" encoder="var-lstm"
use_greedy_infer=false use_greedy_infer=false
[optim] [optim]

View File

@ -10,9 +10,13 @@ from fastNLP.core.trainer import Trainer
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline from fastNLP.api.pipeline import Pipeline
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.io.model_io import ModelLoader from fastNLP.io.model_io import ModelLoader
from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.io.model_io import ModelSaver
from fastNLP.io.dataset_loader import ConllxDataLoader from fastNLP.io.dataset_loader import ConllxDataLoader
from fastNLP.api.processor import * from fastNLP.api.processor import *
from fastNLP.io.embed_loader import EmbedLoader from fastNLP.io.embed_loader import EmbedLoader
@ -156,6 +160,8 @@ print('test len {}'.format(len(test_data)))
def train(path): def train(path):
# test saving pipeline # test saving pipeline
save_pipe(path) save_pipe(path)
embed = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v)
embed = torch.tensor(embed, dtype=torch.float32)
# embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v) # embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v)
# embed = torch.tensor(embed, dtype=torch.float32) # embed = torch.tensor(embed, dtype=torch.float32)