mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-02 20:27:35 +08:00
[add] logger in trainer
This commit is contained in:
parent
89142d9dc5
commit
287019450e
@ -164,7 +164,7 @@ class Callback(object):
|
||||
|
||||
@property
|
||||
def is_master(self):
|
||||
return self._trainer.is_master()
|
||||
return self._trainer.is_master
|
||||
|
||||
@property
|
||||
def disabled(self):
|
||||
@ -172,7 +172,7 @@ class Callback(object):
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return getattr(self._trainer, 'logger', logging)
|
||||
return getattr(self._trainer, 'logger', logging.getLogger(__name__))
|
||||
|
||||
def on_train_begin(self):
|
||||
"""
|
||||
@ -405,11 +405,11 @@ class DistCallbackManager(CallbackManager):
|
||||
def __init__(self, env, callbacks_all=None, callbacks_master=None):
|
||||
super(DistCallbackManager, self).__init__(env)
|
||||
assert 'trainer' in env
|
||||
is_master = env['trainer'].is_master
|
||||
self.patch_callback(callbacks_master, disabled=not is_master)
|
||||
self.callbacks_all = self.prepare_callbacks(callbacks_all)
|
||||
self.callbacks_master = self.prepare_callbacks(callbacks_master)
|
||||
self.callbacks = self.callbacks_all + self.callbacks_master
|
||||
self._trainer = env['trainer']
|
||||
self.callbacks_master = []
|
||||
self.callbacks_all = []
|
||||
self.add_callback(callbacks_all, master=False)
|
||||
self.add_callback(callbacks_master, master=True)
|
||||
|
||||
def patch_callback(self, callbacks, disabled):
|
||||
if not callbacks:
|
||||
@ -419,6 +419,14 @@ class DistCallbackManager(CallbackManager):
|
||||
for cb in callbacks:
|
||||
cb._disabled = disabled
|
||||
|
||||
def add_callback(self, cb, master=False):
|
||||
if master:
|
||||
self.patch_callback(cb, not self.is_master)
|
||||
self.callbacks_master += self.prepare_callbacks(cb)
|
||||
else:
|
||||
self.callbacks_all += self.prepare_callbacks(cb)
|
||||
self.callbacks = self.callbacks_all + self.callbacks_master
|
||||
|
||||
|
||||
class GradientClipCallback(Callback):
|
||||
"""
|
||||
@ -1048,15 +1056,26 @@ class TesterCallback(Callback):
|
||||
self.score = cur_score
|
||||
return cur_score, is_better
|
||||
|
||||
def _get_score(self, metric_dict, key):
|
||||
for metric in metric_dict.items():
|
||||
if key in metric:
|
||||
return metric[key]
|
||||
return None
|
||||
|
||||
def compare_better(self, a):
|
||||
if self.score is None:
|
||||
return True
|
||||
if self.metric_key is None:
|
||||
self.metric_key = list(list(self.score.values())[0].keys())[0]
|
||||
k = self.metric_key
|
||||
is_increase = self.score[k] <= a[k] # if equal, prefer more recent results
|
||||
score = self._get_score(self.score, k)
|
||||
new_score = self._get_score(a, k)
|
||||
if score is None or new_score is None:
|
||||
return False
|
||||
if self.increase_better:
|
||||
return is_increase
|
||||
return score <= new_score
|
||||
else:
|
||||
return not is_increase
|
||||
return score >= new_score
|
||||
|
||||
def on_train_end(self):
|
||||
self.logger.info('Evaluate on training ends.')
|
||||
|
@ -22,6 +22,7 @@ from .optimizer import Optimizer
|
||||
from .utils import _build_args
|
||||
from .utils import _move_dict_value_to_device
|
||||
from .utils import _get_func_signature
|
||||
from ..io.logger import initLogger
|
||||
from pkg_resources import parse_version
|
||||
|
||||
__all__ = [
|
||||
@ -40,7 +41,7 @@ def get_local_rank():
|
||||
if 'local_rank' in args and args.local_rank:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function
|
||||
return args.local_rank
|
||||
raise RuntimeError('Please use "python -m torch.distributed.launch train_script.py')
|
||||
raise RuntimeError('Please use "python -m torch.distributed.launch --nproc_per_node=N train_script.py')
|
||||
|
||||
|
||||
class DistTrainer():
|
||||
@ -50,7 +51,7 @@ class DistTrainer():
|
||||
def __init__(self, train_data, model, optimizer=None, loss=None,
|
||||
callbacks_all=None, callbacks_master=None,
|
||||
batch_size_per_gpu=8, n_epochs=1,
|
||||
num_data_workers=1, drop_last=False,
|
||||
num_workers=1, drop_last=False,
|
||||
dev_data=None, metrics=None, metric_key=None,
|
||||
update_every=1, print_every=10, validate_every=-1,
|
||||
log_path=None,
|
||||
@ -78,7 +79,7 @@ class DistTrainer():
|
||||
self.train_data = train_data
|
||||
self.batch_size_per_gpu = int(batch_size_per_gpu)
|
||||
self.n_epochs = int(n_epochs)
|
||||
self.num_data_workers = int(num_data_workers)
|
||||
self.num_data_workers = int(num_workers)
|
||||
self.drop_last = drop_last
|
||||
self.update_every = int(update_every)
|
||||
self.print_every = int(print_every)
|
||||
@ -127,9 +128,8 @@ class DistTrainer():
|
||||
if dev_data and metrics:
|
||||
cb = TesterCallback(
|
||||
dev_data, model, metrics,
|
||||
batch_size=batch_size_per_gpu, num_workers=num_data_workers)
|
||||
self.callback_manager.callbacks_master += \
|
||||
self.callback_manager.prepare_callbacks([cb])
|
||||
batch_size=batch_size_per_gpu, num_workers=num_workers)
|
||||
self.callback_manager.add_callback([cb], master=True)
|
||||
|
||||
# Setup logging
|
||||
dist.barrier()
|
||||
@ -140,10 +140,7 @@ class DistTrainer():
|
||||
self.cp_save_path = None
|
||||
|
||||
# use INFO in the master, WARN for others
|
||||
logging.basicConfig(filename=log_path,
|
||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO if self.is_master else logging.WARN)
|
||||
initLogger(log_path, level=logging.INFO if self.is_master else logging.WARNING)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.info("Setup Distributed Trainer")
|
||||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
|
||||
@ -284,18 +281,8 @@ class DistTrainer():
|
||||
|
||||
self.callback_manager.on_batch_end()
|
||||
|
||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
|
||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)):
|
||||
self.callback_manager.on_valid_begin()
|
||||
eval_res = self.callback_manager.on_validation()
|
||||
eval_res = list(filter(lambda x: x is not None, eval_res))
|
||||
if len(eval_res):
|
||||
eval_res, is_better = list(zip(*eval_res))
|
||||
else:
|
||||
eval_res, is_better = None, None
|
||||
self.callback_manager.on_valid_end(
|
||||
eval_res, self.metric_key, self.optimizer, is_better)
|
||||
dist.barrier()
|
||||
if (self.validate_every > 0 and self.step % self.validate_every == 0):
|
||||
self._do_validation()
|
||||
|
||||
if self.cp_save_path and \
|
||||
self.save_every > 0 and \
|
||||
@ -303,6 +290,9 @@ class DistTrainer():
|
||||
self.save_check_point()
|
||||
|
||||
# ================= mini-batch end ==================== #
|
||||
if self.validate_every < 0:
|
||||
self._do_validation()
|
||||
|
||||
if self.save_every < 0 and self.cp_save_path:
|
||||
self.save_check_point()
|
||||
# lr decay; early stopping
|
||||
@ -351,5 +341,17 @@ class DistTrainer():
|
||||
model_to_save = model_to_save.state_dict()
|
||||
torch.save(model_to_save, path)
|
||||
|
||||
def _do_validation(self):
|
||||
self.callback_manager.on_valid_begin()
|
||||
eval_res = self.callback_manager.on_validation()
|
||||
eval_res = list(filter(lambda x: x is not None, eval_res))
|
||||
if len(eval_res):
|
||||
eval_res, is_better = list(zip(*eval_res))
|
||||
else:
|
||||
eval_res, is_better = None, None
|
||||
self.callback_manager.on_valid_end(
|
||||
eval_res, self.metric_key, self.optimizer, is_better)
|
||||
dist.barrier()
|
||||
|
||||
def close(self):
|
||||
dist.destroy_process_group()
|
||||
|
@ -353,6 +353,8 @@ from .utils import _get_func_signature
|
||||
from .utils import _get_model_device
|
||||
from .utils import _move_model_to_device
|
||||
from ._parallel_utils import _model_contains_inner_module
|
||||
from ..io.logger import initLogger
|
||||
import logging
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
@ -547,6 +549,12 @@ class Trainer(object):
|
||||
else:
|
||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
|
||||
|
||||
log_path = None
|
||||
if save_path is not None:
|
||||
log_path = os.path.join(os.path.dirname(save_path), 'log')
|
||||
initLogger(log_path)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self.use_tqdm = use_tqdm
|
||||
self.pbar = None
|
||||
self.print_every = abs(self.print_every)
|
||||
@ -588,7 +596,7 @@ class Trainer(object):
|
||||
"""
|
||||
results = {}
|
||||
if self.n_epochs <= 0:
|
||||
print(f"training epoch is {self.n_epochs}, nothing was done.")
|
||||
self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.")
|
||||
results['seconds'] = 0.
|
||||
return results
|
||||
try:
|
||||
@ -597,7 +605,7 @@ class Trainer(object):
|
||||
self._load_best_model = load_best_model
|
||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
|
||||
start_time = time.time()
|
||||
print("training epochs started " + self.start_time, flush=True)
|
||||
self.logger.info("training epochs started " + self.start_time)
|
||||
|
||||
try:
|
||||
self.callback_manager.on_train_begin()
|
||||
@ -613,7 +621,7 @@ class Trainer(object):
|
||||
raise e
|
||||
|
||||
if self.dev_data is not None and self.best_dev_perf is not None:
|
||||
print(
|
||||
self.logger.info(
|
||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
|
||||
self.tester._format_eval_results(self.best_dev_perf), )
|
||||
results['best_eval'] = self.best_dev_perf
|
||||
@ -623,9 +631,9 @@ class Trainer(object):
|
||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
||||
load_succeed = self._load_model(self.model, model_name)
|
||||
if load_succeed:
|
||||
print("Reloaded the best model.")
|
||||
self.logger.info("Reloaded the best model.")
|
||||
else:
|
||||
print("Fail to reload best model.")
|
||||
self.logger.info("Fail to reload best model.")
|
||||
finally:
|
||||
pass
|
||||
results['seconds'] = round(time.time() - start_time, 2)
|
||||
@ -825,12 +833,12 @@ class Trainer(object):
|
||||
self.best_metric_indicator = indicator_val
|
||||
else:
|
||||
if self.increase_better is True:
|
||||
if indicator_val > self.best_metric_indicator:
|
||||
if indicator_val >= self.best_metric_indicator:
|
||||
self.best_metric_indicator = indicator_val
|
||||
else:
|
||||
is_better = False
|
||||
else:
|
||||
if indicator_val < self.best_metric_indicator:
|
||||
if indicator_val <= self.best_metric_indicator:
|
||||
self.best_metric_indicator = indicator_val
|
||||
else:
|
||||
is_better = False
|
||||
|
@ -17,6 +17,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
|
||||
'varargs'])
|
||||
@ -659,15 +660,14 @@ class _pseudo_tqdm:
|
||||
"""
|
||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
self.logger = logging.getLogger()
|
||||
|
||||
def write(self, info):
|
||||
print(info)
|
||||
self.logger.info(info)
|
||||
|
||||
def set_postfix_str(self, info):
|
||||
print(info)
|
||||
self.logger.info(info)
|
||||
|
||||
def __getattr__(self, item):
|
||||
def pass_func(*args, **kwargs):
|
||||
|
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fastNLP.embeddings.utils import get_embeddings
|
||||
from fastNLP.core import Const as C
|
||||
|
||||
|
||||
@ -64,7 +63,8 @@ class RegionEmbedding(nn.Module):
|
||||
kernel_sizes = [5, 9]
|
||||
assert isinstance(
|
||||
kernel_sizes, list), 'kernel_sizes should be List(int)'
|
||||
self.embed = get_embeddings(init_embed)
|
||||
# self.embed = nn.Embedding.from_pretrained(torch.tensor(init_embed).float(), freeze=False)
|
||||
self.embed = init_embed
|
||||
try:
|
||||
embed_dim = self.embed.embedding_dim
|
||||
except Exception:
|
||||
|
@ -13,10 +13,11 @@ from fastNLP.core.sampler import BucketSampler
|
||||
from fastNLP.core import LRScheduler
|
||||
from fastNLP.core.const import Const as C
|
||||
from fastNLP.core.vocabulary import VocabularyOption
|
||||
from fastNLP.core.dist_trainer import DistTrainer
|
||||
from utils.util_init import set_rng_seeds
|
||||
import os
|
||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
|
||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
|
||||
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
|
||||
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
|
||||
|
||||
@ -64,27 +65,28 @@ def load_data():
|
||||
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
|
||||
ds.set_input(C.INPUT, C.INPUT_LEN)
|
||||
ds.set_target(C.TARGET)
|
||||
embedding = StaticEmbedding(
|
||||
datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad,
|
||||
normalize=False
|
||||
)
|
||||
return datainfo, embedding
|
||||
|
||||
return datainfo
|
||||
|
||||
|
||||
datainfo, embedding = load_data()
|
||||
datainfo = load_data()
|
||||
embedding = StaticEmbedding(
|
||||
datainfo.vocabs['words'], model_dir_or_name='en-glove-6b-100d', requires_grad=ops.embedding_grad,
|
||||
normalize=False)
|
||||
embedding.embedding.weight.data /= embedding.embedding.weight.data.std()
|
||||
print(embedding.embedding.weight.mean(), embedding.embedding.weight.std())
|
||||
print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.std())
|
||||
|
||||
# 2.或直接复用fastNLP的模型
|
||||
|
||||
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
|
||||
|
||||
datainfo.datasets['train'] = datainfo.datasets['train'][:1000]
|
||||
datainfo.datasets['test'] = datainfo.datasets['test'][:1000]
|
||||
print(datainfo)
|
||||
print(datainfo.datasets['train'][0])
|
||||
|
||||
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
|
||||
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
|
||||
print(model)
|
||||
# print(model)
|
||||
|
||||
# 3. 声明loss,metric,optimizer
|
||||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
|
||||
@ -109,13 +111,17 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
print(device)
|
||||
|
||||
# 4.定义train方法
|
||||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
|
||||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
|
||||
metrics=[metric],
|
||||
dev_data=datainfo.datasets['test'], device=device,
|
||||
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
|
||||
n_epochs=ops.train_epoch, num_workers=4)
|
||||
|
||||
# trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
|
||||
# sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
|
||||
# metrics=[metric],
|
||||
# dev_data=datainfo.datasets['test'], device=device,
|
||||
# check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
|
||||
# n_epochs=ops.train_epoch, num_workers=4)
|
||||
trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
|
||||
metrics=[metric],
|
||||
dev_data=datainfo.datasets['test'], device='cuda',
|
||||
batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks,
|
||||
n_epochs=ops.train_epoch, num_workers=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user