mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
Merge branch 'dev' of https://github.com/choosewhatulike/fastNLP-private into dev
This commit is contained in:
commit
e93c6f0053
@ -18,7 +18,7 @@ from fastNLP.api.processor import IndexerProcessor
|
||||
# TODO add pretrain urls
|
||||
model_urls = {
|
||||
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl",
|
||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl",
|
||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
|
||||
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl"
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,10 @@ def chinese_word_segmentation():
|
||||
|
||||
|
||||
def pos_tagging():
|
||||
# 输入已分词序列
|
||||
text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。']
|
||||
text = [text[0].split()]
|
||||
print(text)
|
||||
pos = POS(device='cpu')
|
||||
print(pos.predict(text))
|
||||
|
||||
@ -26,4 +30,4 @@ def syntactic_parsing():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
syntactic_parsing()
|
||||
pos_tagging()
|
||||
|
@ -1,3 +1,11 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from fastNLP.io.model_io import ModelSaver, ModelLoader
|
||||
|
||||
|
||||
class Callback(object):
|
||||
"""An Interface for all callbacks.
|
||||
|
||||
@ -7,6 +15,7 @@ class Callback(object):
|
||||
|
||||
def __init__(self):
|
||||
super(Callback, self).__init__()
|
||||
self.trainer = None # 在Trainer内部被重新赋值
|
||||
|
||||
def before_train(self):
|
||||
# before the main training loop
|
||||
@ -315,6 +324,144 @@ class ControlC(Callback):
|
||||
raise exception # 抛出陌生Error
|
||||
|
||||
|
||||
class SmoothValue(object):
|
||||
def __init__(self, beta: float):
|
||||
self.beta, self.n, self.mov_avg = beta, 0, 0
|
||||
self.smooth = None
|
||||
|
||||
def add_value(self, val: float) -> None:
|
||||
"Add `val` to calculate updated smoothed value."
|
||||
self.n += 1
|
||||
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
|
||||
self.smooth = self.mov_avg / (1 - self.beta ** self.n)
|
||||
|
||||
|
||||
class LRFinder(Callback):
|
||||
def __init__(self, n_batch, start_lr=1e-6, end_lr=10):
|
||||
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它
|
||||
|
||||
:param n_batch: 一个epoch内的iteration数
|
||||
:param start_lr: 学习率下界
|
||||
:param end_lr: 学习率上界
|
||||
"""
|
||||
super(LRFinder, self).__init__()
|
||||
self.start_lr, self.end_lr = start_lr, end_lr
|
||||
self.num_it = n_batch
|
||||
self.stop = False
|
||||
self.best_loss = 0.
|
||||
self.best_lr = None
|
||||
self.loss_history = []
|
||||
self.smooth_value = SmoothValue(0.8)
|
||||
self.opt = None
|
||||
scale = (self.end_lr - self.start_lr) / self.num_it
|
||||
|
||||
self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it))
|
||||
self.find = None
|
||||
self.loader = ModelLoader()
|
||||
|
||||
def before_epoch(self, cur_epoch, total_epoch):
|
||||
if cur_epoch == 1:
|
||||
self.opt = self.trainer.optimizer # pytorch optimizer
|
||||
self.opt.param_groups[0]["lr"] = self.start_lr
|
||||
# save model
|
||||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
|
||||
self.find = True
|
||||
|
||||
def before_backward(self, loss, model):
|
||||
if self.find:
|
||||
if torch.isnan(loss) or self.stop is True:
|
||||
self.stop = True
|
||||
return
|
||||
loss_val = loss.detach().cpu().data
|
||||
self.loss_history.append(loss_val)
|
||||
self.smooth_value.add_value(loss_val)
|
||||
if self.best_loss == 0. or self.smooth_value.smooth < self.best_loss:
|
||||
self.best_loss = self.smooth_value.smooth
|
||||
self.best_lr = self.opt.param_groups[0]["lr"]
|
||||
|
||||
def after_batch(self, *args):
|
||||
if self.find:
|
||||
lr = next(self.lr_gen, None)
|
||||
if lr is None or self.stop is True or self.loss_history[-1] > 4 * self.best_loss:
|
||||
self.stop = True
|
||||
return
|
||||
self.opt.param_groups[0]["lr"] = lr
|
||||
# self.loader.load_pytorch(self.trainer.model, "tmp")
|
||||
|
||||
def after_epoch(self, cur_epoch, n_epoch, optimizer):
|
||||
if cur_epoch == 1:
|
||||
self.opt.param_groups[0]["lr"] = self.best_lr
|
||||
self.find = False
|
||||
# reset model
|
||||
ModelLoader().load_pytorch(self.trainer.model, "tmp")
|
||||
print("Model reset. \nFind best lr={}".format(self.best_lr))
|
||||
|
||||
|
||||
class TensorboardCallback(Callback):
|
||||
"""
|
||||
接受以下一个或多个字符串作为参数:
|
||||
- "model"
|
||||
- "loss"
|
||||
- "metric"
|
||||
"""
|
||||
|
||||
def __init__(self, *options):
|
||||
super(TensorboardCallback, self).__init__()
|
||||
args = {"model", "loss", "metric"}
|
||||
for opt in options:
|
||||
if opt not in args:
|
||||
raise ValueError("Unrecognized argument {}. Expect one of {}".format(opt, args))
|
||||
self.options = options
|
||||
self._summary_writer = None
|
||||
self.graph_added = False
|
||||
|
||||
def before_train(self):
|
||||
save_dir = self.trainer.save_path
|
||||
if save_dir is None:
|
||||
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time))
|
||||
else:
|
||||
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time))
|
||||
self._summary_writer = SummaryWriter(path)
|
||||
|
||||
def before_batch(self, batch_x, batch_y, indices):
|
||||
if "model" in self.options and self.graph_added is False:
|
||||
# tesorboardX 这里有大bug,暂时没法画模型图
|
||||
# from fastNLP.core.utils import _build_args
|
||||
# inputs = _build_args(self.trainer.model, **batch_x)
|
||||
# args = tuple([value for value in inputs.values()])
|
||||
# args = args[0] if len(args) == 1 else args
|
||||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2))
|
||||
self.graph_added = True
|
||||
|
||||
def before_backward(self, loss, model):
|
||||
if "loss" in self.options:
|
||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step)
|
||||
|
||||
if "model" in self.options:
|
||||
for name, param in self.trainer.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.trainer.step)
|
||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step)
|
||||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(),
|
||||
global_step=self.trainer.step)
|
||||
|
||||
def after_valid(self, eval_result, metric_key, optimizer):
|
||||
if "metric" in self.options:
|
||||
for name, metric in eval_result.items():
|
||||
for metric_key, metric_val in metric.items():
|
||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
|
||||
global_step=self.trainer.step)
|
||||
|
||||
def after_train(self, model):
|
||||
self._summary_writer.close()
|
||||
del self._summary_writer
|
||||
|
||||
def on_exception(self, exception, model):
|
||||
if hasattr(self, "_summary_writer"):
|
||||
self._summary_writer.close()
|
||||
del self._summary_writer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
|
||||
manager.before_train(10, 11, 12)
|
||||
|
@ -5,7 +5,6 @@ from datetime import timedelta
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
@ -34,8 +33,8 @@ from fastNLP.core.utils import get_func_signature
|
||||
class Trainer(object):
|
||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
|
||||
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0),
|
||||
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False,
|
||||
timeout=0, use_tqdm=True, use_cuda=False, callbacks=None):
|
||||
check_code_level=0, metric_key=None, sampler=RandomSampler(), prefetch=False, use_tqdm=True,
|
||||
use_cuda=False, callbacks=None):
|
||||
"""
|
||||
:param DataSet train_data: the training data
|
||||
:param torch.nn.modules.module model: a PyTorch model
|
||||
@ -59,12 +58,7 @@ class Trainer(object):
|
||||
|
||||
metric_key="-PPL" # language model gets better as perplexity gets smaller
|
||||
:param BaseSampler sampler: method used to generate batch data.
|
||||
:param num_workers: int, 使用多少个进程来准备数据。默认为0, 即使用主线程生成数据。 特性处于实验阶段,谨慎使用。
|
||||
如果DataSet较大,且每个batch的准备时间很短,使用多进程可能并不能提速。
|
||||
:param pin_memory: bool, 默认为False. 当设置为True时,会使用锁页内存,可能导致内存占用变多。如果内存比较充足,
|
||||
可以考虑设置为True进行加速, 当pin_memory为True时,默认使用non_blocking=True的方式将数据从cpu移动到gpu。
|
||||
:param timeout: float, 大于0的数,只有在num_workers>0时才有用。超过该时间仍然没有获取到一个batch则报错,可以用于
|
||||
检测是否出现了batch产生阻塞的情况。
|
||||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。
|
||||
:param bool use_tqdm: whether to use tqdm to show train progress.
|
||||
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以
|
||||
通过callback机制实现。
|
||||
@ -126,9 +120,7 @@ class Trainer(object):
|
||||
self.best_dev_step = None
|
||||
self.best_dev_perf = None
|
||||
self.sampler = sampler
|
||||
self.num_workers = num_workers
|
||||
self.pin_memory = pin_memory
|
||||
self.timeout = timeout
|
||||
self.prefetch = prefetch
|
||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
|
||||
|
||||
if isinstance(optimizer, torch.optim.Optimizer):
|
||||
@ -195,21 +187,9 @@ class Trainer(object):
|
||||
self._model_device = self.model.parameters().__next__().device
|
||||
self._mode(self.model, is_test=False)
|
||||
|
||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
|
||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
|
||||
start_time = time.time()
|
||||
print("training epochs started " + self.start_time, flush=True)
|
||||
if self.save_path is None:
|
||||
class psudoSW:
|
||||
def __getattr__(self, item):
|
||||
def pass_func(*args, **kwargs):
|
||||
pass
|
||||
|
||||
return pass_func
|
||||
|
||||
self._summary_writer = psudoSW()
|
||||
else:
|
||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
|
||||
self._summary_writer = SummaryWriter(path)
|
||||
|
||||
try:
|
||||
self.callback_manager.before_train()
|
||||
@ -232,8 +212,7 @@ class Trainer(object):
|
||||
else:
|
||||
print("Fail to reload best model.")
|
||||
finally:
|
||||
self._summary_writer.close()
|
||||
del self._summary_writer
|
||||
pass
|
||||
results['seconds'] = round(time.time() - start_time, 2)
|
||||
|
||||
return results
|
||||
@ -250,8 +229,7 @@ class Trainer(object):
|
||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
|
||||
avg_loss = 0
|
||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
|
||||
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
|
||||
keep_process=True)
|
||||
prefetch=self.prefetch, device=self._model_device)
|
||||
for epoch in range(1, self.n_epochs+1):
|
||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
|
||||
# early stopping
|
||||
@ -260,8 +238,6 @@ class Trainer(object):
|
||||
indices = data_iterator.get_batch_indices()
|
||||
# negative sampling; replace unknown; re-weight batch_y
|
||||
self.callback_manager.before_batch(batch_x, batch_y, indices)
|
||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
|
||||
non_blocking=self.pin_memory) # pin_memory, use non_blockling.
|
||||
prediction = self._data_forward(self.model, batch_x)
|
||||
|
||||
# edit prediction
|
||||
@ -279,12 +255,6 @@ class Trainer(object):
|
||||
# lr scheduler; lr_finder; one_cycle
|
||||
self.callback_manager.after_step(self.optimizer)
|
||||
|
||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
|
||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
|
||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
|
||||
if (self.step+1) % self.print_every == 0:
|
||||
if self.use_tqdm:
|
||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
|
||||
@ -319,10 +289,7 @@ class Trainer(object):
|
||||
|
||||
def _do_validation(self, epoch, step):
|
||||
res = self.tester.test()
|
||||
for name, metric in res.items():
|
||||
for metric_key, metric_val in metric.items():
|
||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
|
||||
global_step=self.step)
|
||||
|
||||
if self._better_eval_result(res):
|
||||
if self.save_path is not None:
|
||||
self._save_model(self.model,
|
||||
|
@ -14,7 +14,7 @@ from fastNLP.core.metrics import SpanFPreRecMetric
|
||||
from fastNLP.core.trainer import Trainer
|
||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection
|
||||
from fastNLP.models.sequence_modeling import AdvSeqLabel
|
||||
from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader
|
||||
from fastNLP.io.dataset_loader import ConllxDataLoader
|
||||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id):
|
||||
return embedding_tensor
|
||||
|
||||
|
||||
def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
def train(train_data_path, dev_data_path, checkpoint=None, save=None):
|
||||
# load config
|
||||
train_param = ConfigSection()
|
||||
model_param = ConfigSection()
|
||||
@ -44,9 +44,9 @@ def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
|
||||
# Data Loader
|
||||
print("loading training set...")
|
||||
dataset = ConllxDataLoader().load(train_data_path)
|
||||
dataset = ConllxDataLoader().load(train_data_path, return_dataset=True)
|
||||
print("loading dev set...")
|
||||
dev_data = ConllxDataLoader().load(dev_data_path)
|
||||
dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True)
|
||||
print(dataset)
|
||||
print("================= dataset ready =====================")
|
||||
|
||||
@ -54,9 +54,9 @@ def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
dev_data.rename_field("tag", "truth")
|
||||
|
||||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq")
|
||||
tag_proc = VocabIndexerProcessor("truth")
|
||||
tag_proc = VocabIndexerProcessor("truth", is_input=True)
|
||||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True)
|
||||
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth")
|
||||
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len")
|
||||
|
||||
vocab_proc(dataset)
|
||||
tag_proc(dataset)
|
||||
@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
target="truth",
|
||||
seq_lens="word_seq_origin_len"),
|
||||
dev_data=dev_data, metric_key="f",
|
||||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117")
|
||||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save)
|
||||
trainer.train(load_best_model=True)
|
||||
|
||||
# save model & pipeline
|
||||
@ -102,12 +102,12 @@ def train(train_data_path, dev_data_path, checkpoint=None):
|
||||
|
||||
pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag])
|
||||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
|
||||
torch.save(save_dict, "model_pp_0117.pkl")
|
||||
torch.save(save_dict, os.path.join(save, "model_pp.pkl"))
|
||||
print("pipeline saved")
|
||||
|
||||
|
||||
def run_test(test_path):
|
||||
test_data = ZhConllPOSReader().load(test_path)
|
||||
test_data = ConllxDataLoader().load(test_path, return_dataset=True)
|
||||
|
||||
with open("model_pp_0117.pkl", "rb") as f:
|
||||
save_dict = torch.load(f)
|
||||
@ -157,7 +157,7 @@ if __name__ == "__main__":
|
||||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl
|
||||
if args.checkpoint is None:
|
||||
raise RuntimeError("Please provide the checkpoint. -cp ")
|
||||
train(args.train, args.dev, args.checkpoint)
|
||||
train(args.train, args.dev, args.checkpoint, save=args.save)
|
||||
else:
|
||||
# 一次训练 python train_pos_tag.py
|
||||
train(args.train, args.dev)
|
||||
train(args.train, args.dev, save=args.save)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@ -8,7 +9,7 @@ from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.dataset import construct_dataset
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.sampler import SequentialSampler
|
||||
import time
|
||||
|
||||
|
||||
def generate_fake_dataset(num_samples=1000):
|
||||
"""
|
||||
@ -161,12 +162,13 @@ class TestCase1(unittest.TestCase):
|
||||
dataset = generate_fake_dataset(num_samples)
|
||||
|
||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True)
|
||||
for batch_x, batch_y in batch:
|
||||
time.sleep(pause_seconds)
|
||||
# 这里发生OOM
|
||||
# for batch_x, batch_y in batch:
|
||||
# time.sleep(pause_seconds)
|
||||
|
||||
num_workers = 2
|
||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers,
|
||||
pin_memory=True)
|
||||
for batch_x, batch_y in batch:
|
||||
time.sleep(pause_seconds)
|
||||
|
||||
# 这里发生OOM
|
||||
# for batch_x, batch_y in batch:
|
||||
# time.sleep(pause_seconds)
|
||||
|
@ -3,7 +3,9 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC
|
||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \
|
||||
LRFinder, \
|
||||
TensorboardCallback
|
||||
from fastNLP.core.dataset import DataSet
|
||||
from fastNLP.core.instance import Instance
|
||||
from fastNLP.core.losses import BCELoss
|
||||
@ -52,7 +54,7 @@ class TestCallback(unittest.TestCase):
|
||||
data_set, model = prepare_env()
|
||||
trainer = Trainer(data_set, model,
|
||||
loss=BCELoss(pred="predict", target="y"),
|
||||
n_epochs=30,
|
||||
n_epochs=20,
|
||||
batch_size=32,
|
||||
print_every=50,
|
||||
optimizer=SGD(lr=0.1),
|
||||
@ -67,7 +69,7 @@ class TestCallback(unittest.TestCase):
|
||||
data_set, model = prepare_env()
|
||||
trainer = Trainer(data_set, model,
|
||||
loss=BCELoss(pred="predict", target="y"),
|
||||
n_epochs=50,
|
||||
n_epochs=20,
|
||||
batch_size=32,
|
||||
print_every=50,
|
||||
optimizer=SGD(lr=0.01),
|
||||
@ -83,7 +85,7 @@ class TestCallback(unittest.TestCase):
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
trainer = Trainer(data_set, model,
|
||||
loss=BCELoss(pred="predict", target="y"),
|
||||
n_epochs=50,
|
||||
n_epochs=5,
|
||||
batch_size=32,
|
||||
print_every=50,
|
||||
optimizer=optimizer,
|
||||
@ -98,7 +100,7 @@ class TestCallback(unittest.TestCase):
|
||||
data_set, model = prepare_env()
|
||||
trainer = Trainer(data_set, model,
|
||||
loss=BCELoss(pred="predict", target="y"),
|
||||
n_epochs=50,
|
||||
n_epochs=5,
|
||||
batch_size=32,
|
||||
print_every=50,
|
||||
optimizer=SGD(lr=0.1),
|
||||
@ -106,3 +108,31 @@ class TestCallback(unittest.TestCase):
|
||||
use_tqdm=False,
|
||||
callbacks=[ControlC(False)])
|
||||
trainer.train()
|
||||
|
||||
def test_LRFinder(self):
|
||||
data_set, model = prepare_env()
|
||||
trainer = Trainer(data_set, model,
|
||||
loss=BCELoss(pred="predict", target="y"),
|
||||
n_epochs=5,
|
||||
batch_size=32,
|
||||
print_every=50,
|
||||
optimizer=SGD(lr=0.1),
|
||||
check_code_level=2,
|
||||
use_tqdm=False,
|
||||
callbacks=[LRFinder(len(data_set) // 32)])
|
||||
trainer.train()
|
||||
|
||||
def test_TensorboardCallback(self):
|
||||
data_set, model = prepare_env()
|
||||
trainer = Trainer(data_set, model,
|
||||
loss=BCELoss(pred="predict", target="y"),
|
||||
n_epochs=5,
|
||||
batch_size=32,
|
||||
print_every=50,
|
||||
optimizer=SGD(lr=0.1),
|
||||
check_code_level=2,
|
||||
use_tqdm=False,
|
||||
dev_data=data_set,
|
||||
metrics=AccuracyMetric(pred="predict", target="y"),
|
||||
callbacks=[TensorboardCallback("loss", "metric")])
|
||||
trainer.train()
|
||||
|
Loading…
Reference in New Issue
Block a user