mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-03 04:37:37 +08:00
[add] logger in trainer
This commit is contained in:
parent
89142d9dc5
commit
287019450e
@ -164,7 +164,7 @@ class Callback(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_master(self):
|
def is_master(self):
|
||||||
return self._trainer.is_master()
|
return self._trainer.is_master
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disabled(self):
|
def disabled(self):
|
||||||
@ -172,7 +172,7 @@ class Callback(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def logger(self):
|
def logger(self):
|
||||||
return getattr(self._trainer, 'logger', logging)
|
return getattr(self._trainer, 'logger', logging.getLogger(__name__))
|
||||||
|
|
||||||
def on_train_begin(self):
|
def on_train_begin(self):
|
||||||
"""
|
"""
|
||||||
@ -405,11 +405,11 @@ class DistCallbackManager(CallbackManager):
|
|||||||
def __init__(self, env, callbacks_all=None, callbacks_master=None):
|
def __init__(self, env, callbacks_all=None, callbacks_master=None):
|
||||||
super(DistCallbackManager, self).__init__(env)
|
super(DistCallbackManager, self).__init__(env)
|
||||||
assert 'trainer' in env
|
assert 'trainer' in env
|
||||||
is_master = env['trainer'].is_master
|
self._trainer = env['trainer']
|
||||||
self.patch_callback(callbacks_master, disabled=not is_master)
|
self.callbacks_master = []
|
||||||
self.callbacks_all = self.prepare_callbacks(callbacks_all)
|
self.callbacks_all = []
|
||||||
self.callbacks_master = self.prepare_callbacks(callbacks_master)
|
self.add_callback(callbacks_all, master=False)
|
||||||
self.callbacks = self.callbacks_all + self.callbacks_master
|
self.add_callback(callbacks_master, master=True)
|
||||||
|
|
||||||
def patch_callback(self, callbacks, disabled):
|
def patch_callback(self, callbacks, disabled):
|
||||||
if not callbacks:
|
if not callbacks:
|
||||||
@ -419,6 +419,14 @@ class DistCallbackManager(CallbackManager):
|
|||||||
for cb in callbacks:
|
for cb in callbacks:
|
||||||
cb._disabled = disabled
|
cb._disabled = disabled
|
||||||
|
|
||||||
|
def add_callback(self, cb, master=False):
|
||||||
|
if master:
|
||||||
|
self.patch_callback(cb, not self.is_master)
|
||||||
|
self.callbacks_master += self.prepare_callbacks(cb)
|
||||||
|
else:
|
||||||
|
self.callbacks_all += self.prepare_callbacks(cb)
|
||||||
|
self.callbacks = self.callbacks_all + self.callbacks_master
|
||||||
|
|
||||||
|
|
||||||
class GradientClipCallback(Callback):
|
class GradientClipCallback(Callback):
|
||||||
"""
|
"""
|
||||||
@ -1048,15 +1056,26 @@ class TesterCallback(Callback):
|
|||||||
self.score = cur_score
|
self.score = cur_score
|
||||||
return cur_score, is_better
|
return cur_score, is_better
|
||||||
|
|
||||||
|
def _get_score(self, metric_dict, key):
|
||||||
|
for metric in metric_dict.items():
|
||||||
|
if key in metric:
|
||||||
|
return metric[key]
|
||||||
|
return None
|
||||||
|
|
||||||
def compare_better(self, a):
|
def compare_better(self, a):
|
||||||
if self.score is None:
|
if self.score is None:
|
||||||
return True
|
return True
|
||||||
|
if self.metric_key is None:
|
||||||
|
self.metric_key = list(list(self.score.values())[0].keys())[0]
|
||||||
k = self.metric_key
|
k = self.metric_key
|
||||||
is_increase = self.score[k] <= a[k] # if equal, prefer more recent results
|
score = self._get_score(self.score, k)
|
||||||
|
new_score = self._get_score(a, k)
|
||||||
|
if score is None or new_score is None:
|
||||||
|
return False
|
||||||
if self.increase_better:
|
if self.increase_better:
|
||||||
return is_increase
|
return score <= new_score
|
||||||
else:
|
else:
|
||||||
return not is_increase
|
return score >= new_score
|
||||||
|
|
||||||
def on_train_end(self):
|
def on_train_end(self):
|
||||||
self.logger.info('Evaluate on training ends.')
|
self.logger.info('Evaluate on training ends.')
|
||||||
|
@ -22,6 +22,7 @@ from .optimizer import Optimizer
|
|||||||
from .utils import _build_args
|
from .utils import _build_args
|
||||||
from .utils import _move_dict_value_to_device
|
from .utils import _move_dict_value_to_device
|
||||||
from .utils import _get_func_signature
|
from .utils import _get_func_signature
|
||||||
|
from ..io.logger import initLogger
|
||||||
from pkg_resources import parse_version
|
from pkg_resources import parse_version
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -40,7 +41,7 @@ def get_local_rank():
|
|||||||
if 'local_rank' in args and args.local_rank:
|
if 'local_rank' in args and args.local_rank:
|
||||||
os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function
|
os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function
|
||||||
return args.local_rank
|
return args.local_rank
|
||||||
raise RuntimeError('Please use "python -m torch.distributed.launch train_script.py')
|
raise RuntimeError('Please use "python -m torch.distributed.launch --nproc_per_node=N train_script.py')
|
||||||
|
|
||||||
|
|
||||||
class DistTrainer():
|
class DistTrainer():
|
||||||
@ -50,7 +51,7 @@ class DistTrainer():
|
|||||||
def __init__(self, train_data, model, optimizer=None, loss=None,
|
def __init__(self, train_data, model, optimizer=None, loss=None,
|
||||||
callbacks_all=None, callbacks_master=None,
|
callbacks_all=None, callbacks_master=None,
|
||||||
batch_size_per_gpu=8, n_epochs=1,
|
batch_size_per_gpu=8, n_epochs=1,
|
||||||
num_data_workers=1, drop_last=False,
|
num_workers=1, drop_last=False,
|
||||||
dev_data=None, metrics=None, metric_key=None,
|
dev_data=None, metrics=None, metric_key=None,
|
||||||
update_every=1, print_every=10, validate_every=-1,
|
update_every=1, print_every=10, validate_every=-1,
|
||||||
log_path=None,
|
log_path=None,
|
||||||
@ -78,7 +79,7 @@ class DistTrainer():
|
|||||||
self.train_data = train_data
|
self.train_data = train_data
|
||||||
self.batch_size_per_gpu = int(batch_size_per_gpu)
|
self.batch_size_per_gpu = int(batch_size_per_gpu)
|
||||||
self.n_epochs = int(n_epochs)
|
self.n_epochs = int(n_epochs)
|
||||||
self.num_data_workers = int(num_data_workers)
|
self.num_data_workers = int(num_workers)
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
self.update_every = int(update_every)
|
self.update_every = int(update_every)
|
||||||
self.print_every = int(print_every)
|
self.print_every = int(print_every)
|
||||||
@ -127,9 +128,8 @@ class DistTrainer():
|
|||||||
if dev_data and metrics:
|
if dev_data and metrics:
|
||||||
cb = TesterCallback(
|
cb = TesterCallback(
|
||||||
dev_data, model, metrics,
|
dev_data, model, metrics,
|
||||||
batch_size=batch_size_per_gpu, num_workers=num_data_workers)
|
batch_size=batch_size_per_gpu, num_workers=num_workers)
|
||||||
self.callback_manager.callbacks_master += \
|
self.callback_manager.add_callback([cb], master=True)
|
||||||
self.callback_manager.prepare_callbacks([cb])
|
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
@ -140,10 +140,7 @@ class DistTrainer():
|
|||||||
self.cp_save_path = None
|
self.cp_save_path = None
|
||||||
|
|
||||||
# use INFO in the master, WARN for others
|
# use INFO in the master, WARN for others
|
||||||
logging.basicConfig(filename=log_path,
|
initLogger(log_path, level=logging.INFO if self.is_master else logging.WARNING)
|
||||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
|
||||||
datefmt='%m/%d/%Y %H:%M:%S',
|
|
||||||
level=logging.INFO if self.is_master else logging.WARN)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.logger.info("Setup Distributed Trainer")
|
self.logger.info("Setup Distributed Trainer")
|
||||||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
|
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
|
||||||
@ -284,18 +281,8 @@ class DistTrainer():
|
|||||||
|
|
||||||
self.callback_manager.on_batch_end()
|
self.callback_manager.on_batch_end()
|
||||||
|
|
||||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
|
if (self.validate_every > 0 and self.step % self.validate_every == 0):
|
||||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)):
|
self._do_validation()
|
||||||
self.callback_manager.on_valid_begin()
|
|
||||||
eval_res = self.callback_manager.on_validation()
|
|
||||||
eval_res = list(filter(lambda x: x is not None, eval_res))
|
|
||||||
if len(eval_res):
|
|
||||||
eval_res, is_better = list(zip(*eval_res))
|
|
||||||
else:
|
|
||||||
eval_res, is_better = None, None
|
|
||||||
self.callback_manager.on_valid_end(
|
|
||||||
eval_res, self.metric_key, self.optimizer, is_better)
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
if self.cp_save_path and \
|
if self.cp_save_path and \
|
||||||
self.save_every > 0 and \
|
self.save_every > 0 and \
|
||||||
@ -303,6 +290,9 @@ class DistTrainer():
|
|||||||
self.save_check_point()
|
self.save_check_point()
|
||||||
|
|
||||||
# ================= mini-batch end ==================== #
|
# ================= mini-batch end ==================== #
|
||||||
|
if self.validate_every < 0:
|
||||||
|
self._do_validation()
|
||||||
|
|
||||||
if self.save_every < 0 and self.cp_save_path:
|
if self.save_every < 0 and self.cp_save_path:
|
||||||
self.save_check_point()
|
self.save_check_point()
|
||||||
# lr decay; early stopping
|
# lr decay; early stopping
|
||||||
@ -351,5 +341,17 @@ class DistTrainer():
|
|||||||
model_to_save = model_to_save.state_dict()
|
model_to_save = model_to_save.state_dict()
|
||||||
torch.save(model_to_save, path)
|
torch.save(model_to_save, path)
|
||||||
|
|
||||||
|
def _do_validation(self):
|
||||||
|
self.callback_manager.on_valid_begin()
|
||||||
|
eval_res = self.callback_manager.on_validation()
|
||||||
|
eval_res = list(filter(lambda x: x is not None, eval_res))
|
||||||
|
if len(eval_res):
|
||||||
|
eval_res, is_better = list(zip(*eval_res))
|
||||||
|
else:
|
||||||
|
eval_res, is_better = None, None
|
||||||
|
self.callback_manager.on_valid_end(
|
||||||
|
eval_res, self.metric_key, self.optimizer, is_better)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
@ -353,6 +353,8 @@ from .utils import _get_func_signature
|
|||||||
from .utils import _get_model_device
|
from .utils import _get_model_device
|
||||||
from .utils import _move_model_to_device
|
from .utils import _move_model_to_device
|
||||||
from ._parallel_utils import _model_contains_inner_module
|
from ._parallel_utils import _model_contains_inner_module
|
||||||
|
from ..io.logger import initLogger
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class Trainer(object):
|
class Trainer(object):
|
||||||
@ -547,6 +549,12 @@ class Trainer(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
|
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
|
||||||
|
|
||||||
|
log_path = None
|
||||||
|
if save_path is not None:
|
||||||
|
log_path = os.path.join(os.path.dirname(save_path), 'log')
|
||||||
|
initLogger(log_path)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
self.use_tqdm = use_tqdm
|
self.use_tqdm = use_tqdm
|
||||||
self.pbar = None
|
self.pbar = None
|
||||||
self.print_every = abs(self.print_every)
|
self.print_every = abs(self.print_every)
|
||||||
@ -588,7 +596,7 @@ class Trainer(object):
|
|||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
if self.n_epochs <= 0:
|
if self.n_epochs <= 0:
|
||||||
print(f"training epoch is {self.n_epochs}, nothing was done.")
|
self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.")
|
||||||
results['seconds'] = 0.
|
results['seconds'] = 0.
|
||||||
return results
|
return results
|
||||||
try:
|
try:
|
||||||
@ -597,7 +605,7 @@ class Trainer(object):
|
|||||||
self._load_best_model = load_best_model
|
self._load_best_model = load_best_model
|
||||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
print("training epochs started " + self.start_time, flush=True)
|
self.logger.info("training epochs started " + self.start_time)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.callback_manager.on_train_begin()
|
self.callback_manager.on_train_begin()
|
||||||
@ -613,7 +621,7 @@ class Trainer(object):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
if self.dev_data is not None and self.best_dev_perf is not None:
|
if self.dev_data is not None and self.best_dev_perf is not None:
|
||||||
print(
|
self.logger.info(
|
||||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
|
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
|
||||||
self.tester._format_eval_results(self.best_dev_perf), )
|
self.tester._format_eval_results(self.best_dev_perf), )
|
||||||
results['best_eval'] = self.best_dev_perf
|
results['best_eval'] = self.best_dev_perf
|
||||||
@ -623,9 +631,9 @@ class Trainer(object):
|
|||||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
|
||||||
load_succeed = self._load_model(self.model, model_name)
|
load_succeed = self._load_model(self.model, model_name)
|
||||||
if load_succeed:
|
if load_succeed:
|
||||||
print("Reloaded the best model.")
|
self.logger.info("Reloaded the best model.")
|
||||||
else:
|
else:
|
||||||
print("Fail to reload best model.")
|
self.logger.info("Fail to reload best model.")
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
results['seconds'] = round(time.time() - start_time, 2)
|
results['seconds'] = round(time.time() - start_time, 2)
|
||||||
@ -825,12 +833,12 @@ class Trainer(object):
|
|||||||
self.best_metric_indicator = indicator_val
|
self.best_metric_indicator = indicator_val
|
||||||
else:
|
else:
|
||||||
if self.increase_better is True:
|
if self.increase_better is True:
|
||||||
if indicator_val > self.best_metric_indicator:
|
if indicator_val >= self.best_metric_indicator:
|
||||||
self.best_metric_indicator = indicator_val
|
self.best_metric_indicator = indicator_val
|
||||||
else:
|
else:
|
||||||
is_better = False
|
is_better = False
|
||||||
else:
|
else:
|
||||||
if indicator_val < self.best_metric_indicator:
|
if indicator_val <= self.best_metric_indicator:
|
||||||
self.best_metric_indicator = indicator_val
|
self.best_metric_indicator = indicator_val
|
||||||
else:
|
else:
|
||||||
is_better = False
|
is_better = False
|
||||||
|
@ -17,6 +17,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import logging
|
||||||
|
|
||||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
|
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
|
||||||
'varargs'])
|
'varargs'])
|
||||||
@ -659,15 +660,14 @@ class _pseudo_tqdm:
|
|||||||
"""
|
"""
|
||||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
|
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
pass
|
self.logger = logging.getLogger()
|
||||||
|
|
||||||
def write(self, info):
|
def write(self, info):
|
||||||
print(info)
|
self.logger.info(info)
|
||||||
|
|
||||||
def set_postfix_str(self, info):
|
def set_postfix_str(self, info):
|
||||||
print(info)
|
self.logger.info(info)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
def pass_func(*args, **kwargs):
|
def pass_func(*args, **kwargs):
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from fastNLP.embeddings.utils import get_embeddings
|
|
||||||
from fastNLP.core import Const as C
|
from fastNLP.core import Const as C
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +63,8 @@ class RegionEmbedding(nn.Module):
|
|||||||
kernel_sizes = [5, 9]
|
kernel_sizes = [5, 9]
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
kernel_sizes, list), 'kernel_sizes should be List(int)'
|
kernel_sizes, list), 'kernel_sizes should be List(int)'
|
||||||
self.embed = get_embeddings(init_embed)
|
# self.embed = nn.Embedding.from_pretrained(torch.tensor(init_embed).float(), freeze=False)
|
||||||
|
self.embed = init_embed
|
||||||
try:
|
try:
|
||||||
embed_dim = self.embed.embedding_dim
|
embed_dim = self.embed.embedding_dim
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -13,10 +13,11 @@ from fastNLP.core.sampler import BucketSampler
|
|||||||
from fastNLP.core import LRScheduler
|
from fastNLP.core import LRScheduler
|
||||||
from fastNLP.core.const import Const as C
|
from fastNLP.core.const import Const as C
|
||||||
from fastNLP.core.vocabulary import VocabularyOption
|
from fastNLP.core.vocabulary import VocabularyOption
|
||||||
|
from fastNLP.core.dist_trainer import DistTrainer
|
||||||
from utils.util_init import set_rng_seeds
|
from utils.util_init import set_rng_seeds
|
||||||
import os
|
import os
|
||||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
|
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
|
||||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
|
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
|
||||||
|
|
||||||
@ -64,27 +65,28 @@ def load_data():
|
|||||||
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
|
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
|
||||||
ds.set_input(C.INPUT, C.INPUT_LEN)
|
ds.set_input(C.INPUT, C.INPUT_LEN)
|
||||||
ds.set_target(C.TARGET)
|
ds.set_target(C.TARGET)
|
||||||
embedding = StaticEmbedding(
|
|
||||||
datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad,
|
return datainfo
|
||||||
normalize=False
|
|
||||||
)
|
|
||||||
return datainfo, embedding
|
|
||||||
|
|
||||||
|
|
||||||
datainfo, embedding = load_data()
|
datainfo = load_data()
|
||||||
|
embedding = StaticEmbedding(
|
||||||
|
datainfo.vocabs['words'], model_dir_or_name='en-glove-6b-100d', requires_grad=ops.embedding_grad,
|
||||||
|
normalize=False)
|
||||||
embedding.embedding.weight.data /= embedding.embedding.weight.data.std()
|
embedding.embedding.weight.data /= embedding.embedding.weight.data.std()
|
||||||
print(embedding.embedding.weight.mean(), embedding.embedding.weight.std())
|
print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.std())
|
||||||
|
|
||||||
# 2.或直接复用fastNLP的模型
|
# 2.或直接复用fastNLP的模型
|
||||||
|
|
||||||
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
|
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
|
||||||
|
datainfo.datasets['train'] = datainfo.datasets['train'][:1000]
|
||||||
|
datainfo.datasets['test'] = datainfo.datasets['test'][:1000]
|
||||||
print(datainfo)
|
print(datainfo)
|
||||||
print(datainfo.datasets['train'][0])
|
print(datainfo.datasets['train'][0])
|
||||||
|
|
||||||
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
|
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
|
||||||
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
|
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
|
||||||
print(model)
|
# print(model)
|
||||||
|
|
||||||
# 3. 声明loss,metric,optimizer
|
# 3. 声明loss,metric,optimizer
|
||||||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
|
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
|
||||||
@ -109,13 +111,17 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|||||||
print(device)
|
print(device)
|
||||||
|
|
||||||
# 4.定义train方法
|
# 4.定义train方法
|
||||||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
|
# trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
|
||||||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
|
# sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
|
||||||
metrics=[metric],
|
# metrics=[metric],
|
||||||
dev_data=datainfo.datasets['test'], device=device,
|
# dev_data=datainfo.datasets['test'], device=device,
|
||||||
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
|
# check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
|
||||||
n_epochs=ops.train_epoch, num_workers=4)
|
# n_epochs=ops.train_epoch, num_workers=4)
|
||||||
|
trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
|
||||||
|
metrics=[metric],
|
||||||
|
dev_data=datainfo.datasets['test'], device='cuda',
|
||||||
|
batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks,
|
||||||
|
n_epochs=ops.train_epoch, num_workers=4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user