mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-29 18:59:01 +08:00
Merge pull request #60 from KuNyaa/master
add tensorboardX for loss visualization
This commit is contained in:
commit
49ad966c5f
@ -16,6 +16,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
|
||||
- numpy>=1.14.2
|
||||
- torch==0.4.0
|
||||
- torchvision>=0.1.8
|
||||
- tensorboardX
|
||||
|
||||
|
||||
## Resources
|
||||
@ -47,6 +48,11 @@ conda install pytorch torchvision -c pytorch
|
||||
pip3 install torch torchvision
|
||||
```
|
||||
|
||||
### TensorboardX Installation
|
||||
|
||||
```shell
|
||||
pip3 install tensorboardX
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
|
@ -4,6 +4,8 @@ import time
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
import tensorboardX
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from fastNLP.core.action import Action
|
||||
from fastNLP.core.action import RandomSampler, Batchifier
|
||||
@ -86,6 +88,8 @@ class BaseTrainer(object):
|
||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
|
||||
self._optimizer = None
|
||||
self._optimizer_proto = default_args["optimizer"]
|
||||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
|
||||
self._graph_summaried = False
|
||||
|
||||
def train(self, network, train_data, dev_data=None):
|
||||
"""General Training Procedure
|
||||
@ -160,6 +164,11 @@ class BaseTrainer(object):
|
||||
loss = self.get_loss(prediction, batch_y)
|
||||
self.grad_backward(loss)
|
||||
self.update()
|
||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
|
||||
|
||||
if not self._graph_summaried:
|
||||
self._summary_writer.add_graph(network, batch_x)
|
||||
self._graph_summaried = True
|
||||
|
||||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
|
||||
end = time.time()
|
||||
|
@ -1,3 +1,4 @@
|
||||
numpy>=1.14.2
|
||||
torch==0.4.0
|
||||
torchvision>=0.1.8
|
||||
tensorboardX
|
||||
|
Loading…
Reference in New Issue
Block a user