增加对ConfusionMatrix的测试用例

This commit is contained in:
yh_cc 2020-03-24 16:09:00 +08:00
parent b3ace23d11
commit 885c74022c
2 changed files with 27 additions and 11 deletions

View File

@ -34,8 +34,6 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
'varargs'])
class ConfusionMatrix:
"""a dict can provide Confusion Matrix"""
def __init__(self, vocab=None, print_ratio=False):
@ -83,7 +81,7 @@ class ConfusionMatrix:
def clear(self):
"""
除一些值等待再次新加入
空ConfusionMatrix等待再次新加入
:return:
"""
self.confusiondict = {}
@ -102,11 +100,6 @@ class ConfusionMatrix:
set(self.targetcount.keys()).union(set(
self.predcount.keys()))))
lenth = len(totallabel)
# namedict key :idx value:word/idx
namedict = dict([
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
for k in totallabel
])
for label, idx in zip(totallabel, range(lenth)):
idx2row[
@ -116,7 +109,6 @@ class ConfusionMatrix:
output = []
for i in row2idx.keys(): # 第i行
p = row2idx[i]
h = namedict[p]
l = [0 for _ in range(lenth)]
if self.confusiondict.get(p, None):
for t, c in self.confusiondict[p].items():
@ -141,7 +133,7 @@ class ConfusionMatrix:
tmp = tmp * 100
elif dim == 1:
tmp = np.array(result).T
mp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12)
tmp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12)
tmp = tmp.T * 100
tmp = np.around(tmp, decimals=2)
return tmp.tolist()
@ -172,7 +164,6 @@ class ConfusionMatrix:
row2idx[
idx] = label # 建立一个临时字典value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
# 这里打印东西
col_lenths = []
out = str()
output = []
# 表头

View File

@ -288,3 +288,28 @@ class TestUtils(unittest.TestCase):
self.assertSequenceEqual(convert_tags, iob2bioes(tags))
class TestConfusionMatrix(unittest.TestCase):
def test1(self):
# 测试能否正常打印
from fastNLP import Vocabulary
from fastNLP.core.utils import ConfusionMatrix
import numpy as np
vocab = Vocabulary(unknown=None, padding=None)
vocab.add_word_lst(list('abcdef'))
confusion_matrix = ConfusionMatrix(vocab)
for _ in range(3):
length = np.random.randint(1, 5)
pred = np.random.randint(0, 3, size=(length,))
target = np.random.randint(0, 3, size=(length,))
confusion_matrix.add_pred_target(pred, target)
print(confusion_matrix)
# 测试print_ratio
confusion_matrix = ConfusionMatrix(vocab, print_ratio=True)
for _ in range(3):
length = np.random.randint(1, 5)
pred = np.random.randint(0, 3, size=(length,))
target = np.random.randint(0, 3, size=(length,))
confusion_matrix.add_pred_target(pred, target)
print(confusion_matrix)