mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
增加对ConfusionMatrix的测试用例
This commit is contained in:
parent
b3ace23d11
commit
885c74022c
@ -34,8 +34,6 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
|
||||
'varargs'])
|
||||
|
||||
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
"""a dict can provide Confusion Matrix"""
|
||||
def __init__(self, vocab=None, print_ratio=False):
|
||||
@ -83,7 +81,7 @@ class ConfusionMatrix:
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清除一些值,等待再次新加入
|
||||
清空ConfusionMatrix,等待再次新加入
|
||||
:return:
|
||||
"""
|
||||
self.confusiondict = {}
|
||||
@ -102,11 +100,6 @@ class ConfusionMatrix:
|
||||
set(self.targetcount.keys()).union(set(
|
||||
self.predcount.keys()))))
|
||||
lenth = len(totallabel)
|
||||
# namedict key :idx value:word/idx
|
||||
namedict = dict([
|
||||
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
|
||||
for k in totallabel
|
||||
])
|
||||
|
||||
for label, idx in zip(totallabel, range(lenth)):
|
||||
idx2row[
|
||||
@ -116,7 +109,6 @@ class ConfusionMatrix:
|
||||
output = []
|
||||
for i in row2idx.keys(): # 第i行
|
||||
p = row2idx[i]
|
||||
h = namedict[p]
|
||||
l = [0 for _ in range(lenth)]
|
||||
if self.confusiondict.get(p, None):
|
||||
for t, c in self.confusiondict[p].items():
|
||||
@ -141,7 +133,7 @@ class ConfusionMatrix:
|
||||
tmp = tmp * 100
|
||||
elif dim == 1:
|
||||
tmp = np.array(result).T
|
||||
mp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12)
|
||||
tmp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12)
|
||||
tmp = tmp.T * 100
|
||||
tmp = np.around(tmp, decimals=2)
|
||||
return tmp.tolist()
|
||||
@ -172,7 +164,6 @@ class ConfusionMatrix:
|
||||
row2idx[
|
||||
idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
|
||||
# 这里打印东西
|
||||
col_lenths = []
|
||||
out = str()
|
||||
output = []
|
||||
# 表头
|
||||
|
@ -288,3 +288,28 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
self.assertSequenceEqual(convert_tags, iob2bioes(tags))
|
||||
|
||||
class TestConfusionMatrix(unittest.TestCase):
|
||||
def test1(self):
|
||||
# 测试能否正常打印
|
||||
from fastNLP import Vocabulary
|
||||
from fastNLP.core.utils import ConfusionMatrix
|
||||
import numpy as np
|
||||
vocab = Vocabulary(unknown=None, padding=None)
|
||||
vocab.add_word_lst(list('abcdef'))
|
||||
confusion_matrix = ConfusionMatrix(vocab)
|
||||
for _ in range(3):
|
||||
length = np.random.randint(1, 5)
|
||||
pred = np.random.randint(0, 3, size=(length,))
|
||||
target = np.random.randint(0, 3, size=(length,))
|
||||
confusion_matrix.add_pred_target(pred, target)
|
||||
print(confusion_matrix)
|
||||
|
||||
# 测试print_ratio
|
||||
confusion_matrix = ConfusionMatrix(vocab, print_ratio=True)
|
||||
for _ in range(3):
|
||||
length = np.random.randint(1, 5)
|
||||
pred = np.random.randint(0, 3, size=(length,))
|
||||
target = np.random.randint(0, 3, size=(length,))
|
||||
confusion_matrix.add_pred_target(pred, target)
|
||||
print(confusion_matrix)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user