mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
[update] distributed trainer, add evaluation part
This commit is contained in:
parent
606d63a5a4
commit
329a18976f
@ -79,6 +79,7 @@ except:
|
||||
from ..io.model_io import ModelSaver, ModelLoader
|
||||
from .dataset import DataSet
|
||||
from .tester import Tester
|
||||
import logging
|
||||
|
||||
try:
|
||||
import fitlog
|
||||
@ -167,7 +168,11 @@ class Callback(object):
|
||||
@property
|
||||
def disabled(self):
|
||||
return self._disabled
|
||||
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return getattr(self._trainer, 'logger', logging)
|
||||
|
||||
def on_train_begin(self):
|
||||
"""
|
||||
在Train过程开始之前调用。
|
||||
@ -316,21 +321,27 @@ class CallbackManager(Callback):
|
||||
"""
|
||||
super(CallbackManager, self).__init__()
|
||||
# set attribute of trainer environment
|
||||
|
||||
self._env = env
|
||||
self.callbacks = []
|
||||
if callbacks is not None:
|
||||
if isinstance(callbacks, list):
|
||||
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
|
||||
self.callbacks.extend(callbacks)
|
||||
else:
|
||||
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
|
||||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
|
||||
if callbacks:
|
||||
self.callbacks += self.prepare_callbacks(callbacks)
|
||||
|
||||
def prepare_callbacks(self, callbacks):
|
||||
if not callbacks:
|
||||
return []
|
||||
if isinstance(callbacks, list):
|
||||
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
|
||||
self.callbacks.extend(callbacks)
|
||||
else:
|
||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")
|
||||
|
||||
for env_name, env_val in env.items():
|
||||
for callback in self.callbacks:
|
||||
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
|
||||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
|
||||
else:
|
||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")
|
||||
|
||||
for env_name, env_val in self._env.items():
|
||||
for callback in callbacks:
|
||||
setattr(callback, '_' + env_name, env_val) # Callback.trainer
|
||||
return callbacks
|
||||
|
||||
@_transfer
|
||||
def on_train_begin(self):
|
||||
@ -391,11 +402,12 @@ class CallbackManager(Callback):
|
||||
|
||||
class DistCallbackManager(CallbackManager):
|
||||
def __init__(self, env, callbacks_all=None, callbacks_master=None):
|
||||
super(DistCallbackManager, self).__init__(env)
|
||||
assert 'trainer' in env
|
||||
is_master = env['trainer'].is_master
|
||||
self.patch_callback(callbacks_master, disabled=not is_master)
|
||||
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks
|
||||
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks
|
||||
self.callbacks_all = self.prepare_callbacks(callbacks_all)
|
||||
self.callbacks_master = self.prepare_callbacks(callbacks_master)
|
||||
self.callbacks = self.callbacks_all + self.callbacks_master
|
||||
|
||||
def patch_callback(self, callbacks, disabled):
|
||||
@ -944,5 +956,21 @@ class EchoCallback(Callback):
|
||||
|
||||
|
||||
class TesterCallback(Callback):
|
||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):
|
||||
self.tester = Tester(data, model)
|
||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\
|
||||
#TODO add compare & save best
|
||||
super(TesterCallback, self).__init__()
|
||||
self.tester = Tester(data, model,
|
||||
metrics=metrics, batch_size=batch_size,
|
||||
num_workers=num_workers, verbose=0)
|
||||
self.score = None
|
||||
|
||||
def on_validation(self):
|
||||
cur_socre = self.tester.test()
|
||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
|
||||
self.epoch, self.n_epochs, self.step, self.n_steps,
|
||||
self.tester._format_eval_results(cur_socre))
|
||||
self.logger.info(eval_str)
|
||||
|
||||
def on_train_end(self):
|
||||
self.logger.info('Evaluate on training ends.')
|
||||
self.on_validation()
|
||||
|
@ -11,7 +11,7 @@ import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .batch import DataSetIter, BatchIter
|
||||
from .callback import DistCallbackManager, CallbackException
|
||||
from .callback import DistCallbackManager, CallbackException, TesterCallback
|
||||
from .dataset import DataSet
|
||||
from .losses import _prepare_losser
|
||||
from .optimizer import Optimizer
|
||||
@ -39,10 +39,13 @@ def get_local_rank():
|
||||
|
||||
|
||||
class DistTrainer():
|
||||
"""Distributed Trainer that support distributed and mixed precision training
|
||||
"""
|
||||
def __init__(self, train_data, model, optimizer=None, loss=None,
|
||||
callbacks_all=None, callbacks_master=None,
|
||||
batch_size_per_gpu=8, n_epochs=1,
|
||||
num_data_workers=1, drop_last=False,
|
||||
dev_data=None, metrics=None,
|
||||
update_every=1, print_every=10, validate_every=-1,
|
||||
save_every=-1, save_path=None, device='auto',
|
||||
fp16='', backend=None, init_method=None):
|
||||
@ -107,6 +110,14 @@ class DistTrainer():
|
||||
self.data_iterator = self._get_data_iter(self.train_data)
|
||||
self.n_steps = self._get_n_steps()
|
||||
|
||||
# for evaluation, only run eval on master proc
|
||||
if dev_data and metrics:
|
||||
cb = TesterCallback(
|
||||
dev_data, model, metrics,
|
||||
batch_size=batch_size_per_gpu, num_workers=num_data_workers)
|
||||
self.callback_manager.callbacks_master += \
|
||||
self.callback_manager.prepare_callbacks([cb])
|
||||
|
||||
# Setup logging
|
||||
dist.barrier()
|
||||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
|
||||
@ -261,9 +272,6 @@ class DistTrainer():
|
||||
|
||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
|
||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)):
|
||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
|
||||
self.n_steps)
|
||||
self.logger.info(eval_str)
|
||||
self.callback_manager.on_validation()
|
||||
dist.barrier()
|
||||
|
||||
|
@ -13,6 +13,7 @@ import os
|
||||
import subprocess
|
||||
from argparse import ArgumentParser
|
||||
from fastNLP.core.callback import EchoCallback
|
||||
from fastNLP import AccuracyMetric
|
||||
|
||||
def prepare_fake_dataset():
|
||||
mean = np.array([-3, -3])
|
||||
@ -106,15 +107,36 @@ class TestDistTrainer(unittest.TestCase):
|
||||
shutil.rmtree(self.save_path)
|
||||
|
||||
def run3(self):
|
||||
set_rng_seed(100)
|
||||
data_set, model = prepare_env()
|
||||
trainer = DistTrainer(
|
||||
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"),
|
||||
data_set, model, optimizer=None,
|
||||
loss=BCELoss(pred="predict", target="y"),
|
||||
n_epochs=3, print_every=50,
|
||||
callbacks_all=[EchoCallback('callbacks_all')],
|
||||
callbacks_master=[EchoCallback('callbacks_master')]
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
def run4(self):
|
||||
set_rng_seed(100)
|
||||
data_set, model = prepare_env()
|
||||
|
||||
train_set, dev_set = data_set.split(0.3)
|
||||
|
||||
model = NaiveClassifier(2, 1)
|
||||
|
||||
trainer = DistTrainer(
|
||||
train_set, model, optimizer=SGD(lr=0.1),
|
||||
loss=BCELoss(pred="predict", target="y"),
|
||||
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set,
|
||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
|
||||
)
|
||||
trainer.train()
|
||||
"""
|
||||
# 应该正确运行
|
||||
"""
|
||||
|
||||
def run_dist(self, run_id):
|
||||
if torch.cuda.is_available():
|
||||
ngpu = min(2, torch.cuda.device_count())
|
||||
@ -133,6 +155,8 @@ class TestDistTrainer(unittest.TestCase):
|
||||
def test_callback(self):
|
||||
self.run_dist(3)
|
||||
|
||||
def test_dev_data(self):
|
||||
self.run_dist(4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
runner = TestDistTrainer()
|
||||
|
Loading…
Reference in New Issue
Block a user