forked from j-min/Adversarial_Video_Summary
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
26 lines (21 loc) · 919 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# -*- coding: utf-8 -*-
from tensorboardX import SummaryWriter
class TensorboardWriter(SummaryWriter):
def __init__(self, logdir):
"""
Extended SummaryWriter Class from tensorboard-pytorch (tensorbaordX)
https://github.com/lanpa/tensorboard-pytorch/blob/master/tensorboardX/writer.py
Internally calls self.file_writer
"""
super(TensorboardWriter, self).__init__(logdir)
self.logdir = self.file_writer.get_logdir()
def update_parameters(self, module, step_i):
"""
module: nn.Module
"""
for name, param in module.named_parameters():
self.add_histogram(name, param.clone().cpu().data.numpy(), step_i)
def update_loss(self, loss, step_i, name='loss'):
self.add_scalar(name, loss, step_i)
def update_histogram(self, values, step_i, name='hist'):
self.add_histogram(name, values, step_i)