mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 13:17:51 +08:00
[new] add seperator for conll loader (#293)
* add ConfusionMatrix, ConfusionMatrixMetric * add confusionmatrix to utils * add ConfusionMatrixmetric * add ConfusionMatrixMetric * init for test * begin test * test finish * doc finish * revised confusion * revised two * revise two * add sep for conll loader * with remote * withdraw some update * finish merge * update test * update test * to avoid none situation
This commit is contained in:
parent
ae7b916355
commit
4e95989e97
@ -316,6 +316,7 @@ class ConfusionMatrixMetric(MetricBase):
|
|||||||
print_ratio=False
|
print_ratio=False
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
:param vocab: vocab词表类,要求有to_word()方法。
|
:param vocab: vocab词表类,要求有to_word()方法。
|
||||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
|
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
|
||||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
|
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
|
||||||
@ -332,7 +333,6 @@ class ConfusionMatrixMetric(MetricBase):
|
|||||||
def evaluate(self, pred, target, seq_len=None):
|
def evaluate(self, pred, target, seq_len=None):
|
||||||
r"""
|
r"""
|
||||||
evaluate函数将针对一个批次的预测结果做评价指标的累计
|
evaluate函数将针对一个批次的预测结果做评价指标的累计
|
||||||
|
|
||||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
|
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
|
||||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
|
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
|
||||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
|
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
|
||||||
|
@ -62,6 +62,7 @@ class ConfusionMatrix:
|
|||||||
target = [2,2,1]
|
target = [2,2,1]
|
||||||
confusion.add_pred_target(pred, target)
|
confusion.add_pred_target(pred, target)
|
||||||
print(confusion)
|
print(confusion)
|
||||||
|
|
||||||
target 1 2 3 all
|
target 1 2 3 all
|
||||||
pred
|
pred
|
||||||
1 0 1 0 1
|
1 0 1 0 1
|
||||||
@ -157,7 +158,6 @@ class ConfusionMatrix:
|
|||||||
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
|
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
|
||||||
for k in totallabel
|
for k in totallabel
|
||||||
])
|
])
|
||||||
|
|
||||||
for label, idx in zip(totallabel, range(lenth)):
|
for label, idx in zip(totallabel, range(lenth)):
|
||||||
idx2row[
|
idx2row[
|
||||||
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
|
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
|
||||||
|
@ -81,12 +81,13 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True):
|
|||||||
yield line_idx, _res
|
yield line_idx, _res
|
||||||
|
|
||||||
|
|
||||||
def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
|
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
|
||||||
r"""
|
r"""
|
||||||
Construct a generator to read conll items.
|
Construct a generator to read conll items.
|
||||||
|
|
||||||
:param path: file path
|
:param path: file path
|
||||||
:param encoding: file's encoding, default: utf-8
|
:param encoding: file's encoding, default: utf-8
|
||||||
|
:param sep: seperator
|
||||||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
|
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
|
||||||
:param dropna: weather to ignore and drop invalid data,
|
:param dropna: weather to ignore and drop invalid data,
|
||||||
:if False, raise ValueError when reading invalid data. default: True
|
:if False, raise ValueError when reading invalid data. default: True
|
||||||
@ -105,7 +106,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
|
|||||||
sample = []
|
sample = []
|
||||||
start = next(f).strip()
|
start = next(f).strip()
|
||||||
if start != '':
|
if start != '':
|
||||||
sample.append(start.split())
|
sample.append(start.split(sep)) if sep else sample.append(start.split())
|
||||||
for line_idx, line in enumerate(f, 1):
|
for line_idx, line in enumerate(f, 1):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line == '':
|
if line == '':
|
||||||
@ -123,7 +124,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
|
|||||||
elif line.startswith('#'):
|
elif line.startswith('#'):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
sample.append(line.split())
|
sample.append(line.split(sep)) if sep else sample.append(line.split())
|
||||||
if len(sample) > 0:
|
if len(sample) > 0:
|
||||||
try:
|
try:
|
||||||
res = parse_conll(sample)
|
res = parse_conll(sample)
|
||||||
|
@ -55,10 +55,11 @@ class ConllLoader(Loader):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, headers, indexes=None, dropna=True):
|
def __init__(self, headers, sep=None, indexes=None, dropna=True):
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
|
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
|
||||||
|
:param list sep: 指定分隔符,默认为制表符
|
||||||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
|
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
|
||||||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
|
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
|
||||||
"""
|
"""
|
||||||
@ -68,6 +69,7 @@ class ConllLoader(Loader):
|
|||||||
'invalid headers: {}, should be list of strings'.format(headers))
|
'invalid headers: {}, should be list of strings'.format(headers))
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.dropna = dropna
|
self.dropna = dropna
|
||||||
|
self.sep=sep
|
||||||
if indexes is None:
|
if indexes is None:
|
||||||
self.indexes = list(range(len(self.headers)))
|
self.indexes = list(range(len(self.headers)))
|
||||||
else:
|
else:
|
||||||
@ -83,7 +85,7 @@ class ConllLoader(Loader):
|
|||||||
:return: DataSet
|
:return: DataSet
|
||||||
"""
|
"""
|
||||||
ds = DataSet()
|
ds = DataSet()
|
||||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
|
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna):
|
||||||
ins = {h: data[i] for i, h in enumerate(self.headers)}
|
ins = {h: data[i] for i, h in enumerate(self.headers)}
|
||||||
ds.append(Instance(**ins))
|
ds.append(Instance(**ins))
|
||||||
return ds
|
return ds
|
||||||
|
@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result):
|
|||||||
return allen_result
|
return allen_result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfusionMatrixMetric(unittest.TestCase):
|
class TestConfusionMatrixMetric(unittest.TestCase):
|
||||||
def test_ConfusionMatrixMetric1(self):
|
def test_ConfusionMatrixMetric1(self):
|
||||||
pred_dict = {"pred": torch.zeros(4,3)}
|
pred_dict = {"pred": torch.zeros(4,3)}
|
||||||
@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
|
|||||||
|
|
||||||
def test_ConfusionMatrixMetric2(self):
|
def test_ConfusionMatrixMetric2(self):
|
||||||
# (2) with corrupted size
|
# (2) with corrupted size
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
pred_dict = {"pred": torch.zeros(4, 3, 2)}
|
pred_dict = {"pred": torch.zeros(4, 3, 2)}
|
||||||
target_dict = {'target': torch.zeros(4)}
|
target_dict = {'target': torch.zeros(4)}
|
||||||
@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase):
|
|||||||
|
|
||||||
print(metric.get_metric())
|
print(metric.get_metric())
|
||||||
|
|
||||||
|
|
||||||
def test_ConfusionMatrixMetric4(self):
|
def test_ConfusionMatrixMetric4(self):
|
||||||
# (4) check reset
|
# (4) check reset
|
||||||
metric = ConfusionMatrixMetric()
|
metric = ConfusionMatrixMetric()
|
||||||
@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
|
|||||||
|
|
||||||
def test_ConfusionMatrixMetric5(self):
|
def test_ConfusionMatrixMetric5(self):
|
||||||
# (5) check numpy array is not acceptable
|
# (5) check numpy array is not acceptable
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
metric = ConfusionMatrixMetric()
|
metric = ConfusionMatrixMetric()
|
||||||
pred_dict = {"pred": np.zeros((4, 3, 2))}
|
pred_dict = {"pred": np.zeros((4, 3, 2))}
|
||||||
@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
|
|||||||
metric(pred_dict=pred_dict, target_dict=target_dict)
|
metric(pred_dict=pred_dict, target_dict=target_dict)
|
||||||
print(metric.get_metric())
|
print(metric.get_metric())
|
||||||
|
|
||||||
|
|
||||||
def test_duplicate(self):
|
def test_duplicate(self):
|
||||||
# 0.4.1的潜在bug,不能出现形参重复的情况
|
# 0.4.1的潜在bug,不能出现形参重复的情况
|
||||||
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
|
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
|
||||||
@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
|
|||||||
metric(pred_dict=pred_dict, target_dict=target_dict)
|
metric(pred_dict=pred_dict, target_dict=target_dict)
|
||||||
print(metric.get_metric())
|
print(metric.get_metric())
|
||||||
|
|
||||||
|
|
||||||
def test_seq_len(self):
|
def test_seq_len(self):
|
||||||
N = 256
|
N = 256
|
||||||
seq_len = torch.zeros(N).long()
|
seq_len = torch.zeros(N).long()
|
||||||
@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
|
|||||||
print(metric.get_metric())
|
print(metric.get_metric())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestAccuracyMetric(unittest.TestCase):
|
class TestAccuracyMetric(unittest.TestCase):
|
||||||
def test_AccuracyMetric1(self):
|
def test_AccuracyMetric1(self):
|
||||||
# (1) only input, targets passed
|
# (1) only input, targets passed
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \
|
from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \
|
||||||
Conll2003Loader
|
Conll2003Loader, ConllLoader
|
||||||
|
|
||||||
|
|
||||||
class TestMSRANER(unittest.TestCase):
|
class TestMSRANER(unittest.TestCase):
|
||||||
@ -35,3 +35,10 @@ class TestConllLoader(unittest.TestCase):
|
|||||||
db = Conll2003Loader().load('test/data_for_tests/io/conll2003')
|
db = Conll2003Loader().load('test/data_for_tests/io/conll2003')
|
||||||
print(db)
|
print(db)
|
||||||
|
|
||||||
|
class TestConllLoader(unittest.TestCase):
|
||||||
|
def test_sep(self):
|
||||||
|
headers = [
|
||||||
|
'raw_words', 'ner',
|
||||||
|
]
|
||||||
|
db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER')
|
||||||
|
print(db)
|
||||||
|
Loading…
Reference in New Issue
Block a user