快捷搜索: 王者荣耀 脱发

Pytorch基础 - 3. torch.utils.tensorboard

1. 简介

Tensorboard是Tensorflow的可视化工具,常用来可视化网络的损失函数,网络结构,图像等。后来将Tensorboard集成到了PyTorch中,常使用torch.utils.tensorboard来进行导入。官网地址:

2. 基本步骤

(1) 首先执行如下代码,具体含义写在注释里

from torch.utils.tensorboard import SummaryWriter

if __name__ == __main__:
    # 新建实例, log_dir为生成文件的存储地址, 不写参数默认是./run/文件夹下
    writer = SummaryWriter(log_dir=events存储地址)
    # 调用对象的方法,给文件写入数据
    writer.add_scalar(tag="show1", scalar_value=loss, global_step=epoch)
    writer.add_scalars(main_tag="show2", tag_scalar_dict={fun1: None, fun2: None}, global_step=epoch)
    writer.add_graph(model=model, input_to_model=input)
    # 关闭writer
    writer.close()

(2) 执行完上述代码,会在设置的log_dir路径下生成一个以events.out开头的文件,如下所示

(3) 执行如下命令,运行该文件

tensorboard --logdir=revents存储地址

(4) 运行结束后复制生成的地址并在浏览器中打开

3. 示例1 - 可视化单条曲线

from torch.utils.tensorboard import SummaryWriter

if __name__ == __main__:
    writer = SummaryWriter(log_dir=/home/Test/log_dir)
    for i in range(100):
        writer.add_scalar(tag="y=2x", scalar_value=2 * i, global_step=i)
    writer.close()

其中参数的具体含义如下:

4. 示例2 - 可视化多条曲线

from torch.utils.tensorboard import SummaryWriter

if __name__ == __main__:
    writer = SummaryWriter(log_dir=/home/Test/log_dir)
    for x in range(100):
        writer.add_scalars(multi_funcs, {2x: 2 * x, 3x: 3 * x, 4x: 4 * x}, x)
    writer.close()

其中参数的具体含义如下:

5. 示例3 - 可视化网络结构

新建一个MLP网络,通过add_graph来保存网络结构

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter


class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        out = self.model(x)
        return out


if __name__ == __main__:
    writer = SummaryWriter(log_dir=/home/TenbodTest/log_dir)
    model = MLP()
    input = torch.rand(32, 1, 28, 28).view(-1, 28 * 28)
    writer.add_graph(model, input)
    writer.close()

其中参数的具体含义如下:

可视化结果如下:

经验分享 程序员 微信小程序 职场和发展