Merge pull request #60 from KuNyaa/master

add tensorboardX for loss visualization
This commit is contained in:
Coet 2018-09-06 09:56:14 +08:00 committed by GitHub
commit 49ad966c5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 0 deletions

View File

@ -16,6 +16,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
- numpy>=1.14.2
- torch==0.4.0
- torchvision>=0.1.8
- tensorboardX
## Resources
@ -47,6 +48,11 @@ conda install pytorch torchvision -c pytorch
pip3 install torch torchvision
```
### TensorboardX Installation
```shell
pip3 install tensorboardX
```
## Project Structure

View File

@ -4,6 +4,8 @@ import time
from datetime import timedelta
import torch
import tensorboardX
from tensorboardX import SummaryWriter
from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
@ -86,6 +88,8 @@ class BaseTrainer(object):
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
self._optimizer = None
self._optimizer_proto = default_args["optimizer"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False
def train(self, network, train_data, dev_data=None):
"""General Training Procedure
@ -160,6 +164,11 @@ class BaseTrainer(object):
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
if not self._graph_summaried:
self._summary_writer.add_graph(network, batch_x)
self._graph_summaried = True
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time()

View File

@ -1,3 +1,4 @@
numpy>=1.14.2
torch==0.4.0
torchvision>=0.1.8
tensorboardX