diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 0232f56f..525ac8ca 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -519,11 +519,11 @@ class Vocabulary(object): line = line.strip() if line: name, value = line.split() - if name == 'max_size': - vocab.max_size = int(value) if value!='None' else None - elif name == 'min_freq': - vocab.min_freq = int(value) if value!='None' else None + if name in ('max_size', 'min_freq'): + value = int(value) if value!='None' else None + setattr(vocab, name, value) elif name in ('unknown', 'padding'): + value = value if value!='None' else None setattr(vocab, name, value) elif name == 'rebuild': vocab.rebuild = True if value=='True' else False @@ -535,12 +535,12 @@ class Vocabulary(object): for line in f: line = line.strip() if line: - parts = line.split() + parts = line.split('\t') word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) if idx >= 0: word2idx[word] = idx word_counter[word] = count - if no_create_entry_counter: + if no_create_entry: no_create_entry_counter[word] = count word_counter = Counter(word_counter) diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index 81a01092..2aa7b26a 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -214,7 +214,29 @@ class TestOther(unittest.TestCase): for idx in range(len(vocab)): self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx)) self.assertEqual(vocab.unknown, new_vocab.unknown) - except: + + # 测试vocab中包含None的padding和unk + vocab= Vocabulary(padding=None, unknown=None) + words = list('abcdefaddfdkjfe') + no_create_entry = list('12342331') + + vocab.add_word_lst(words) + vocab.add_word_lst(no_create_entry, no_create_entry=True) + vocab.save(fp) + + new_vocab = Vocabulary.load(fp) + + for word, index in vocab: + self.assertEqual(new_vocab.to_index(word), index) + for word in no_create_entry: + self.assertTrue(new_vocab._is_word_no_create_entry(word)) + for word in words: + self.assertFalse(new_vocab._is_word_no_create_entry(word)) + for idx in range(len(vocab)): + self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx)) + self.assertEqual(vocab.unknown, new_vocab.unknown) + + finally: import os if os.path.exists(fp): os.remove(fp)