修复Vocabulary在load的时候可能发生的bug

This commit is contained in:
yh_cc 2020-04-13 17:03:53 +08:00
parent 2dee67129a
commit 18747e632e
2 changed files with 29 additions and 7 deletions

View File

@ -519,11 +519,11 @@ class Vocabulary(object):
line = line.strip()
if line:
name, value = line.split()
if name == 'max_size':
vocab.max_size = int(value) if value!='None' else None
elif name == 'min_freq':
vocab.min_freq = int(value) if value!='None' else None
if name in ('max_size', 'min_freq'):
value = int(value) if value!='None' else None
setattr(vocab, name, value)
elif name in ('unknown', 'padding'):
value = value if value!='None' else None
setattr(vocab, name, value)
elif name == 'rebuild':
vocab.rebuild = True if value=='True' else False
@ -535,12 +535,12 @@ class Vocabulary(object):
for line in f:
line = line.strip()
if line:
parts = line.split()
parts = line.split('\t')
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3])
if idx >= 0:
word2idx[word] = idx
word_counter[word] = count
if no_create_entry_counter:
if no_create_entry:
no_create_entry_counter[word] = count
word_counter = Counter(word_counter)

View File

@ -214,7 +214,29 @@ class TestOther(unittest.TestCase):
for idx in range(len(vocab)):
self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx))
self.assertEqual(vocab.unknown, new_vocab.unknown)
except:
# 测试vocab中包含None的padding和unk
vocab= Vocabulary(padding=None, unknown=None)
words = list('abcdefaddfdkjfe')
no_create_entry = list('12342331')
vocab.add_word_lst(words)
vocab.add_word_lst(no_create_entry, no_create_entry=True)
vocab.save(fp)
new_vocab = Vocabulary.load(fp)
for word, index in vocab:
self.assertEqual(new_vocab.to_index(word), index)
for word in no_create_entry:
self.assertTrue(new_vocab._is_word_no_create_entry(word))
for word in words:
self.assertFalse(new_vocab._is_word_no_create_entry(word))
for idx in range(len(vocab)):
self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx))
self.assertEqual(vocab.unknown, new_vocab.unknown)
finally:
import os
if os.path.exists(fp):
os.remove(fp)