mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
Merge remote-tracking branch 'origin/dev0.5.0' into batch
This commit is contained in:
commit
efe3574014
@ -30,7 +30,7 @@ class BertConfig:
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate = intermediate_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
|
@ -157,7 +157,9 @@ class StaticEmbedding(TokenEmbedding):
|
||||
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz',
|
||||
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz",
|
||||
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz",
|
||||
'cn': "tencent_cn-dab24577.tar.gz"
|
||||
'en-fasttext': "cc.en.300.vec-d53187b2.gz",
|
||||
'cn': "tencent_cn-dab24577.tar.gz",
|
||||
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz",
|
||||
}
|
||||
|
||||
# 得到cache_path
|
||||
|
@ -100,13 +100,14 @@ class TestIndexing(unittest.TestCase):
|
||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])
|
||||
|
||||
def test_iteration(self):
|
||||
vocab = Vocabulary()
|
||||
vocab = Vocabulary(padding=None, unknown=None)
|
||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
|
||||
"works", "well", "in", "most", "cases", "scales", "well"]
|
||||
vocab.update(text)
|
||||
text = set(text)
|
||||
for word in vocab:
|
||||
for word, idx in vocab:
|
||||
self.assertTrue(word in text)
|
||||
self.assertTrue(idx < len(vocab))
|
||||
|
||||
|
||||
class TestOther(unittest.TestCase):
|
||||
|
@ -12,7 +12,6 @@ class TestCNNText(unittest.TestCase):
|
||||
model = CNNText(init_emb,
|
||||
NUM_CLS,
|
||||
kernel_nums=(1, 3, 5),
|
||||
kernel_sizes=(2, 2, 2),
|
||||
padding=0,
|
||||
kernel_sizes=(1, 3, 5),
|
||||
dropout=0.5)
|
||||
RUNNER.run_model_with_task(TEXT_CLS, model)
|
||||
|
@ -70,7 +70,7 @@ class TestTutorial(unittest.TestCase):
|
||||
break
|
||||
|
||||
from fastNLP.models import CNNText
|
||||
model = CNNText((len(vocab), 50), num_classes=5, padding=2, dropout=0.1)
|
||||
model = CNNText((len(vocab), 50), num_classes=5, dropout=0.1)
|
||||
|
||||
from fastNLP import Trainer
|
||||
from copy import deepcopy
|
||||
@ -143,7 +143,7 @@ class TestTutorial(unittest.TestCase):
|
||||
is_input=True)
|
||||
|
||||
from fastNLP.models import CNNText
|
||||
model = CNNText((len(vocab), 50), num_classes=5, padding=2, dropout=0.1)
|
||||
model = CNNText((len(vocab), 50), num_classes=5, dropout=0.1)
|
||||
|
||||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user