mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0
This commit is contained in:
commit
8142bad87a
@ -23,6 +23,7 @@ from .utils import _get_func_signature
|
||||
from .utils import seq_len_to_mask
|
||||
from .vocabulary import Vocabulary
|
||||
from abc import abstractmethod
|
||||
import warnings
|
||||
|
||||
|
||||
class MetricBase(object):
|
||||
@ -492,6 +493,30 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
|
||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]
|
||||
|
||||
|
||||
def _check_tag_vocab_and_encoding_type(vocab:Vocabulary, encoding_type:str):
|
||||
"""
|
||||
检查vocab中的tag是否与encoding_type是匹配的
|
||||
|
||||
:param vocab: target的Vocabulary
|
||||
:param encoding_type: bio, bmes, bioes, bmeso
|
||||
:return:
|
||||
"""
|
||||
tag_set = set()
|
||||
for tag, idx in vocab:
|
||||
if idx in (vocab.unknown_idx, vocab.padding_idx):
|
||||
continue
|
||||
tag = tag[:1].lower()
|
||||
tag_set.add(tag)
|
||||
tags = encoding_type
|
||||
for tag in tag_set:
|
||||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \
|
||||
f"encoding_type."
|
||||
tags = tags.replace(tag, '') # 删除该值
|
||||
if tags: # 如果不为空,说明出现了未使用的tag
|
||||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
|
||||
"encoding_type.")
|
||||
|
||||
|
||||
class SpanFPreRecMetric(MetricBase):
|
||||
r"""
|
||||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric`
|
||||
@ -546,6 +571,7 @@ class SpanFPreRecMetric(MetricBase):
|
||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
|
||||
|
||||
self.encoding_type = encoding_type
|
||||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type)
|
||||
if self.encoding_type == 'bmes':
|
||||
self.tag_to_span_func = _bmes_tag_to_spans
|
||||
elif self.encoding_type == 'bio':
|
||||
|
@ -338,6 +338,41 @@ class SpanF1PreRecMetric(unittest.TestCase):
|
||||
for key, value in expected_metric.items():
|
||||
self.assertAlmostEqual(value, metric_value[key], places=5)
|
||||
|
||||
def test_encoding_type(self):
|
||||
# 检查传入的tag_vocab与encoding_type不符合时,是否会报错
|
||||
vocabs = {}
|
||||
import random
|
||||
from itertools import product
|
||||
for encoding_type in ['bio', 'bioes', 'bmeso']:
|
||||
vocab = Vocabulary(unknown=None, padding=None)
|
||||
for i in range(random.randint(10, 100)):
|
||||
label = str(random.randint(1, 10))
|
||||
for tag in encoding_type:
|
||||
if tag!='o':
|
||||
vocab.add_word(f'{tag}-{label}')
|
||||
else:
|
||||
vocab.add_word('o')
|
||||
vocabs[encoding_type] = vocab
|
||||
for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']):
|
||||
with self.subTest(e1=e1, e2=e2):
|
||||
if e1==e2:
|
||||
metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2)
|
||||
else:
|
||||
s2 = set(e2)
|
||||
s2.update(set(e1))
|
||||
if s2==set(e2):
|
||||
continue
|
||||
with self.assertRaises(AssertionError):
|
||||
metric = SpanFPreRecMetric(vocabs[e1], encoding_type=e2)
|
||||
for encoding_type in ['bio', 'bioes', 'bmeso']:
|
||||
with self.assertRaises(AssertionError):
|
||||
metric = SpanFPreRecMetric(vocabs[encoding_type], encoding_type='bmes')
|
||||
|
||||
with self.assertWarns(Warning):
|
||||
vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes'))
|
||||
metric = SpanFPreRecMetric(vocab, encoding_type='bmeso')
|
||||
vocab = Vocabulary().add_word_lst(list('bmes'))
|
||||
metric = SpanFPreRecMetric(vocab, encoding_type='bmeso')
|
||||
|
||||
class TestUsefulFunctions(unittest.TestCase):
|
||||
# 测试metrics.py中一些看上去挺有用的函数
|
||||
|
Loading…
Reference in New Issue
Block a user