mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 04:07:35 +08:00
update reproduction
This commit is contained in:
parent
eb55856c78
commit
de856fb8eb
@ -101,9 +101,12 @@ class EmbedLoader(BaseLoader):
|
|||||||
"""
|
"""
|
||||||
if vocab is None:
|
if vocab is None:
|
||||||
raise RuntimeError("You must provide a vocabulary.")
|
raise RuntimeError("You must provide a vocabulary.")
|
||||||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim))
|
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim), dtype=np.float32)
|
||||||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int)
|
hit_flags = np.zeros(shape=(len(vocab),), dtype=int)
|
||||||
with open(emb_file, "r", encoding="utf-8") as f:
|
with open(emb_file, "r", encoding="utf-8") as f:
|
||||||
|
startline = f.readline()
|
||||||
|
if len(startline.split()) > 2:
|
||||||
|
f.seek(0)
|
||||||
for line in f:
|
for line in f:
|
||||||
word, vector = EmbedLoader.parse_glove_line(line)
|
word, vector = EmbedLoader.parse_glove_line(line)
|
||||||
if word in vocab:
|
if word in vocab:
|
||||||
|
@ -26,7 +26,7 @@ arc_mlp_size = 500
|
|||||||
label_mlp_size = 100
|
label_mlp_size = 100
|
||||||
num_label = -1
|
num_label = -1
|
||||||
dropout = 0.3
|
dropout = 0.3
|
||||||
encoder="transformer"
|
encoder="var-lstm"
|
||||||
use_greedy_infer=false
|
use_greedy_infer=false
|
||||||
|
|
||||||
[optim]
|
[optim]
|
||||||
|
@ -10,9 +10,13 @@ from fastNLP.core.trainer import Trainer
|
|||||||
from fastNLP.core.instance import Instance
|
from fastNLP.core.instance import Instance
|
||||||
from fastNLP.api.pipeline import Pipeline
|
from fastNLP.api.pipeline import Pipeline
|
||||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
|
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
|
||||||
|
from fastNLP.core.vocabulary import Vocabulary
|
||||||
|
from fastNLP.core.dataset import DataSet
|
||||||
from fastNLP.core.tester import Tester
|
from fastNLP.core.tester import Tester
|
||||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection
|
from fastNLP.io.config_io import ConfigLoader, ConfigSection
|
||||||
from fastNLP.io.model_io import ModelLoader
|
from fastNLP.io.model_io import ModelLoader
|
||||||
|
from fastNLP.io.embed_loader import EmbedLoader
|
||||||
|
from fastNLP.io.model_io import ModelSaver
|
||||||
from fastNLP.io.dataset_loader import ConllxDataLoader
|
from fastNLP.io.dataset_loader import ConllxDataLoader
|
||||||
from fastNLP.api.processor import *
|
from fastNLP.api.processor import *
|
||||||
from fastNLP.io.embed_loader import EmbedLoader
|
from fastNLP.io.embed_loader import EmbedLoader
|
||||||
@ -156,6 +160,8 @@ print('test len {}'.format(len(test_data)))
|
|||||||
def train(path):
|
def train(path):
|
||||||
# test saving pipeline
|
# test saving pipeline
|
||||||
save_pipe(path)
|
save_pipe(path)
|
||||||
|
embed = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v)
|
||||||
|
embed = torch.tensor(embed, dtype=torch.float32)
|
||||||
|
|
||||||
# embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v)
|
# embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v)
|
||||||
# embed = torch.tensor(embed, dtype=torch.float32)
|
# embed = torch.tensor(embed, dtype=torch.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user