mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 12:17:35 +08:00
1. 优化trainer checkcode过程的报错信息
This commit is contained in:
parent
f7c29b85d7
commit
1158556236
@ -69,7 +69,7 @@ class DataSet(object):
|
|||||||
self.idx = idx
|
self.idx = idx
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx])
|
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[self.idx])
|
||||||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
|
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
|
||||||
return self.dataset.field_arrays[item][self.idx]
|
return self.dataset.field_arrays[item][self.idx]
|
||||||
|
|
||||||
|
@ -83,7 +83,8 @@ class FieldArray(object):
|
|||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
# content is a 1-D list
|
# content is a 1-D list
|
||||||
if len(content) == 0:
|
if len(content) == 0:
|
||||||
raise RuntimeError("Cannot create FieldArray with an empty list.")
|
# the old error is not informative enough.
|
||||||
|
raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.")
|
||||||
type_set = set([type(item) for item in content])
|
type_set = set([type(item) for item in content])
|
||||||
|
|
||||||
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES:
|
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES:
|
||||||
@ -164,11 +165,13 @@ class FieldArray(object):
|
|||||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
|
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
|
||||||
if not is_iterable(self.content[0]):
|
if not is_iterable(self.content[0]):
|
||||||
array = np.array([self.content[i] for i in indices], dtype=self.dtype)
|
array = np.array([self.content[i] for i in indices], dtype=self.dtype)
|
||||||
else:
|
elif self.dtype in (np.int64, np.float64):
|
||||||
max_len = max([len(self.content[i]) for i in indices])
|
max_len = max([len(self.content[i]) for i in indices])
|
||||||
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)
|
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)
|
||||||
for i, idx in enumerate(indices):
|
for i, idx in enumerate(indices):
|
||||||
array[i][:len(self.content[idx])] = self.content[idx]
|
array[i][:len(self.content[idx])] = self.content[idx]
|
||||||
|
else: # should only be str
|
||||||
|
array = np.array([self.content[i] for i in indices])
|
||||||
return array
|
return array
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -80,7 +80,7 @@ class LossBase(object):
|
|||||||
fast_param = {}
|
fast_param = {}
|
||||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
|
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
|
||||||
fast_param['pred'] = list(pred_dict.values())[0]
|
fast_param['pred'] = list(pred_dict.values())[0]
|
||||||
fast_param['target'] = list(pred_dict.values())[0]
|
fast_param['target'] = list(target_dict.values())[0]
|
||||||
return fast_param
|
return fast_param
|
||||||
return fast_param
|
return fast_param
|
||||||
|
|
||||||
@ -134,10 +134,11 @@ class LossBase(object):
|
|||||||
# missing
|
# missing
|
||||||
if not self._checked:
|
if not self._checked:
|
||||||
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict])
|
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict])
|
||||||
# only check missing.
|
# replace missing.
|
||||||
missing = check_res.missing
|
missing = check_res.missing
|
||||||
replaced_missing = list(missing)
|
replaced_missing = list(missing)
|
||||||
for idx, func_arg in enumerate(missing):
|
for idx, func_arg in enumerate(missing):
|
||||||
|
# Don't delete `` in this information, nor add ``
|
||||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
|
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
|
||||||
f"in `{self.__class__.__name__}`)"
|
f"in `{self.__class__.__name__}`)"
|
||||||
|
|
||||||
@ -188,7 +189,7 @@ class CrossEntropyLoss(LossBase):
|
|||||||
class L1Loss(LossBase):
|
class L1Loss(LossBase):
|
||||||
def __init__(self, pred=None, target=None):
|
def __init__(self, pred=None, target=None):
|
||||||
super(L1Loss, self).__init__()
|
super(L1Loss, self).__init__()
|
||||||
self._init_param_map(input=pred, target=target)
|
self._init_param_map(pred=pred, target=target)
|
||||||
|
|
||||||
def get_loss(self, pred, target):
|
def get_loss(self, pred, target):
|
||||||
return F.l1_loss(input=pred, target=target)
|
return F.l1_loss(input=pred, target=target)
|
||||||
@ -197,7 +198,7 @@ class L1Loss(LossBase):
|
|||||||
class BCELoss(LossBase):
|
class BCELoss(LossBase):
|
||||||
def __init__(self, pred=None, target=None):
|
def __init__(self, pred=None, target=None):
|
||||||
super(BCELoss, self).__init__()
|
super(BCELoss, self).__init__()
|
||||||
self._init_param_map(input=pred, target=target)
|
self._init_param_map(pred=pred, target=target)
|
||||||
|
|
||||||
def get_loss(self, pred, target):
|
def get_loss(self, pred, target):
|
||||||
return F.binary_cross_entropy(input=pred, target=target)
|
return F.binary_cross_entropy(input=pred, target=target)
|
||||||
@ -205,7 +206,7 @@ class BCELoss(LossBase):
|
|||||||
class NLLLoss(LossBase):
|
class NLLLoss(LossBase):
|
||||||
def __init__(self, pred=None, target=None):
|
def __init__(self, pred=None, target=None):
|
||||||
super(NLLLoss, self).__init__()
|
super(NLLLoss, self).__init__()
|
||||||
self._init_param_map(input=pred, target=target)
|
self._init_param_map(pred=pred, target=target)
|
||||||
|
|
||||||
def get_loss(self, pred, target):
|
def get_loss(self, pred, target):
|
||||||
return F.nll_loss(input=pred, target=target)
|
return F.nll_loss(input=pred, target=target)
|
||||||
|
@ -151,9 +151,11 @@ class MetricBase(object):
|
|||||||
if not self._checked:
|
if not self._checked:
|
||||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
|
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
|
||||||
# only check missing.
|
# only check missing.
|
||||||
|
# replace missing.
|
||||||
missing = check_res.missing
|
missing = check_res.missing
|
||||||
replaced_missing = list(missing)
|
replaced_missing = list(missing)
|
||||||
for idx, func_arg in enumerate(missing):
|
for idx, func_arg in enumerate(missing):
|
||||||
|
# Don't delete `` in this information, nor add ``
|
||||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
|
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
|
||||||
f"in `{self.__class__.__name__}`)"
|
f"in `{self.__class__.__name__}`)"
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from tqdm import tqdm
|
from tqdm.autonotebook import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
@ -23,7 +23,6 @@ from fastNLP.core.utils import _check_forward_error
|
|||||||
from fastNLP.core.utils import _check_loss_evaluate
|
from fastNLP.core.utils import _check_loss_evaluate
|
||||||
from fastNLP.core.utils import _move_dict_value_to_device
|
from fastNLP.core.utils import _move_dict_value_to_device
|
||||||
from fastNLP.core.utils import get_func_signature
|
from fastNLP.core.utils import get_func_signature
|
||||||
from fastNLP.core.utils import _relocate_pbar
|
|
||||||
|
|
||||||
class Trainer(object):
|
class Trainer(object):
|
||||||
"""Main Training Loop
|
"""Main Training Loop
|
||||||
@ -45,7 +44,7 @@ class Trainer(object):
|
|||||||
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
|
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
|
||||||
:param DataSet dev_data: the validation data
|
:param DataSet dev_data: the validation data
|
||||||
:param use_cuda:
|
:param use_cuda:
|
||||||
:param str save_path: file path to save models
|
:param save_path: file path to save models
|
||||||
:param Optimizer optimizer: an optimizer object
|
:param Optimizer optimizer: an optimizer object
|
||||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.
|
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.
|
||||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
|
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
|
||||||
@ -149,7 +148,7 @@ class Trainer(object):
|
|||||||
self._mode(self.model, is_test=False)
|
self._mode(self.model, is_test=False)
|
||||||
|
|
||||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||||
print("training epochs started " + self.start_time)
|
print("training epochs started " + self.start_time, flush=True)
|
||||||
if self.save_path is None:
|
if self.save_path is None:
|
||||||
class psudoSW:
|
class psudoSW:
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
@ -172,12 +171,12 @@ class Trainer(object):
|
|||||||
del self._summary_writer
|
del self._summary_writer
|
||||||
|
|
||||||
def _tqdm_train(self):
|
def _tqdm_train(self):
|
||||||
|
self.step = 0
|
||||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
|
||||||
as_numpy=False)
|
as_numpy=False)
|
||||||
total_steps = data_iterator.num_batches*self.n_epochs
|
total_steps = data_iterator.num_batches*self.n_epochs
|
||||||
epoch = 1
|
epoch = 1
|
||||||
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', desc="Epoch {}/{}"
|
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
||||||
.format(epoch, self.n_epochs), leave=False, dynamic_ncols=True) as pbar:
|
|
||||||
ava_loss = 0
|
ava_loss = 0
|
||||||
for epoch in range(1, self.n_epochs+1):
|
for epoch in range(1, self.n_epochs+1):
|
||||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
|
||||||
@ -195,28 +194,26 @@ class Trainer(object):
|
|||||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
|
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
|
||||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
|
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
|
||||||
if (self.step+1) % self.print_every == 0:
|
if (self.step+1) % self.print_every == 0:
|
||||||
pbar.update(self.print_every)
|
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every))
|
||||||
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss/self.print_every))
|
|
||||||
ava_loss = 0
|
ava_loss = 0
|
||||||
|
pbar.update(1)
|
||||||
self.step += 1
|
self.step += 1
|
||||||
if self.validate_every > 0 and self.step % self.validate_every == 0 \
|
if self.validate_every > 0 and self.step % self.validate_every == 0 \
|
||||||
and self.dev_data is not None:
|
and self.dev_data is not None:
|
||||||
eval_res = self._do_validation()
|
eval_res = self._do_validation()
|
||||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
|
||||||
self.tester._format_eval_results(eval_res)
|
self.tester._format_eval_results(eval_res)
|
||||||
pbar = _relocate_pbar(pbar, print_str=eval_str)
|
pbar.write(eval_str)
|
||||||
if self.validate_every < 0 and self.dev_data:
|
if self.validate_every < 0 and self.dev_data:
|
||||||
eval_res = self._do_validation()
|
eval_res = self._do_validation()
|
||||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
|
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
|
||||||
self.tester._format_eval_results(eval_res)
|
self.tester._format_eval_results(eval_res)
|
||||||
pbar = _relocate_pbar(pbar, print_str=eval_str)
|
pbar.write(eval_str)
|
||||||
if epoch!=self.n_epochs:
|
if epoch!=self.n_epochs:
|
||||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
|
||||||
as_numpy=False)
|
as_numpy=False)
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
def _print_train(self):
|
def _print_train(self):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -264,9 +261,6 @@ class Trainer(object):
|
|||||||
self._do_validation()
|
self._do_validation()
|
||||||
epoch += 1
|
epoch += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _do_validation(self):
|
def _do_validation(self):
|
||||||
res = self.tester.test()
|
res = self.tester.test()
|
||||||
for name, num in res.items():
|
for name, num in res.items():
|
||||||
|
@ -258,28 +258,47 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
|
|||||||
if _unused_param:
|
if _unused_param:
|
||||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
|
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
|
||||||
|
|
||||||
|
module_name = ''
|
||||||
if check_res.missing:
|
if check_res.missing:
|
||||||
errs.append(f"\tmissing param: {check_res.missing}")
|
errs.append(f"\tmissing param: {check_res.missing}")
|
||||||
_miss_in_dataset = []
|
import re
|
||||||
_miss_out_dataset = []
|
mapped_missing = []
|
||||||
|
unmapped_missing = []
|
||||||
|
input_func_map = {}
|
||||||
for _miss in check_res.missing:
|
for _miss in check_res.missing:
|
||||||
|
fun_arg, module_name = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss)
|
||||||
if '(' in _miss:
|
if '(' in _miss:
|
||||||
# if they are like 'SomeParam(assign to xxx)'
|
# if they are like 'SomeParam(assign to xxx)'
|
||||||
_miss = _miss.split('(')[0]
|
_miss = _miss.split('(')[0]
|
||||||
if _miss in dataset:
|
input_func_map[_miss] = fun_arg
|
||||||
_miss_in_dataset.append(_miss)
|
if fun_arg == _miss:
|
||||||
|
unmapped_missing.append(_miss)
|
||||||
else:
|
else:
|
||||||
_miss_out_dataset.append(_miss)
|
mapped_missing.append(_miss)
|
||||||
|
|
||||||
if _miss_in_dataset:
|
for _miss in mapped_missing:
|
||||||
suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now "
|
if _miss in dataset:
|
||||||
f"target is {list(target_dict.keys())}).")
|
suggestions.append(f"Set {_miss} as target.")
|
||||||
if _miss_out_dataset:
|
else:
|
||||||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
|
_tmp = ''
|
||||||
f"target has {list(target_dict.keys())}) or output it "
|
if check_res.unused:
|
||||||
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).")
|
_tmp = f"Check key assignment for `{input_func_map[_miss]}` when initialize {module_name}."
|
||||||
# if _unused_field:
|
if _tmp:
|
||||||
# _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
|
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
|
||||||
|
else:
|
||||||
|
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.'
|
||||||
|
suggestions.append(_tmp)
|
||||||
|
for _miss in unmapped_missing:
|
||||||
|
if _miss in dataset:
|
||||||
|
suggestions.append(f"Set {_miss} as target.")
|
||||||
|
else:
|
||||||
|
_tmp = ''
|
||||||
|
if check_res.unused:
|
||||||
|
_tmp = f"Specify your assignment for `{input_func_map[_miss]}` when initialize {module_name}."
|
||||||
|
if _tmp:
|
||||||
|
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
|
||||||
|
else:
|
||||||
|
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.'
|
||||||
suggestions.append(_tmp)
|
suggestions.append(_tmp)
|
||||||
|
|
||||||
if check_res.duplicated:
|
if check_res.duplicated:
|
||||||
@ -297,17 +316,23 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
|
|||||||
sugg_str = ""
|
sugg_str = ""
|
||||||
if len(suggestions) > 1:
|
if len(suggestions) > 1:
|
||||||
for idx, sugg in enumerate(suggestions):
|
for idx, sugg in enumerate(suggestions):
|
||||||
sugg_str += f'({idx+1}). {sugg}'
|
if idx>0:
|
||||||
|
sugg_str += '\t\t\t'
|
||||||
|
sugg_str += f'({idx+1}). {sugg}\n'
|
||||||
|
sugg_str = sugg_str[:-1]
|
||||||
else:
|
else:
|
||||||
sugg_str += suggestions[0]
|
sugg_str += suggestions[0]
|
||||||
|
errs.append(f'\ttarget field: {list(target_dict.keys())}')
|
||||||
|
errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}')
|
||||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
|
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
|
||||||
raise NameError(err_str)
|
raise NameError(err_str)
|
||||||
if check_res.unused:
|
if check_res.unused:
|
||||||
if check_level == WARNING_CHECK_LEVEL:
|
if check_level == WARNING_CHECK_LEVEL:
|
||||||
_unused_warn = f'{check_res.unused} is not used by {func_signature}.'
|
if not module_name:
|
||||||
|
module_name = func_signature.split('.')[0]
|
||||||
|
_unused_warn = f'{check_res.unused} is not used by {module_name}.'
|
||||||
warnings.warn(message=_unused_warn)
|
warnings.warn(message=_unused_warn)
|
||||||
|
|
||||||
|
|
||||||
def _check_forward_error(forward_func, batch_x, dataset, check_level):
|
def _check_forward_error(forward_func, batch_x, dataset, check_level):
|
||||||
check_res = _check_arg_dict_list(forward_func, batch_x)
|
check_res = _check_arg_dict_list(forward_func, batch_x)
|
||||||
func_signature = get_func_signature(forward_func)
|
func_signature = get_func_signature(forward_func)
|
||||||
@ -402,40 +427,3 @@ def seq_mask(seq_len, max_len):
|
|||||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
|
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
|
||||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
|
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
|
||||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len]
|
return torch.gt(seq_len, seq_range) # [batch_size, max_len]
|
||||||
|
|
||||||
|
|
||||||
def _relocate_pbar(pbar:tqdm, print_str:str):
|
|
||||||
"""
|
|
||||||
|
|
||||||
When using tqdm, you cannot print. If you print, the tqdm will duplicate. By using this function, print_str will
|
|
||||||
show above tqdm.
|
|
||||||
:param pbar: tqdm
|
|
||||||
:param print_str:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
params = ['desc', 'total', 'leave', 'file', 'ncols', 'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable',
|
|
||||||
'unit', 'unit_scale', 'dynamic_ncols', 'smoothing', 'bar_format', 'initial', 'position', 'postfix', 'unit_divisor',
|
|
||||||
'gui']
|
|
||||||
|
|
||||||
attr_map = {'file': 'fp', 'initial':'n', 'position':'pos'}
|
|
||||||
|
|
||||||
param_dict = {}
|
|
||||||
for param in params:
|
|
||||||
attr_name = param
|
|
||||||
if param in attr_map:
|
|
||||||
attr_name = attr_map[param]
|
|
||||||
value = getattr(pbar, attr_name)
|
|
||||||
if attr_name == 'pos':
|
|
||||||
value = abs(value)
|
|
||||||
param_dict[param] = value
|
|
||||||
|
|
||||||
pbar.close()
|
|
||||||
avg_time = pbar.avg_time
|
|
||||||
start_t = pbar.start_t
|
|
||||||
print(print_str)
|
|
||||||
pbar = tqdm(**param_dict)
|
|
||||||
pbar.start_t = start_t
|
|
||||||
pbar.avg_time = avg_time
|
|
||||||
pbar.sp(pbar.__repr__())
|
|
||||||
return pbar
|
|
@ -1,4 +1,4 @@
|
|||||||
numpy>=1.14.2
|
numpy>=1.14.2
|
||||||
torch>=0.4.0
|
torch>=0.4.0
|
||||||
tensorboardX
|
tensorboardX
|
||||||
tqdm
|
tqdm>=4.28.1
|
@ -142,9 +142,16 @@ class TestDataSet(unittest.TestCase):
|
|||||||
def split_sent(ins):
|
def split_sent(ins):
|
||||||
return ins['raw_sentence'].split()
|
return ins['raw_sentence'].split()
|
||||||
dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t')
|
dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t')
|
||||||
dataset.apply(split_sent, new_field_name='words')
|
dataset.drop(lambda x:len(x['raw_sentence'].split())==0)
|
||||||
|
dataset.apply(split_sent, new_field_name='words', is_input=True)
|
||||||
# print(dataset)
|
# print(dataset)
|
||||||
|
|
||||||
|
def test_add_field(self):
|
||||||
|
ds = DataSet({"x": [3, 4]})
|
||||||
|
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True)
|
||||||
|
# ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y')
|
||||||
|
print(ds)
|
||||||
|
|
||||||
def test_save_load(self):
|
def test_save_load(self):
|
||||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
|
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
|
||||||
ds.save("./my_ds.pkl")
|
ds.save("./my_ds.pkl")
|
||||||
|
@ -4,6 +4,64 @@ data_name = "pku_training.utf8"
|
|||||||
pickle_path = "data_for_tests"
|
pickle_path = "data_for_tests"
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
import time
|
||||||
|
from fastNLP.core.utils import CheckError
|
||||||
|
from fastNLP.core.dataset import DataSet
|
||||||
|
from fastNLP.core.instance import Instance
|
||||||
|
from fastNLP.core.losses import BCELoss
|
||||||
|
from fastNLP.core.losses import CrossEntropyLoss
|
||||||
|
from fastNLP.core.metrics import AccuracyMetric
|
||||||
|
from fastNLP.core.optimizer import SGD
|
||||||
|
from fastNLP.core.tester import Tester
|
||||||
|
from fastNLP.models.base_model import NaiveClassifier
|
||||||
|
|
||||||
|
def prepare_fake_dataset():
|
||||||
|
mean = np.array([-3, -3])
|
||||||
|
cov = np.array([[1, 0], [0, 1]])
|
||||||
|
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
|
||||||
|
|
||||||
|
mean = np.array([3, 3])
|
||||||
|
cov = np.array([[1, 0], [0, 1]])
|
||||||
|
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
|
||||||
|
|
||||||
|
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
|
||||||
|
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_fake_dataset2(*args, size=100):
|
||||||
|
ys = np.random.randint(4, size=100, dtype=np.int64)
|
||||||
|
data = {'y': ys}
|
||||||
|
for arg in args:
|
||||||
|
data[arg] = np.random.randn(size, 5)
|
||||||
|
return DataSet(data=data)
|
||||||
|
|
||||||
class TestTester(unittest.TestCase):
|
class TestTester(unittest.TestCase):
|
||||||
def test_case_1(self):
|
def test_case_1(self):
|
||||||
pass
|
# 检查报错提示能否正确提醒用户
|
||||||
|
# 这里传入多余参数,让其duplicate
|
||||||
|
dataset = prepare_fake_dataset2('x1', 'x_unused')
|
||||||
|
dataset.rename_field('x_unused', 'x2')
|
||||||
|
dataset.set_input('x1', 'x2')
|
||||||
|
dataset.set_target('y', 'x1')
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.fc = nn.Linear(5, 4)
|
||||||
|
def forward(self, x1, x2):
|
||||||
|
x1 = self.fc(x1)
|
||||||
|
x2 = self.fc(x2)
|
||||||
|
x = x1 + x2
|
||||||
|
time.sleep(0.1)
|
||||||
|
# loss = F.cross_entropy(x, y)
|
||||||
|
return {'preds': x}
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
tester = Tester(
|
||||||
|
data=dataset,
|
||||||
|
model=model,
|
||||||
|
metrics=AccuracyMetric())
|
||||||
|
tester.test()
|
||||||
|
@ -3,7 +3,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import time
|
||||||
from fastNLP.core.utils import CheckError
|
from fastNLP.core.utils import CheckError
|
||||||
from fastNLP.core.dataset import DataSet
|
from fastNLP.core.dataset import DataSet
|
||||||
from fastNLP.core.instance import Instance
|
from fastNLP.core.instance import Instance
|
||||||
@ -212,8 +212,8 @@ class TrainerTestGround(unittest.TestCase):
|
|||||||
# 这里传入多余参数,让其duplicate
|
# 这里传入多余参数,让其duplicate
|
||||||
dataset = prepare_fake_dataset2('x1', 'x_unused')
|
dataset = prepare_fake_dataset2('x1', 'x_unused')
|
||||||
dataset.rename_field('x_unused', 'x2')
|
dataset.rename_field('x_unused', 'x2')
|
||||||
dataset.set_input('x1', 'x2', 'y')
|
dataset.set_input('x1', 'x2')
|
||||||
dataset.set_target('x1', 'x2')
|
dataset.set_target('y', 'x1')
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -222,8 +222,9 @@ class TrainerTestGround(unittest.TestCase):
|
|||||||
x1 = self.fc(x1)
|
x1 = self.fc(x1)
|
||||||
x2 = self.fc(x2)
|
x2 = self.fc(x2)
|
||||||
x = x1 + x2
|
x = x1 + x2
|
||||||
|
time.sleep(0.1)
|
||||||
# loss = F.cross_entropy(x, y)
|
# loss = F.cross_entropy(x, y)
|
||||||
return {'pred': x}
|
return {'preds': x}
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -34,17 +34,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"8529\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from fastNLP import DataSet\n",
|
"from fastNLP import DataSet\n",
|
||||||
"from fastNLP import Instance\n",
|
"from fastNLP import Instance\n",
|
||||||
@ -56,20 +48,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
|
|
||||||
"'label': 1}\n",
|
|
||||||
"{'raw_sentence': -LRB- Tries -RRB- to parody a genre that 's already a joke in the United States .,\n",
|
|
||||||
"'label': 1}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 使用数字索引[k],获取第k个样本\n",
|
"# 使用数字索引[k],获取第k个样本\n",
|
||||||
"print(dataset[0])\n",
|
"print(dataset[0])\n",
|
||||||
@ -90,21 +71,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"{'raw_sentence': fake data,\n",
|
|
||||||
"'label': 0}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# DataSet.append(Instance)加入新数据\n",
|
"# DataSet.append(Instance)加入新数据\n",
|
||||||
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n",
|
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n",
|
||||||
@ -121,18 +90,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
|
|
||||||
"'label': 1}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 将所有数字转为小写\n",
|
"# 将所有数字转为小写\n",
|
||||||
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
|
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
|
||||||
@ -141,18 +101,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
|
|
||||||
"'label': 1}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# label转int\n",
|
"# label转int\n",
|
||||||
"dataset.apply(lambda x: int(x['label']), new_field_name='label')\n",
|
"dataset.apply(lambda x: int(x['label']), new_field_name='label')\n",
|
||||||
@ -161,28 +112,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"ename": "RuntimeError",
|
|
||||||
"evalue": "Cannot create FieldArray with an empty list.",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"\u001b[0;32m<ipython-input-9-d70cf5545af4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit_sent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'raw_sentence'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_sent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_field_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'words'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, func, new_field_name, **kwargs)\u001b[0m\n\u001b[1;32m 265\u001b[0m **extra_param)\n\u001b[1;32m 266\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_field_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfields\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mextra_param\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36madd_field\u001b[0;34m(self, name, fields, padding_val, is_input, is_target)\u001b[0m\n\u001b[1;32m 158\u001b[0m f\"Dataset size {len(self)} != field size {len(fields)}\")\n\u001b[1;32m 159\u001b[0m self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,\n\u001b[0;32m--> 160\u001b[0;31m is_input=is_input)\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, content, padding_val, is_target, is_input)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_input\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_target\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36mis_input\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mis_input\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_to_np_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;31m# content is a 1-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot create FieldArray with an empty list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
||||||
"\u001b[0;31mRuntimeError\u001b[0m: Cannot create FieldArray with an empty list."
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 使用空格分割句子\n",
|
"# 使用空格分割句子\n",
|
||||||
"def split_sent(ins):\n",
|
"def split_sent(ins):\n",
|
||||||
@ -193,20 +125,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
|
|
||||||
"'label': 1,\n",
|
|
||||||
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n",
|
|
||||||
"'seq_len': 37}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 增加长度信息\n",
|
"# 增加长度信息\n",
|
||||||
"dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n",
|
"dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n",
|
||||||
@ -223,17 +144,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"38\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"dataset.drop(lambda x: x['seq_len'] <= 3)\n",
|
"dataset.drop(lambda x: x['seq_len'] <= 3)\n",
|
||||||
"print(len(dataset))"
|
"print(len(dataset))"
|
||||||
@ -250,7 +163,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -264,18 +177,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"27\n",
|
|
||||||
"11"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 分出测试集、训练集\n",
|
"# 分出测试集、训练集\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -296,20 +200,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 22,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"{'raw_sentence': that the chuck norris `` grenade gag '' occurs about 7 times during windtalkers is a good indication of how serious-minded the film is .,\n",
|
|
||||||
"'label': 2,\n",
|
|
||||||
"'words': [6, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 8, 24, 1, 5, 1, 1, 2, 15, 10, 3],\n",
|
|
||||||
"'seq_len': 25}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from fastNLP import Vocabulary\n",
|
"from fastNLP import Vocabulary\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -336,36 +229,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 23,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"CNNText(\n",
|
|
||||||
" (embed): Embedding(\n",
|
|
||||||
" (embed): Embedding(32, 50, padding_idx=0)\n",
|
|
||||||
" (dropout): Dropout(p=0.0)\n",
|
|
||||||
" )\n",
|
|
||||||
" (conv_pool): ConvMaxpool(\n",
|
|
||||||
" (convs): ModuleList(\n",
|
|
||||||
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n",
|
|
||||||
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n",
|
|
||||||
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n",
|
|
||||||
" )\n",
|
|
||||||
" )\n",
|
|
||||||
" (dropout): Dropout(p=0.1)\n",
|
|
||||||
" (fc): Linear(\n",
|
|
||||||
" (linear): Linear(in_features=12, out_features=5, bias=True)\n",
|
|
||||||
" )\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 23,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from fastNLP.models import CNNText\n",
|
"from fastNLP.models import CNNText\n",
|
||||||
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
|
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
|
||||||
@ -432,7 +298,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 25,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -469,7 +335,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 26,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -492,7 +358,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 27,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -501,94 +367,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 30,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"training epochs started 2018-12-04 22:51:24\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.407407\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.518519\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.481481\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.592593\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 实例化Trainer,传入模型和数据,进行训练\n",
|
"# 实例化Trainer,传入模型和数据,进行训练\n",
|
||||||
"# 先在test_data拟合\n",
|
"# 先在test_data拟合\n",
|
||||||
@ -604,101 +385,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 31,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"training epochs started 2018-12-04 22:52:01\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.222222\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.259259\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.296296\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.259259\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" \r"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Train finished!\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 用train_data训练,在test_data验证\n",
|
"# 用train_data训练,在test_data验证\n",
|
||||||
"trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n",
|
"trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n",
|
||||||
@ -713,19 +402,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 33,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[tester] \n",
|
|
||||||
"AccuracyMetric: acc=0.259259\n",
|
|
||||||
"{'AccuracyMetric': {'acc': 0.259259}}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 调用Tester在test_data上评价效果\n",
|
"# 调用Tester在test_data上评价效果\n",
|
||||||
"from fastNLP import Tester\n",
|
"from fastNLP import Tester\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user