pytorch查看网络模型变量以及对应的尺寸
今天看代码发现,自己对于网络中需要更新的参数并不是很熟悉,然后百度发现了这个方法,记录一下:
在自己定义的模型下面加入这一行就可以查看了:
for name, param in model.named_parameters(): print(name, , param.size())
运行结果如下所示:
gnn.k_weight torch.Size([192, 64]) gnn.node.0.weight torch.Size([64, 19]) gnn.node.0.bias torch.Size([64]) gnn.node.1.weight torch.Size([64]) gnn.node.1.bias torch.Size([64]) gnn.conv_params.0.weight torch.Size([64, 64]) gnn.conv_params.0.bias torch.Size([64]) gnn.conv_params.1.weight torch.Size([64, 64]) gnn.conv_params.1.bias torch.Size([64]) gnn.conv_params.2.weight torch.Size([64, 64]) gnn.conv_params.2.bias torch.Size([64]) gnn.bn2.0.weight torch.Size([64]) gnn.bn2.0.bias torch.Size([64]) gnn.bn2.1.weight torch.Size([64]) gnn.bn2.1.bias torch.Size([64]) gnn.bn2.2.weight torch.Size([64]) gnn.bn2.2.bias torch.Size([64]) gnn.bn3.weight torch.Size([64]) gnn.bn3.bias torch.Size([64]) gnn.graph_attn.0.weight torch.Size([64, 64]) gnn.graph_attn.0.bias torch.Size([64]) gnn.graph_attn.2.weight torch.Size([64]) gnn.graph_attn.2.bias torch.Size([64]) gnn.graph_attn.3.weight torch.Size([1, 64]) gnn.graph_attn.3.bias torch.Size([1]) gnn.graph_attn.4.weight torch.Size([1]) gnn.graph_attn.4.bias torch.Size([1]) mlp.m.0.weight torch.Size([128, 64]) mlp.m.0.bias torch.Size([128]) mlp.m.2.weight torch.Size([2, 128]) mlp.m.2.bias torch.Size([2])
努力加油a啊
上一篇:
JS实现多线程数据分片下载