1. 优化trainer checkcode过程的报错信息

This commit is contained in:
yh 2018-12-05 20:15:59 +08:00
parent f7c29b85d7
commit 1158556236
11 changed files with 172 additions and 439 deletions

View File

@ -69,7 +69,7 @@ class DataSet(object):
self.idx = idx self.idx = idx
def __getitem__(self, item): def __getitem__(self, item):
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[self.idx])
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
return self.dataset.field_arrays[item][self.idx] return self.dataset.field_arrays[item][self.idx]

View File

@ -83,7 +83,8 @@ class FieldArray(object):
elif isinstance(content, list): elif isinstance(content, list):
# content is a 1-D list # content is a 1-D list
if len(content) == 0: if len(content) == 0:
raise RuntimeError("Cannot create FieldArray with an empty list.") # the old error is not informative enough.
raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.")
type_set = set([type(item) for item in content]) type_set = set([type(item) for item in content])
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES:
@ -164,11 +165,13 @@ class FieldArray(object):
# TODO 当这个fieldArray是seq_length这种只有一位的内容时不需要padding需要再讨论一下 # TODO 当这个fieldArray是seq_length这种只有一位的内容时不需要padding需要再讨论一下
if not is_iterable(self.content[0]): if not is_iterable(self.content[0]):
array = np.array([self.content[i] for i in indices], dtype=self.dtype) array = np.array([self.content[i] for i in indices], dtype=self.dtype)
else: elif self.dtype in (np.int64, np.float64):
max_len = max([len(self.content[i]) for i in indices]) max_len = max([len(self.content[i]) for i in indices])
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)
for i, idx in enumerate(indices): for i, idx in enumerate(indices):
array[i][:len(self.content[idx])] = self.content[idx] array[i][:len(self.content[idx])] = self.content[idx]
else: # should only be str
array = np.array([self.content[i] for i in indices])
return array return array
def __len__(self): def __len__(self):

View File

@ -80,7 +80,7 @@ class LossBase(object):
fast_param = {} fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0] fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(pred_dict.values())[0] fast_param['target'] = list(target_dict.values())[0]
return fast_param return fast_param
return fast_param return fast_param
@ -134,10 +134,11 @@ class LossBase(object):
# missing # missing
if not self._checked: if not self._checked:
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict])
# only check missing. # replace missing.
missing = check_res.missing missing = check_res.missing
replaced_missing = list(missing) replaced_missing = list(missing)
for idx, func_arg in enumerate(missing): for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)" f"in `{self.__class__.__name__}`)"
@ -188,7 +189,7 @@ class CrossEntropyLoss(LossBase):
class L1Loss(LossBase): class L1Loss(LossBase):
def __init__(self, pred=None, target=None): def __init__(self, pred=None, target=None):
super(L1Loss, self).__init__() super(L1Loss, self).__init__()
self._init_param_map(input=pred, target=target) self._init_param_map(pred=pred, target=target)
def get_loss(self, pred, target): def get_loss(self, pred, target):
return F.l1_loss(input=pred, target=target) return F.l1_loss(input=pred, target=target)
@ -197,7 +198,7 @@ class L1Loss(LossBase):
class BCELoss(LossBase): class BCELoss(LossBase):
def __init__(self, pred=None, target=None): def __init__(self, pred=None, target=None):
super(BCELoss, self).__init__() super(BCELoss, self).__init__()
self._init_param_map(input=pred, target=target) self._init_param_map(pred=pred, target=target)
def get_loss(self, pred, target): def get_loss(self, pred, target):
return F.binary_cross_entropy(input=pred, target=target) return F.binary_cross_entropy(input=pred, target=target)
@ -205,7 +206,7 @@ class BCELoss(LossBase):
class NLLLoss(LossBase): class NLLLoss(LossBase):
def __init__(self, pred=None, target=None): def __init__(self, pred=None, target=None):
super(NLLLoss, self).__init__() super(NLLLoss, self).__init__()
self._init_param_map(input=pred, target=target) self._init_param_map(pred=pred, target=target)
def get_loss(self, pred, target): def get_loss(self, pred, target):
return F.nll_loss(input=pred, target=target) return F.nll_loss(input=pred, target=target)

