mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
对trainer中check code的报错信息进行了增强;将tester中的output修改为pred_dict
This commit is contained in:
parent
d4af19ec1f
commit
77f8ac77da
@ -96,7 +96,7 @@ class MetricBase(object):
|
||||
will be conducted)
|
||||
:param pred_dict: usually the output of forward or prediction function
|
||||
:param target_dict: usually features set as target..
|
||||
:param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`.
|
||||
:param check: boolean, if check is True, it will force check `varargs, missing, unused, duplicated`.
|
||||
:return:
|
||||
"""
|
||||
if not callable(self.evaluate):
|
||||
@ -148,8 +148,8 @@ class MetricBase(object):
|
||||
missing = check_res.missing
|
||||
replaced_missing = list(missing)
|
||||
for idx, func_arg in enumerate(missing):
|
||||
replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \
|
||||
f"in `{get_func_signature(self.evaluate)}`)"
|
||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
|
||||
f"in `{self.__class__.__name__}`)"
|
||||
|
||||
check_res = CheckRes(missing=replaced_missing,
|
||||
unused=check_res.unused,
|
||||
|
@ -51,19 +51,18 @@ class Tester(object):
|
||||
# turn on the testing mode; clean up the history
|
||||
network = self._model
|
||||
self._mode(network, is_test=True)
|
||||
output, truths = defaultdict(list), defaultdict(list)
|
||||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
|
||||
eval_results = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for batch_x, batch_y in data_iterator:
|
||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
|
||||
prediction = self._data_forward(self._predict_func, batch_x)
|
||||
if not isinstance(prediction, dict):
|
||||
pred_dict = self._data_forward(self._predict_func, batch_x)
|
||||
if not isinstance(pred_dict, dict):
|
||||
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} "
|
||||
f"must be `dict`, got {type(prediction)}.")
|
||||
f"must be `dict`, got {type(pred_dict)}.")
|
||||
for metric in self.metrics:
|
||||
metric(prediction, batch_y)
|
||||
metric(pred_dict, batch_y)
|
||||
for metric in self.metrics:
|
||||
eval_result = metric.get_metric()
|
||||
if not isinstance(eval_result, dict):
|
||||
@ -74,7 +73,8 @@ class Tester(object):
|
||||
except CheckError as e:
|
||||
prev_func_signature = get_func_signature(self._predict_func)
|
||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
|
||||
check_res=e.check_res, output=output, batch_y=truths, check_level=0)
|
||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
|
||||
dataset=self.data, check_level=0)
|
||||
|
||||
if self.verbose >= 1:
|
||||
print("[tester] \n{}".format(self._format_eval_results(eval_results)))
|
||||
|
@ -311,14 +311,14 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
|
||||
batch_x=batch_x, check_level=check_level)
|
||||
|
||||
refined_batch_x = _build_args(model.forward, **batch_x)
|
||||
output = model(**refined_batch_x)
|
||||
pred_dict = model(**refined_batch_x)
|
||||
func_signature = get_func_signature(model.forward)
|
||||
if not isinstance(output, dict):
|
||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.")
|
||||
if not isinstance(pred_dict, dict):
|
||||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.")
|
||||
|
||||
# loss check
|
||||
try:
|
||||
loss = losser(output, batch_y)
|
||||
loss = losser(pred_dict, batch_y)
|
||||
# check loss output
|
||||
if batch_count == 0:
|
||||
if not isinstance(loss, torch.Tensor):
|
||||
@ -333,8 +333,8 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
|
||||
except CheckError as e:
|
||||
pre_func_signature = get_func_signature(model.forward)
|
||||
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
|
||||
check_res=e.check_res, output=output, batch_y=batch_y,
|
||||
check_level=check_level)
|
||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
|
||||
dataset=dataset, check_level=check_level)
|
||||
model.zero_grad()
|
||||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
|
||||
break
|
||||
|
@ -229,29 +229,72 @@ WARNING_CHECK_LEVEL = 1
|
||||
STRICT_CHECK_LEVEL = 2
|
||||
|
||||
def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes,
|
||||
output:dict, batch_y:dict, check_level=0):
|
||||
pred_dict:dict, target_dict:dict, dataset, check_level=0):
|
||||
errs = []
|
||||
_unused = []
|
||||
unuseds = []
|
||||
_unused_field = []
|
||||
_unused_param = []
|
||||
suggestions = []
|
||||
if check_res.varargs:
|
||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, "
|
||||
f"please delete it.)")
|
||||
if check_res.missing:
|
||||
errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`"
|
||||
f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).")
|
||||
if check_res.duplicated:
|
||||
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of "
|
||||
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ")
|
||||
errs.append(f"\tvarargs: *{check_res.varargs}")
|
||||
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
|
||||
|
||||
if check_res.unused:
|
||||
_unused = [f"\tunused param: {check_res.unused}"]
|
||||
if check_level == STRICT_CHECK_LEVEL:
|
||||
errs.extend(_unused)
|
||||
for _unused in check_res.unused:
|
||||
if _unused in target_dict:
|
||||
_unused_field.append(_unused)
|
||||
else:
|
||||
_unused_param.append(_unused)
|
||||
if _unused_field:
|
||||
unuseds.append([f"\tunused field: {_unused_field}"])
|
||||
if _unused_param:
|
||||
unuseds.append([f"\tunused param: {_unused_param}"])
|
||||
|
||||
if check_res.missing:
|
||||
errs.append(f"\tmissing param: {check_res.missing}")
|
||||
_miss_in_dataset = []
|
||||
_miss_out_dataset = []
|
||||
for _miss in check_res.missing:
|
||||
if '(' in _miss:
|
||||
# if they are like 'SomeParam(assign to xxx)'
|
||||
_miss = _miss.split('(')[0]
|
||||
if _miss in dataset:
|
||||
_miss_in_dataset.append(_miss)
|
||||
else:
|
||||
_miss_out_dataset.append(_miss)
|
||||
|
||||
if _miss_in_dataset:
|
||||
suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now "
|
||||
f"target is {list(target_dict.keys())}).")
|
||||
if _miss_out_dataset:
|
||||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
|
||||
f"target is {list(target_dict.keys())}) or output it "
|
||||
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
|
||||
if _unused_field:
|
||||
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
|
||||
suggestions.append(_tmp)
|
||||
|
||||
if check_res.duplicated:
|
||||
errs.append(f"\tduplicated param: {check_res.duplicated}.")
|
||||
suggestions.append(f"Delete {check_res.duplicated} in the output of "
|
||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
|
||||
|
||||
if check_level == STRICT_CHECK_LEVEL:
|
||||
errs.extend(unuseds)
|
||||
|
||||
if len(errs)>0:
|
||||
errs.insert(0, f'The following problems occurred when calling {func_signature}')
|
||||
raise NameError('\n'.join(errs))
|
||||
if _unused:
|
||||
sugg_str = ""
|
||||
if len(suggestions)>1:
|
||||
for idx, sugg in enumerate(suggestions):
|
||||
sugg_str += f'({idx+1}). {sugg}'
|
||||
else:
|
||||
sugg_str += suggestions[0]
|
||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
|
||||
raise NameError(err_str)
|
||||
if check_res.unused:
|
||||
if check_level == WARNING_CHECK_LEVEL:
|
||||
_unused_warn = _unused[0] + f' in {func_signature}.'
|
||||
_unused_warn = f'{check_res.unused} is not used by {func_signature}.'
|
||||
warnings.warn(message=_unused_warn)
|
||||
|
||||
|
||||
@ -260,21 +303,45 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
|
||||
func_signature = get_func_signature(forward_func)
|
||||
|
||||
errs = []
|
||||
suggestions = []
|
||||
_unused = []
|
||||
|
||||
if check_res.varargs:
|
||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
|
||||
errs.append(f"\tvarargs: {check_res.varargs}")
|
||||
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
|
||||
if check_res.missing:
|
||||
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}. "
|
||||
f"Please set {check_res.missing} as input.")
|
||||
errs.append(f"\tmissing param: {check_res.missing}")
|
||||
_miss_in_dataset = []
|
||||
_miss_out_dataset = []
|
||||
for _miss in check_res.missing:
|
||||
if _miss in dataset:
|
||||
_miss_in_dataset.append(_miss)
|
||||
else:
|
||||
_miss_out_dataset.append(_miss)
|
||||
if _miss_in_dataset:
|
||||
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ")
|
||||
if _miss_out_dataset:
|
||||
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. "
|
||||
if check_res.unused:
|
||||
_tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \
|
||||
f"rename the field in `unused field:`."
|
||||
suggestions.append(_tmp)
|
||||
|
||||
if check_res.unused:
|
||||
_unused = [f"\tunused param: {check_res.unused}"]
|
||||
_unused = [f"\tunused field: {check_res.unused}"]
|
||||
if check_level == STRICT_CHECK_LEVEL:
|
||||
errs.extend(_unused)
|
||||
|
||||
if len(errs)>0:
|
||||
errs.insert(0, f'The following problems occurred when calling {func_signature}')
|
||||
raise NameError('\n'.join(errs))
|
||||
sugg_str = ""
|
||||
if len(suggestions)>1:
|
||||
for idx, sugg in enumerate(suggestions):
|
||||
sugg_str += f'({idx+1}). {sugg}'
|
||||
else:
|
||||
sugg_str += suggestions[0]
|
||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
|
||||
raise NameError(err_str)
|
||||
if _unused:
|
||||
if check_level == WARNING_CHECK_LEVEL:
|
||||
_unused_warn = _unused[0] + f' in {func_signature}.'
|
||||
|
Loading…
Reference in New Issue
Block a user