mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-05 05:38:31 +08:00
Merge branch 'trainer' of https://github.com/FengZiYjun/fastNLP into check
This commit is contained in:
commit
ba7b17661c
@ -2,61 +2,28 @@ import torch
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
"""Wrapper of optimizer from framework
|
||||
def __init__(self, model_params, **kwargs):
|
||||
if model_params is not None and not isinstance(model_params, torch.Tensor):
|
||||
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params)))
|
||||
self.model_params = model_params
|
||||
self.settings = kwargs
|
||||
|
||||
1. Adam: lr (float), weight_decay (float)
|
||||
2. AdaGrad
|
||||
3. RMSProp
|
||||
4. SGD: lr (float), momentum (float)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer_name, **kwargs):
|
||||
"""
|
||||
:param optimizer_name: str, the name of the optimizer
|
||||
:param kwargs: the arguments
|
||||
|
||||
"""
|
||||
self.optim_name = optimizer_name
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""The name of the optimizer.
|
||||
|
||||
:return: str
|
||||
"""
|
||||
return self.optim_name
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
"""The arguments used to create the optimizer.
|
||||
|
||||
:return: dict of (str, *)
|
||||
"""
|
||||
return self.kwargs
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, model_params=None, lr=0.001, momentum=0.9):
|
||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)
|
||||
|
||||
def construct_from_pytorch(self, model_params):
|
||||
"""Construct a optimizer from framework over given model parameters."""
|
||||
if self.model_params is None:
|
||||
self.model_params = model_params
|
||||
return torch.optim.SGD(self.model_params, **self.settings)
|
||||
|
||||
if self.optim_name in ["SGD", "sgd"]:
|
||||
if "lr" in self.kwargs:
|
||||
if "momentum" not in self.kwargs:
|
||||
self.kwargs["momentum"] = 0
|
||||
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"])
|
||||
else:
|
||||
raise ValueError("requires learning rate for SGD optimizer")
|
||||
|
||||
elif self.optim_name in ["adam", "Adam"]:
|
||||
if "lr" in self.kwargs:
|
||||
if "weight_decay" not in self.kwargs:
|
||||
self.kwargs["weight_decay"] = 0
|
||||
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"],
|
||||
weight_decay=self.kwargs["weight_decay"])
|
||||
else:
|
||||
raise ValueError("requires learning rate for Adam optimizer")
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8):
|
||||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return optimizer
|
||||
def construct_from_pytorch(self, model_params):
|
||||
if self.model_params is None:
|
||||
self.model_params = model_params
|
||||
return torch.optim.Adam(self.model_params, **self.settings)
|
||||
|
@ -1,39 +1,38 @@
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch import nn
|
||||
|
||||
from fastNLP.core.batch import Batch
|
||||
from fastNLP.core.optimizer import Optimizer
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.losses import _prepare_losser
|
||||
from fastNLP.core.metrics import _prepare_metrics
|
||||
from fastNLP.core.optimizer import Adam
|
||||
from fastNLP.core.sampler import RandomSampler
|
||||
from fastNLP.core.sampler import SequentialSampler
|
||||
from fastNLP.core.tester import Tester
|
||||
from fastNLP.core.utils import CheckError
|
||||
from fastNLP.core.utils import _build_args
|
||||
from fastNLP.core.utils import _check_arg_dict_list
|
||||
from fastNLP.core.utils import _move_dict_value_to_device
|
||||
from fastNLP.core.utils import get_func_signature
|
||||
from fastNLP.core.dataset import DataSet
|
||||
|
||||
from fastNLP.core.losses import LossBase
|
||||
from fastNLP.core.metrics import MetricBase
|
||||
from fastNLP.core.losses import _prepare_losser
|
||||
from fastNLP.core.metrics import _prepare_metrics
|
||||
from fastNLP.core.utils import CheckError
|
||||
|
||||
class Trainer(object):
|
||||
"""Main Training Loop
|
||||
|
||||
"""
|
||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
|
||||
|
||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1,
|
||||
validate_every=-1,
|
||||
dev_data=None, use_cuda=False, save_path="./save",
|
||||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
|
||||
optimizer=Adam(lr=0.01, weight_decay=0), need_check_code=True,
|
||||
metric_key=None,
|
||||
**kwargs):
|
||||
super(Trainer, self).__init__()
|
||||
|
||||
@ -50,6 +49,13 @@ class Trainer(object):
|
||||
|
||||
# prepare evaluate
|
||||
metrics = _prepare_metrics(metrics)
|
||||
|
||||
# parse metric_key
|
||||
# increase_better is True. It means the exp result gets better if the indicator increases.
|
||||
# It is true by default.
|
||||
self.increase_better = False if metric_key[0] == "-" else True
|
||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
|
||||
|
||||
# prepare loss
|
||||
losser = _prepare_losser(losser)
|
||||
|
||||
@ -67,12 +73,10 @@ class Trainer(object):
|
||||
self.save_path = save_path
|
||||
self.print_every = int(print_every)
|
||||
self.validate_every = int(validate_every)
|
||||
self._best_accuracy = 0
|
||||
self.best_metric_indicator = None
|
||||
|
||||
self._model_device = model.parameters().__next__().device
|
||||
|
||||
# TODO self._best_accuracy不能表现出当前的metric多种的情况
|
||||
|
||||
if isinstance(optimizer, torch.optim.Optimizer):
|
||||
self.optimizer = optimizer
|
||||
else:
|
||||
@ -102,7 +106,7 @@ class Trainer(object):
|
||||
if torch.cuda.is_available() and self.use_cuda:
|
||||
self.model = self.model.cuda()
|
||||
|
||||
self.mode(self.model, is_test=False)
|
||||
self._mode(self.model, is_test=False)
|
||||
|
||||
start = time.time()
|
||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
|
||||
@ -112,7 +116,9 @@ class Trainer(object):
|
||||
def __getattr__(self, item):
|
||||
def pass_func(*args, **kwargs):
|
||||
pass
|
||||
|
||||
return pass_func
|
||||
|
||||
self._summary_writer = psudoSW()
|
||||
else:
|
||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
|
||||
@ -121,19 +127,20 @@ class Trainer(object):
|
||||
epoch = 1
|
||||
while epoch <= self.n_epochs:
|
||||
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
|
||||
as_numpy=False)
|
||||
|
||||
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)
|
||||
self._train_epoch(data_iterator, self.model, epoch, start)
|
||||
|
||||
# validate_every override validation at end of epochs
|
||||
if self.dev_data and self.validate_every <= 0:
|
||||
self.do_validation()
|
||||
self._do_validation()
|
||||
epoch += 1
|
||||
finally:
|
||||
self._summary_writer.close()
|
||||
del self._summary_writer
|
||||
|
||||
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
|
||||
def _train_epoch(self, data_iterator, model, epoch, start):
|
||||
"""Training process in one epoch.
|
||||
|
||||
kwargs should contain:
|
||||
@ -144,10 +151,10 @@ class Trainer(object):
|
||||
for batch_x, batch_y in data_iterator:
|
||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
|
||||
_move_dict_value_to_device(self._model_device, batch_x, batch_y)
|
||||
prediction = self.data_forward(model, batch_x)
|
||||
loss = self.get_loss(prediction, batch_y)
|
||||
self.grad_backward(loss)
|
||||
self.update()
|
||||
prediction = self._data_forward(model, batch_x)
|
||||
loss = self._compute_loss(prediction, batch_y)
|
||||
self._grad_backward(loss)
|
||||
self._update()
|
||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
@ -162,18 +169,19 @@ class Trainer(object):
|
||||
print(print_output)
|
||||
|
||||
if self.validate_every > 0 and self.step % self.validate_every == 0:
|
||||
self.do_validation()
|
||||
self._do_validation()
|
||||
|
||||
self.step += 1
|
||||
|
||||
def do_validation(self):
|
||||
def _do_validation(self):
|
||||
res = self.tester.test()
|
||||
for name, num in res.items():
|
||||
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
|
||||
if self.save_path is not None and self.best_eval_result(res):
|
||||
self.save_model(self.model, 'best_model_' + self.start_time)
|
||||
if self.save_path is not None and self._better_eval_result(res):
|
||||
self._save_model(self.model,
|
||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
|
||||
|
||||
def mode(self, model, is_test=False):
|
||||
def _mode(self, model, is_test=False):
|
||||
"""Train mode or Test mode. This is for PyTorch currently.
|
||||
|
||||
:param model: a PyTorch model
|
||||
@ -185,20 +193,20 @@ class Trainer(object):
|
||||
else:
|
||||
model.train()
|
||||
|
||||
def update(self):
|
||||
def _update(self):
|
||||
"""Perform weight update on a model.
|
||||
|
||||
"""
|
||||
self.optimizer.step()
|
||||
|
||||
def data_forward(self, network, x):
|
||||
def _data_forward(self, network, x):
|
||||
x = _build_args(network.forward, **x)
|
||||
y = network(**x)
|
||||
if not isinstance(y, dict):
|
||||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
|
||||
return y
|
||||
|
||||
def grad_backward(self, loss):
|
||||
def _grad_backward(self, loss):
|
||||
"""Compute gradient with link rules.
|
||||
|
||||
:param loss: a scalar where back-prop starts
|
||||
@ -208,7 +216,7 @@ class Trainer(object):
|
||||
self.model.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
def get_loss(self, predict, truth):
|
||||
def _compute_loss(self, predict, truth):
|
||||
"""Compute loss given prediction and ground truth.
|
||||
|
||||
:param predict: prediction dict, produced by model.forward
|
||||
@ -217,34 +225,59 @@ class Trainer(object):
|
||||
"""
|
||||
return self.losser(predict, truth)
|
||||
|
||||
def save_model(self, model, model_name, only_param=False):
|
||||
def _save_model(self, model, model_name, only_param=False):
|
||||
model_name = os.path.join(self.save_path, model_name)
|
||||
if only_param:
|
||||
torch.save(model.state_dict(), model_name)
|
||||
else:
|
||||
torch.save(model, model_name)
|
||||
|
||||
def best_eval_result(self, metrics):
|
||||
def _better_eval_result(self, metrics):
|
||||
"""Check if the current epoch yields better validation results.
|
||||
|
||||
:return: bool, True means current results on dev set is the best.
|
||||
:return bool value: True means current results on dev set is the best.
|
||||
"""
|
||||
if isinstance(metrics, tuple):
|
||||
loss, metrics = metrics
|
||||
|
||||
if isinstance(metrics, dict):
|
||||
if len(metrics) == 1:
|
||||
accuracy = list(metrics.values())[0]
|
||||
# only single metric, just use it
|
||||
metric_dict = list(metrics.values())[0]
|
||||
metrics_name = list(metrics.keys())[0]
|
||||
else:
|
||||
accuracy = metrics[self.eval_sort_key]
|
||||
else:
|
||||
accuracy = metrics
|
||||
metrics_name = self.metrics[0].__class__.__name__
|
||||
if metrics_name not in metrics:
|
||||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
|
||||
metric_dict = metrics[metrics_name]
|
||||
|
||||
if accuracy > self._best_accuracy:
|
||||
self._best_accuracy = accuracy
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if len(metric_dict) == 1:
|
||||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
|
||||
elif len(metric_dict) > 1 and self.metric_key is None:
|
||||
raise RuntimeError(
|
||||
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
|
||||
else:
|
||||
# metric_key is set
|
||||
if self.metric_key not in metric_dict:
|
||||
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}")
|
||||
indicator_val = metric_dict[self.metric_key]
|
||||
|
||||
is_better = True
|
||||
if self.best_metric_indicator is None:
|
||||
# first-time validation
|
||||
self.best_metric_indicator = indicator_val
|
||||
else:
|
||||
if self.increase_better is True:
|
||||
if indicator_val > self.best_metric_indicator:
|
||||
self.best_metric_indicator = indicator_val
|
||||
else:
|
||||
is_better = False
|
||||
else:
|
||||
if indicator_val < self.best_metric_indicator:
|
||||
self.best_metric_indicator = indicator_val
|
||||
else:
|
||||
is_better = False
|
||||
return is_better
|
||||
|
||||
|
||||
DEFAULT_CHECK_BATCH_SIZE = 2
|
||||
@ -254,6 +287,7 @@ IGNORE_CHECK_LEVEL = 0
|
||||
WARNING_CHECK_LEVEL = 1
|
||||
STRICT_CHECK_LEVEL = 2
|
||||
|
||||
|
||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
|
||||
dev_data=None,
|
||||
check_level=WARNING_CHECK_LEVEL):
|
||||
@ -264,7 +298,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
|
||||
for batch_count, (batch_x, batch_y) in enumerate(batch):
|
||||
_move_dict_value_to_device(model_devcie, batch_x, batch_y)
|
||||
# forward check
|
||||
if batch_count==0:
|
||||
if batch_count == 0:
|
||||
_check_forward_error(model_func=model.forward, check_level=check_level,
|
||||
batch_x=batch_x)
|
||||
|
||||
@ -285,17 +319,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
|
||||
if batch_count == 0:
|
||||
if not isinstance(loss, torch.Tensor):
|
||||
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, "
|
||||
f"but got `{type(loss)}`.")
|
||||
if len(loss.size())!=0:
|
||||
f"but got `{type(loss)}`.")
|
||||
if len(loss.size()) != 0:
|
||||
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, "
|
||||
f"should be torch.size([])")
|
||||
loss.backward()
|
||||
model.zero_grad()
|
||||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
|
||||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
|
||||
break
|
||||
|
||||
if dev_data is not None:
|
||||
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
|
||||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
|
||||
batch_size=batch_size, verbose=-1)
|
||||
tester.test()
|
||||
|
||||
@ -305,18 +339,18 @@ def _check_forward_error(model_func, check_level, batch_x):
|
||||
_missing = ''
|
||||
_unused = ''
|
||||
func_signature = get_func_signature(model_func)
|
||||
if len(check_res['missing'])!=0:
|
||||
if len(check_res['missing']) != 0:
|
||||
_missing = "Function {} misses {}, only provided with {}, " \
|
||||
".\n".format(func_signature, check_res.missing,
|
||||
list(batch_x.keys()))
|
||||
if len(check_res['unused'])!=0:
|
||||
list(batch_x.keys()))
|
||||
if len(check_res['unused']) != 0:
|
||||
if len(check_res.unused) > 1:
|
||||
_unused = "{} are not used ".format(check_res.unused)
|
||||
else:
|
||||
_unused = "{} is not used ".format(check_res.unused)
|
||||
_unused += "in function {}.\n".format(func_signature)
|
||||
if _missing:
|
||||
if len(_unused)>0 and STRICT_CHECK_LEVEL:
|
||||
if len(_unused) > 0 and STRICT_CHECK_LEVEL:
|
||||
_error_str = "(1).{}\n(2).{}".format(_missing, _unused)
|
||||
else:
|
||||
_error_str = _missing
|
||||
@ -329,38 +363,40 @@ def _check_forward_error(model_func, check_level, batch_x):
|
||||
elif check_level == WARNING_CHECK_LEVEL:
|
||||
warnings.warn(message=_unused)
|
||||
|
||||
def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level):
|
||||
|
||||
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):
|
||||
check_res = _check_arg_dict_list(func, [output, batch_y])
|
||||
_missing = ''
|
||||
_unused = ''
|
||||
_duplicated = ''
|
||||
func_signature = get_func_signature(func)
|
||||
prev_func_signature = get_func_signature(prev_func)
|
||||
if len(check_res.missing)>0:
|
||||
if len(check_res.missing) > 0:
|
||||
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \
|
||||
"{}(from target in Dataset)." \
|
||||
.format(func_signature, check_res.missing,
|
||||
list(output.keys()), prev_func_signature,
|
||||
list(batch_y.keys()))
|
||||
if len(check_res.unused)>0:
|
||||
.format(func_signature, check_res.missing,
|
||||
list(output.keys()), prev_func_signature,
|
||||
list(batch_y.keys()))
|
||||
if len(check_res.unused) > 0:
|
||||
if len(check_res.unused) > 1:
|
||||
_unused = "{} are not used ".format(check_res.unused)
|
||||
else:
|
||||
_unused = "{} is not used ".format(check_res.unused)
|
||||
_unused += "in function {}.\n".format(func_signature)
|
||||
if len(check_res.duplicated)>0:
|
||||
if len(check_res.duplicated) > 0:
|
||||
if len(check_res.duplicated) > 1:
|
||||
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \
|
||||
"them in {} at the same time.".format(check_res.duplicated,
|
||||
func_signature,
|
||||
check_res.duplicated,
|
||||
prev_func_signature)
|
||||
else:
|
||||
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
|
||||
"it in {} at the same time.".format(check_res.duplicated,
|
||||
func_signature,
|
||||
check_res.duplicated,
|
||||
prev_func_signature)
|
||||
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
|
||||
else:
|
||||
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
|
||||
"it in {} at the same time.".format(check_res.duplicated,
|
||||
func_signature,
|
||||
check_res.duplicated,
|
||||
prev_func_signature)
|
||||
_number_errs = int(len(_missing) != 0) + int(len(_duplicated) != 0) + int(len(_unused) != 0)
|
||||
if _number_errs > 0:
|
||||
_error_strs = []
|
||||
if _number_errs > 1:
|
||||
|
21
test/core/test_optimizer.py
Normal file
21
test/core/test_optimizer.py
Normal file
@ -0,0 +1,21 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from fastNLP.core.optimizer import SGD
|
||||
|
||||
|
||||
class TestOptim(unittest.TestCase):
|
||||
def test_case(self):
|
||||
optim = SGD(torch.LongTensor(10))
|
||||
print(optim.__dict__)
|
||||
|
||||
optim_2 = SGD(lr=0.001)
|
||||
print(optim_2.__dict__)
|
||||
|
||||
optim_2 = SGD(lr=0.002, momentum=0.989)
|
||||
print(optim_2.__dict__)
|
||||
|
||||
def test_case_2(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = SGD(0.001)
|
@ -4,3 +4,4 @@ import unittest
|
||||
class TestTrainer(unittest.TestCase):
|
||||
def test_case_1(self):
|
||||
pass
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user