View File

@ -151,9 +151,11 @@ class MetricBase(object):
if not self._checked: if not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
# only check missing. # only check missing.
# replace missing.
missing = check_res.missing missing = check_res.missing
replaced_missing = list(missing) replaced_missing = list(missing)
for idx, func_arg in enumerate(missing): for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)" f"in `{self.__class__.__name__}`)"

View File

@ -2,7 +2,7 @@ import os
import time import time
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
from tqdm import tqdm from tqdm.autonotebook import tqdm
import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
@ -23,7 +23,6 @@ from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _relocate_pbar
class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop
@ -45,7 +44,7 @@ class Trainer(object):
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). :param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data :param DataSet dev_data: the validation data
:param use_cuda: :param use_cuda:
:param str save_path: file path to save models :param save_path: file path to save models
:param Optimizer optimizer: an optimizer object :param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
@ -149,7 +148,7 @@ class Trainer(object):
self._mode(self.model, is_test=False) self._mode(self.model, is_test=False)
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("training epochs started " + self.start_time) print("training epochs started " + self.start_time, flush=True)
if self.save_path is None: if self.save_path is None:
class psudoSW: class psudoSW:
def __getattr__(self, item): def __getattr__(self, item):
@ -172,12 +171,12 @@ class Trainer(object):
del self._summary_writer del self._summary_writer
def _tqdm_train(self): def _tqdm_train(self):
self.step = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)
total_steps = data_iterator.num_batches*self.n_epochs total_steps = data_iterator.num_batches*self.n_epochs
epoch = 1 epoch = 1
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', desc="Epoch {}/{}" with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
.format(epoch, self.n_epochs), leave=False, dynamic_ncols=True) as pbar:
ava_loss = 0 ava_loss = 0
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
@ -195,28 +194,26 @@ class Trainer(object):
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if (self.step+1) % self.print_every == 0: if (self.step+1) % self.print_every == 0:
pbar.update(self.print_every) pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every))
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss/self.print_every))
ava_loss = 0 ava_loss = 0
pbar.update(1)
self.step += 1 self.step += 1
if self.validate_every > 0 and self.step % self.validate_every == 0 \ if self.validate_every > 0 and self.step % self.validate_every == 0 \
and self.dev_data is not None: and self.dev_data is not None:
eval_res = self._do_validation() eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar = _relocate_pbar(pbar, print_str=eval_str) pbar.write(eval_str)
if self.validate_every < 0 and self.dev_data: if self.validate_every < 0 and self.dev_data:
eval_res = self._do_validation() eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar = _relocate_pbar(pbar, print_str=eval_str) pbar.write(eval_str)
if epoch!=self.n_epochs: if epoch!=self.n_epochs:
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)
pbar.close() pbar.close()
def _print_train(self): def _print_train(self):
""" """
@ -264,9 +261,6 @@ class Trainer(object):
self._do_validation() self._do_validation()
epoch += 1 epoch += 1
def _do_validation(self): def _do_validation(self):
res = self.tester.test() res = self.tester.test()
for name, num in res.items(): for name, num in res.items():

View File

