mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
update reproduction
This commit is contained in:
parent
eb55856c78
commit
de856fb8eb
@ -101,9 +101,12 @@ class EmbedLoader(BaseLoader):
|
||||
"""
|
||||
if vocab is None:
|
||||
raise RuntimeError("You must provide a vocabulary.")
|
||||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim))
|
||||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim), dtype=np.float32)
|
||||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int)
|
||||
with open(emb_file, "r", encoding="utf-8") as f:
|
||||
startline = f.readline()
|
||||
if len(startline.split()) > 2:
|
||||
f.seek(0)
|
||||
for line in f:
|
||||
word, vector = EmbedLoader.parse_glove_line(line)
|
||||
if word in vocab:
|
||||
|
@ -26,7 +26,7 @@ arc_mlp_size = 500
|
||||
label_mlp_size = 100
|
||||
num_label = -1
|
||||
dropout = 0.3
|
||||
encoder="transformer"
|
||||
encoder="var-lstm"
|
||||
use_greedy_infer=false
|
||||
|
||||
[optim]
|
||||
|
@ -10,9 +10,13 @@ from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.api.pipeline import Pipeline
|
||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
|
||||
from fastNLP.core.vocabulary import Vocabulary
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.tester import Tester
|
||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection
|
||||
from fastNLP.io.model_io import ModelLoader
|
||||
from fastNLP.io.embed_loader import EmbedLoader
|
||||
from fastNLP.io.model_io import ModelSaver
|
||||
from fastNLP.io.dataset_loader import ConllxDataLoader
|
||||
from fastNLP.api.processor import *
|
||||
from fastNLP.io.embed_loader import EmbedLoader
|
||||
@ -156,6 +160,8 @@ print('test len {}'.format(len(test_data)))
|
||||
def train(path):
|
||||
# test saving pipeline
|
||||
save_pipe(path)
|
||||
embed = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v)
|
||||
embed = torch.tensor(embed, dtype=torch.float32)
|
||||
|
||||
# embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v)
|
||||
# embed = torch.tensor(embed, dtype=torch.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user