Merge pull request #60 from KuNyaa/master

add tensorboardX for loss visualization
This commit is contained in:
Coet 2018-09-06 09:56:14 +08:00 committed by GitHub
commit 49ad966c5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 0 deletions

View File

@ -16,6 +16,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
- numpy>=1.14.2 - numpy>=1.14.2
- torch==0.4.0 - torch==0.4.0
- torchvision>=0.1.8 - torchvision>=0.1.8
- tensorboardX
## Resources ## Resources
@ -47,6 +48,11 @@ conda install pytorch torchvision -c pytorch
pip3 install torch torchvision pip3 install torch torchvision
``` ```
### TensorboardX Installation
```shell
pip3 install tensorboardX
```
## Project Structure ## Project Structure

View File

@ -4,6 +4,8 @@ import time
from datetime import timedelta from datetime import timedelta
import torch import torch
import tensorboardX
from tensorboardX import SummaryWriter
from fastNLP.core.action import Action from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier from fastNLP.core.action import RandomSampler, Batchifier
@ -86,6 +88,8 @@ class BaseTrainer(object):
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
self._optimizer = None self._optimizer = None
self._optimizer_proto = default_args["optimizer"] self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False
def train(self, network, train_data, dev_data=None): def train(self, network, train_data, dev_data=None):
"""General Training Procedure """General Training Procedure
@ -160,6 +164,11 @@ class BaseTrainer(object):
loss = self.get_loss(prediction, batch_y) loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss) self.grad_backward(loss)
self.update() self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
if not self._graph_summaried:
self._summary_writer.add_graph(network, batch_x)
self._graph_summaried = True
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time() end = time.time()

View File

@ -1,3 +1,4 @@
numpy>=1.14.2 numpy>=1.14.2
torch==0.4.0 torch==0.4.0
torchvision>=0.1.8 torchvision>=0.1.8
tensorboardX