Netron查看模型结构
- 参照模型
- 安装Netron
- 写netron代码
- 运行查看结果
- 需要关注的地方
- 2024年4月27日14:32:30----0.9.2
参照模型
以pytorch官网的tutorial为观察对象,链接是https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
模型代码如下
import torch.nn as nn
import torch.nn.functional as F
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size, hidden_size)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.h2o = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
output = self.h2o(hidden)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_categories)
安装Netron
pip install netron即可
其他安装方式参考链接
https://blog.csdn.net/m0_49963403/article/details/136242313
写netron代码
随便找一个地方打个点,如sample方法中
import netron
max_length = 20
# Sample from a category and starting letter
def sample(category, start_letter='A'):
with torch.no_grad(): # no need to track history in sampling
category_tensor = categoryTensor(category)
input = inputTensor(start_letter)
hidden = rnn.initHidden()
output_name = start_letter
for i in range(max_length):
# print("category_tensor",category_tensor.size())
# print("input[0]",input[0].size())
# print("hidden",hidden.size())
output, hidden = rnn(category_tensor, input[0], hidden)
torch.onnx.export(rnn,(category_tensor, input[0], hidden) , f='AlexNet1.onnx') #导出 .onnx 文件
netron.start('AlexNet1.onnx') #展示结构图
break
# print("output",output.size())
# print("hidden",hidden.size())
# print("====================")
topv, topi = output.topk(1)
topi = topi[0][0]
if topi == n_letters - 1:
break
else:
letter = all_letters[topi]
output_name += letter
input = inputTensor(letter)
return output_name
# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):
for start_letter in start_letters:
print(sample(category, start_letter))
break
samples('Russian', 'RUS')
运行查看结果
结果是在浏览器中,运行成功后会显示:
Serving ‘AlexNet.onnx’ at http://localhost:8080
打开这个网页就可以看见模型结构,如下图
需要关注的地方
- 关于参数
如果模型是一个参数的情况下,如下使用就可以了
import torch
from torchvision.models import AlexNet
import netron
model = AlexNet()
input = torch.ones((1,3,224,224))
torch.onnx.export(model, input, f='AlexNet.onnx')
netron.start('AlexNet.onnx')
如果模型有多个参数的情况下,则需要如下用括号括起来,如本文中的例子
torch.onnx.export(rnn,(category_tensor, input[0], hidden) , f='AlexNet1.onnx') #导出 .onnx 文件
netron.start('AlexNet1.onnx') #展示结构图
- 如果运行过程中发现报错找不到模型
有可能是你手动删除了生成的模型,最好的方法是重新生成这个模型,再运行