对trainer中check code的报错信息进行了增强;将tester中的output修改为pred_dict

This commit is contained in:
yh 2018-12-03 12:12:48 +08:00
parent d4af19ec1f
commit 77f8ac77da
4 changed files with 103 additions and 36 deletions

View File

@ -96,7 +96,7 @@ class MetricBase(object):
will be conducted)
:param pred_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target..
:param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`.
:param check: boolean, if check is True, it will force check `varargs, missing, unused, duplicated`.
:return:
"""
if not callable(self.evaluate):
@ -148,8 +148,8 @@ class MetricBase(object):
missing = check_res.missing
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \
f"in `{get_func_signature(self.evaluate)}`)"
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,

View File

@ -51,19 +51,18 @@ class Tester(object):
# turn on the testing mode; clean up the history
network = self._model
self._mode(network, is_test=True)
output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
eval_results = {}
try:
with torch.no_grad():
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self._predict_func, batch_x)
if not isinstance(prediction, dict):
pred_dict = self._data_forward(self._predict_func, batch_x)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} "
f"must be `dict`, got {type(prediction)}.")
f"must be `dict`, got {type(pred_dict)}.")
for metric in self.metrics:
metric(prediction, batch_y)
metric(pred_dict, batch_y)
for metric in self.metrics:
eval_result = metric.get_metric()
if not isinstance(eval_result, dict):
@ -74,7 +73,8 @@ class Tester(object):
except CheckError as e:
prev_func_signature = get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=truths, check_level=0)
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0)
if self.verbose >= 1:
print("[tester] \n{}".format(self._format_eval_results(eval_results)))

View File

@ -311,14 +311,14 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
batch_x=batch_x, check_level=check_level)
refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
pred_dict = model(**refined_batch_x)
func_signature = get_func_signature(model.forward)
if not isinstance(output, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.")
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.")
# loss check
try:
loss = losser(output, batch_y)
loss = losser(pred_dict, batch_y)
# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
@ -333,8 +333,8 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
except CheckError as e:
pre_func_signature = get_func_signature(model.forward)
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=batch_y,
check_level=check_level)
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=dataset, check_level=check_level)
model.zero_grad()
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
break

View File

@ -229,29 +229,72 @@ WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2
def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes,
output:dict, batch_y:dict, check_level=0):
pred_dict:dict, target_dict:dict, dataset, check_level=0):
errs = []
_unused = []
unuseds = []
_unused_field = []
_unused_param = []
suggestions = []
if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, "
f"please delete it.)")
if check_res.missing:
errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`"
f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).")
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of "
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ")
errs.append(f"\tvarargs: *{check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)
for _unused in check_res.unused:
if _unused in target_dict:
_unused_field.append(_unused)
else:
_unused_param.append(_unused)
if _unused_field:
unuseds.append([f"\tunused field: {_unused_field}"])
if _unused_param:
unuseds.append([f"\tunused param: {_unused_param}"])
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = []
_miss_out_dataset = []
for _miss in check_res.missing:
if '(' in _miss:
# if they are like 'SomeParam(assign to xxx)'
_miss = _miss.split('(')[0]
if _miss in dataset:
_miss_in_dataset.append(_miss)
else:
_miss_out_dataset.append(_miss)
if _miss_in_dataset:
suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now "
f"target is {list(target_dict.keys())}).")
if _miss_out_dataset:
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
f"target is {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).")
if _unused_field:
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
suggestions.append(_tmp)
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
if check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds)
if len(errs)>0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
raise NameError('\n'.join(errs))
if _unused:
sugg_str = ""
if len(suggestions)>1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
else:
sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str)
if check_res.unused:
if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'
_unused_warn = f'{check_res.unused} is not used by {func_signature}.'
warnings.warn(message=_unused_warn)
@ -260,21 +303,45 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
func_signature = get_func_signature(forward_func)
errs = []
suggestions = []
_unused = []
if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
errs.append(f"\tvarargs: {check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}. "
f"Please set {check_res.missing} as input.")
errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = []
_miss_out_dataset = []
for _miss in check_res.missing:
if _miss in dataset:
_miss_in_dataset.append(_miss)
else:
_miss_out_dataset.append(_miss)
if _miss_in_dataset:
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ")
if _miss_out_dataset:
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. "
if check_res.unused:
_tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \
f"rename the field in `unused field:`."
suggestions.append(_tmp)
if check_res.unused:
_unused = [f"\tunused param: {check_res.unused}"]
_unused = [f"\tunused field: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)
if len(errs)>0:
errs.insert(0, f'The following problems occurred when calling {func_signature}')
raise NameError('\n'.join(errs))
sugg_str = ""
if len(suggestions)>1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
else:
sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str)
if _unused:
if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'