mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-11-30 03:07:59 +08:00
Merge pull request #60 from KuNyaa/master
add tensorboardX for loss visualization
This commit is contained in:
commit
49ad966c5f
@ -16,6 +16,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
|
|||||||
- numpy>=1.14.2
|
- numpy>=1.14.2
|
||||||
- torch==0.4.0
|
- torch==0.4.0
|
||||||
- torchvision>=0.1.8
|
- torchvision>=0.1.8
|
||||||
|
- tensorboardX
|
||||||
|
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
@ -47,6 +48,11 @@ conda install pytorch torchvision -c pytorch
|
|||||||
pip3 install torch torchvision
|
pip3 install torch torchvision
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### TensorboardX Installation
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip3 install tensorboardX
|
||||||
|
```
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|
||||||
|
@ -4,6 +4,8 @@ import time
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import tensorboardX
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from fastNLP.core.action import Action
|
from fastNLP.core.action import Action
|
||||||
from fastNLP.core.action import RandomSampler, Batchifier
|
from fastNLP.core.action import RandomSampler, Batchifier
|
||||||
@ -86,6 +88,8 @@ class BaseTrainer(object):
|
|||||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
|
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
|
||||||
self._optimizer = None
|
self._optimizer = None
|
||||||
self._optimizer_proto = default_args["optimizer"]
|
self._optimizer_proto = default_args["optimizer"]
|
||||||
|
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
|
||||||
|
self._graph_summaried = False
|
||||||
|
|
||||||
def train(self, network, train_data, dev_data=None):
|
def train(self, network, train_data, dev_data=None):
|
||||||
"""General Training Procedure
|
"""General Training Procedure
|
||||||
@ -160,6 +164,11 @@ class BaseTrainer(object):
|
|||||||
loss = self.get_loss(prediction, batch_y)
|
loss = self.get_loss(prediction, batch_y)
|
||||||
self.grad_backward(loss)
|
self.grad_backward(loss)
|
||||||
self.update()
|
self.update()
|
||||||
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=step)
|
||||||
|
|
||||||
|
if not self._graph_summaried:
|
||||||
|
self._summary_writer.add_graph(network, batch_x)
|
||||||
|
self._graph_summaried = True
|
||||||
|
|
||||||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
|
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
numpy>=1.14.2
|
numpy>=1.14.2
|
||||||
torch==0.4.0
|
torch==0.4.0
|
||||||
torchvision>=0.1.8
|
torchvision>=0.1.8
|
||||||
|
tensorboardX
|
||||||
|
Loading…
Reference in New Issue
Block a user