[update] distributed trainer, add evaluation part

This commit is contained in:
yunfan 2019-07-20 17:00:50 +08:00
parent 606d63a5a4
commit 329a18976f
3 changed files with 82 additions and 22 deletions

View File

@ -79,6 +79,7 @@ except:
from ..io.model_io import ModelSaver, ModelLoader
from .dataset import DataSet
from .tester import Tester
import logging
try:
import fitlog
@ -167,7 +168,11 @@ class Callback(object):
@property
def disabled(self):
return self._disabled
@property
def logger(self):
return getattr(self._trainer, 'logger', logging)
def on_train_begin(self):
"""
在Train过程开始之前调用
@ -316,21 +321,27 @@ class CallbackManager(Callback):
"""
super(CallbackManager, self).__init__()
# set attribute of trainer environment
self._env = env
self.callbacks = []
if callbacks is not None:
if isinstance(callbacks, list):
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
self.callbacks.extend(callbacks)
else:
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
if callbacks:
self.callbacks += self.prepare_callbacks(callbacks)
def prepare_callbacks(self, callbacks):
if not callbacks:
return []
if isinstance(callbacks, list):
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
self.callbacks.extend(callbacks)
else:
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")
for env_name, env_val in env.items():
for callback in self.callbacks:
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
else:
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")
for env_name, env_val in self._env.items():
for callback in callbacks:
setattr(callback, '_' + env_name, env_val) # Callback.trainer
return callbacks
@_transfer
def on_train_begin(self):
@ -391,11 +402,12 @@ class CallbackManager(Callback):
class DistCallbackManager(CallbackManager):
def __init__(self, env, callbacks_all=None, callbacks_master=None):
super(DistCallbackManager, self).__init__(env)
assert 'trainer' in env
is_master = env['trainer'].is_master
self.patch_callback(callbacks_master, disabled=not is_master)
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks
self.callbacks_all = self.prepare_callbacks(callbacks_all)
self.callbacks_master = self.prepare_callbacks(callbacks_master)
self.callbacks = self.callbacks_all + self.callbacks_master
def patch_callback(self, callbacks, disabled):
@ -944,5 +956,21 @@ class EchoCallback(Callback):
class TesterCallback(Callback):
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):
self.tester = Tester(data, model)
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\
#TODO add compare & save best
super(TesterCallback, self).__init__()
self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0)
self.score = None
def on_validation(self):
cur_socre = self.tester.test()
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
self.epoch, self.n_epochs, self.step, self.n_steps,
self.tester._format_eval_results(cur_socre))
self.logger.info(eval_str)
def on_train_end(self):
self.logger.info('Evaluate on training ends.')
self.on_validation()

View File

@ -11,7 +11,7 @@ import time
from datetime import datetime, timedelta
from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException
from .callback import DistCallbackManager, CallbackException, TesterCallback
from .dataset import DataSet
from .losses import _prepare_losser
from .optimizer import Optimizer
@ -39,10 +39,13 @@ def get_local_rank():
class DistTrainer():
"""Distributed Trainer that support distributed and mixed precision training
"""
def __init__(self, train_data, model, optimizer=None, loss=None,
callbacks_all=None, callbacks_master=None,
batch_size_per_gpu=8, n_epochs=1,
num_data_workers=1, drop_last=False,
dev_data=None, metrics=None,
update_every=1, print_every=10, validate_every=-1,
save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None):
@ -107,6 +110,14 @@ class DistTrainer():
self.data_iterator = self._get_data_iter(self.train_data)
self.n_steps = self._get_n_steps()
# for evaluation, only run eval on master proc
if dev_data and metrics:
cb = TesterCallback(
dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_data_workers)
self.callback_manager.callbacks_master += \
self.callback_manager.prepare_callbacks([cb])
# Setup logging
dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
@ -261,9 +272,6 @@ class DistTrainer():
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)):
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
self.n_steps)
self.logger.info(eval_str)
self.callback_manager.on_validation()
dist.barrier()

View File

@ -13,6 +13,7 @@ import os
import subprocess
from argparse import ArgumentParser
from fastNLP.core.callback import EchoCallback
from fastNLP import AccuracyMetric
def prepare_fake_dataset():
mean = np.array([-3, -3])
@ -106,15 +107,36 @@ class TestDistTrainer(unittest.TestCase):
shutil.rmtree(self.save_path)
def run3(self):
set_rng_seed(100)
data_set, model = prepare_env()
trainer = DistTrainer(
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"),
data_set, model, optimizer=None,
loss=BCELoss(pred="predict", target="y"),
n_epochs=3, print_every=50,
callbacks_all=[EchoCallback('callbacks_all')],
callbacks_master=[EchoCallback('callbacks_master')]
)
trainer.train()
def run4(self):
set_rng_seed(100)
data_set, model = prepare_env()
train_set, dev_set = data_set.split(0.3)
model = NaiveClassifier(2, 1)
trainer = DistTrainer(
train_set, model, optimizer=SGD(lr=0.1),
loss=BCELoss(pred="predict", target="y"),
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
)
trainer.train()
"""
# 应该正确运行
"""
def run_dist(self, run_id):
if torch.cuda.is_available():
ngpu = min(2, torch.cuda.device_count())
@ -133,6 +155,8 @@ class TestDistTrainer(unittest.TestCase):
def test_callback(self):
self.run_dist(3)
def test_dev_data(self):
self.run_dist(4)
if __name__ == '__main__':
runner = TestDistTrainer()