@ -258,29 +258,48 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
if _unused_param: if _unused_param:
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
module_name = ''
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}") errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = [] import re
_miss_out_dataset = [] mapped_missing = []
unmapped_missing = []
input_func_map = {}
for _miss in check_res.missing: for _miss in check_res.missing:
fun_arg, module_name = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss)
if '(' in _miss: if '(' in _miss:
# if they are like 'SomeParam(assign to xxx)' # if they are like 'SomeParam(assign to xxx)'
_miss = _miss.split('(')[0] _miss = _miss.split('(')[0]
if _miss in dataset: input_func_map[_miss] = fun_arg
_miss_in_dataset.append(_miss) if fun_arg == _miss:
unmapped_missing.append(_miss)
else: else:
_miss_out_dataset.append(_miss) mapped_missing.append(_miss)
if _miss_in_dataset: for _miss in mapped_missing:
suggestions.append(f"You might need to set {_miss_in_dataset} as target(Right now " if _miss in dataset:
f"target is {list(target_dict.keys())}).") suggestions.append(f"Set {_miss} as target.")
if _miss_out_dataset: else:
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " _tmp = ''
f"target has {list(target_dict.keys())}) or output it " if check_res.unused:
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") _tmp = f"Check key assignment for `{input_func_map[_miss]}` when initialize {module_name}."
# if _unused_field: if _tmp:
# _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " _tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp) else:
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp)
for _miss in unmapped_missing:
if _miss in dataset:
suggestions.append(f"Set {_miss} as target.")
else:
_tmp = ''
if check_res.unused:
_tmp = f"Specify your assignment for `{input_func_map[_miss]}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
else:
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp)
if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.") errs.append(f"\tduplicated param: {check_res.duplicated}.")
@ -297,17 +316,23 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
sugg_str = "" sugg_str = ""
if len(suggestions) > 1: if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}' if idx>0:
sugg_str += '\t\t\t'
sugg_str += f'({idx+1}). {sugg}\n'
sugg_str = sugg_str[:-1]
else: else:
sugg_str += suggestions[0] sugg_str += suggestions[0]
errs.append(f'\ttarget field: {list(target_dict.keys())}')
errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}')
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str) raise NameError(err_str)
if check_res.unused: if check_res.unused:
if check_level == WARNING_CHECK_LEVEL: if check_level == WARNING_CHECK_LEVEL:
_unused_warn = f'{check_res.unused} is not used by {func_signature}.' if not module_name:
module_name = func_signature.split('.')[0]
_unused_warn = f'{check_res.unused} is not used by {module_name}.'
warnings.warn(message=_unused_warn) warnings.warn(message=_unused_warn)
def _check_forward_error(forward_func, batch_x, dataset, check_level): def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x) check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = get_func_signature(forward_func) func_signature = get_func_signature(forward_func)
@ -402,40 +427,3 @@ def seq_mask(seq_len, max_len):
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len] return torch.gt(seq_len, seq_range) # [batch_size, max_len]
def _relocate_pbar(pbar:tqdm, print_str:str):
"""
When using tqdm, you cannot print. If you print, the tqdm will duplicate. By using this function, print_str will
show above tqdm.
:param pbar: tqdm
:param print_str:
:return:
"""
params = ['desc', 'total', 'leave', 'file', 'ncols', 'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable',
'unit', 'unit_scale', 'dynamic_ncols', 'smoothing', 'bar_format', 'initial', 'position', 'postfix', 'unit_divisor',
'gui']
attr_map = {'file': 'fp', 'initial':'n', 'position':'pos'}
param_dict = {}
for param in params:
attr_name = param
if param in attr_map:
attr_name = attr_map[param]
value = getattr(pbar, attr_name)
if attr_name == 'pos':
value = abs(value)
param_dict[param] = value
pbar.close()
avg_time = pbar.avg_time
start_t = pbar.start_t
print(print_str)
pbar = tqdm(**param_dict)
pbar.start_t = start_t
pbar.avg_time = avg_time
pbar.sp(pbar.__repr__())
return pbar

View File

@ -1,4 +1,4 @@
numpy>=1.14.2 numpy>=1.14.2
torch>=0.4.0 torch>=0.4.0
tensorboardX tensorboardX
tqdm tqdm>=4.28.1

View File

