mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
修复Vocabulary在load的时候可能发生的bug
This commit is contained in:
parent
2dee67129a
commit
18747e632e
@ -519,11 +519,11 @@ class Vocabulary(object):
|
||||
line = line.strip()
|
||||
if line:
|
||||
name, value = line.split()
|
||||
if name == 'max_size':
|
||||
vocab.max_size = int(value) if value!='None' else None
|
||||
elif name == 'min_freq':
|
||||
vocab.min_freq = int(value) if value!='None' else None
|
||||
if name in ('max_size', 'min_freq'):
|
||||
value = int(value) if value!='None' else None
|
||||
setattr(vocab, name, value)
|
||||
elif name in ('unknown', 'padding'):
|
||||
value = value if value!='None' else None
|
||||
setattr(vocab, name, value)
|
||||
elif name == 'rebuild':
|
||||
vocab.rebuild = True if value=='True' else False
|
||||
@ -535,12 +535,12 @@ class Vocabulary(object):
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
parts = line.split()
|
||||
parts = line.split('\t')
|
||||
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3])
|
||||
if idx >= 0:
|
||||
word2idx[word] = idx
|
||||
word_counter[word] = count
|
||||
if no_create_entry_counter:
|
||||
if no_create_entry:
|
||||
no_create_entry_counter[word] = count
|
||||
|
||||
word_counter = Counter(word_counter)
|
||||
|
@ -214,7 +214,29 @@ class TestOther(unittest.TestCase):
|
||||
for idx in range(len(vocab)):
|
||||
self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx))
|
||||
self.assertEqual(vocab.unknown, new_vocab.unknown)
|
||||
except:
|
||||
|
||||
# 测试vocab中包含None的padding和unk
|
||||
vocab= Vocabulary(padding=None, unknown=None)
|
||||
words = list('abcdefaddfdkjfe')
|
||||
no_create_entry = list('12342331')
|
||||
|
||||
vocab.add_word_lst(words)
|
||||
vocab.add_word_lst(no_create_entry, no_create_entry=True)
|
||||
vocab.save(fp)
|
||||
|
||||
new_vocab = Vocabulary.load(fp)
|
||||
|
||||
for word, index in vocab:
|
||||
self.assertEqual(new_vocab.to_index(word), index)
|
||||
for word in no_create_entry:
|
||||
self.assertTrue(new_vocab._is_word_no_create_entry(word))
|
||||
for word in words:
|
||||
self.assertFalse(new_vocab._is_word_no_create_entry(word))
|
||||
for idx in range(len(vocab)):
|
||||
self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx))
|
||||
self.assertEqual(vocab.unknown, new_vocab.unknown)
|
||||
|
||||
finally:
|
||||
import os
|
||||
if os.path.exists(fp):
|
||||
os.remove(fp)
|
||||
|
Loading…
Reference in New Issue
Block a user