@ -142,9 +142,16 @@ class TestDataSet(unittest.TestCase):
def split_sent(ins): def split_sent(ins):
return ins['raw_sentence'].split() return ins['raw_sentence'].split()
dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t') dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t')
dataset.apply(split_sent, new_field_name='words') dataset.drop(lambda x:len(x['raw_sentence'].split())==0)
dataset.apply(split_sent, new_field_name='words', is_input=True)
# print(dataset) # print(dataset)
def test_add_field(self):
ds = DataSet({"x": [3, 4]})
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True)
# ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y')
print(ds)
def test_save_load(self): def test_save_load(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.save("./my_ds.pkl") ds.save("./my_ds.pkl")

View File

@ -4,6 +4,64 @@ data_name = "pku_training.utf8"
pickle_path = "data_for_tests" pickle_path = "data_for_tests"
import numpy as np
import torch.nn.functional as F
from torch import nn
import time
from fastNLP.core.utils import CheckError
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD
from fastNLP.core.tester import Tester
from fastNLP.models.base_model import NaiveClassifier
def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set
def prepare_fake_dataset2(*args, size=100):
ys = np.random.randint(4, size=100, dtype=np.int64)
data = {'y': ys}
for arg in args:
data[arg] = np.random.randn(size, 5)
return DataSet(data=data)
class TestTester(unittest.TestCase): class TestTester(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
pass # 检查报错提示能否正确提醒用户
# 这里传入多余参数让其duplicate
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2')
dataset.set_target('y', 'x1')
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
time.sleep(0.1)
# loss = F.cross_entropy(x, y)
return {'preds': x}
model = Model()
tester = Tester(
data=dataset,
model=model,
metrics=AccuracyMetric())
tester.test()

View File

@ -3,7 +3,7 @@ import unittest
import numpy as np import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
import time
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
@ -212,8 +212,8 @@ class TrainerTestGround(unittest.TestCase):
# 这里传入多余参数让其duplicate # 这里传入多余参数让其duplicate
dataset = prepare_fake_dataset2('x1', 'x_unused') dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2') dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2', 'y') dataset.set_input('x1', 'x2')
dataset.set_target('x1', 'x2') dataset.set_target('y', 'x1')
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -222,8 +222,9 @@ class TrainerTestGround(unittest.TestCase):
x1 = self.fc(x1) x1 = self.fc(x1)
x2 = self.fc(x2) x2 = self.fc(x2)
x = x1 + x2 x = x1 + x2
time.sleep(0.1)
# loss = F.cross_entropy(x, y) # loss = F.cross_entropy(x, y)
return {'pred': x} return {'preds': x}
model = Model() model = Model()
trainer = Trainer( trainer = Trainer(

View File

@ -12,7 +12,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -34,17 +34,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"8529\n"
]
}
],
"source": [ "source": [
"from fastNLP import DataSet\n", "from fastNLP import DataSet\n",
"from fastNLP import Instance\n", "from fastNLP import Instance\n",
@ -56,20 +48,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1}\n",
"{'raw_sentence': -LRB- Tries -RRB- to parody a genre that 's already a joke in the United States .,\n",
"'label': 1}\n"
]
}
],
"source": [ "source": [
"# 使用数字索引[k]获取第k个样本\n", "# 使用数字索引[k]获取第k个样本\n",
"print(dataset[0])\n", "print(dataset[0])\n",
@ -90,21 +71,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"{'raw_sentence': fake data,\n",
"'label': 0}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"# DataSet.append(Instance)加入新数据\n", "# DataSet.append(Instance)加入新数据\n",
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n", "dataset.append(Instance(raw_sentence='fake data', label='0'))\n",
@ -121,18 +90,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1}\n"
]
}
],
"source": [ "source": [
"# 将所有数字转为小写\n", "# 将所有数字转为小写\n",
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", "dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
@ -141,18 +101,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1}\n"
]
}
],
"source": [ "source": [
"# label转int\n", "# label转int\n",
"dataset.apply(lambda x: int(x['label']), new_field_name='label')\n", "dataset.apply(lambda x: int(x['label']), new_field_name='label')\n",
@ -161,28 +112,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"ename": "RuntimeError",
"evalue": "Cannot create FieldArray with an empty list.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-d70cf5545af4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit_sent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'raw_sentence'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_sent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_field_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'words'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, func, new_field_name, **kwargs)\u001b[0m\n\u001b[1;32m 265\u001b[0m **extra_param)\n\u001b[1;32m 266\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_field_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfields\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mextra_param\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36madd_field\u001b[0;34m(self, name, fields, padding_val, is_input, is_target)\u001b[0m\n\u001b[1;32m 158\u001b[0m f\"Dataset size {len(self)} != field size {len(fields)}\")\n\u001b[1;32m 159\u001b[0m self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,\n\u001b[0;32m--> 160\u001b[0;31m is_input=is_input)\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, content, padding_val, is_target, is_input)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_input\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_target\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36mis_input\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mis_input\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_to_np_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;31m# content is a 1-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot create FieldArray with an empty list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Cannot create FieldArray with an empty list."
]
}
],
"source": [ "source": [
"# 使用空格分割句子\n", "# 使用空格分割句子\n",
"def split_sent(ins):\n", "def split_sent(ins):\n",
@ -193,20 +125,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1,\n",
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n",
"'seq_len': 37}\n"
]
}
],
"source": [ "source": [
"# 增加长度信息\n", "# 增加长度信息\n",
"dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n", "dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n",
@ -223,17 +144,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"38\n"
]
}
],
"source": [ "source": [
"dataset.drop(lambda x: x['seq_len'] <= 3)\n", "dataset.drop(lambda x: x['seq_len'] <= 3)\n",
"print(len(dataset))" "print(len(dataset))"
@ -250,7 +163,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -264,18 +177,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"27\n",
"11"
]
}
],
"source": [ "source": [
"# 分出测试集、训练集\n", "# 分出测试集、训练集\n",
"\n", "\n",
@ -296,20 +200,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': that the chuck norris `` grenade gag '' occurs about 7 times during windtalkers is a good indication of how serious-minded the film is .,\n",
"'label': 2,\n",
"'words': [6, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 8, 24, 1, 5, 1, 1, 2, 15, 10, 3],\n",
"'seq_len': 25}\n"
]
}
],
"source": [ "source": [
"from fastNLP import Vocabulary\n", "from fastNLP import Vocabulary\n",
"\n", "\n",
@ -336,36 +229,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"CNNText(\n",
" (embed): Embedding(\n",
" (embed): Embedding(32, 50, padding_idx=0)\n",
" (dropout): Dropout(p=0.0)\n",
" )\n",
" (conv_pool): ConvMaxpool(\n",
" (convs): ModuleList(\n",
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n",
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n",
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1)\n",
" (fc): Linear(\n",
" (linear): Linear(in_features=12, out_features=5, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"from fastNLP.models import CNNText\n", "from fastNLP.models import CNNText\n",
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
@ -432,7 +298,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -469,7 +335,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -492,7 +358,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -501,94 +367,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-04 22:51:24\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.407407\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.518519\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.481481\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.592593\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [ "source": [
"# 实例化Trainer传入模型和数据进行训练\n", "# 实例化Trainer传入模型和数据进行训练\n",
"# 先在test_data拟合\n", "# 先在test_data拟合\n",
@ -604,101 +385,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-04 22:52:01\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.222222\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.259259\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.296296\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.259259\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train finished!\n"
]
}
],
"source": [ "source": [
"# 用train_data训练在test_data验证\n", "# 用train_data训练在test_data验证\n",
"trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n", "trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n",
@ -713,19 +402,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tester] \n",
"AccuracyMetric: acc=0.259259\n",
"{'AccuracyMetric': {'acc': 0.259259}}\n"
]
}
],
"source": [ "source": [
"# 调用Tester在test_data上评价效果\n", "# 调用Tester在test_data上评价效果\n",
"from fastNLP import Tester\n", "from fastNLP import Tester\n",