pytorch模型保存及加载
代码
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
# 1. save model 1 保存模型结构及模型参数
torch.save(vgg16, './vgg16_save1.model')
# 2. save model 2 只保存模型参数 比第一种保存方法保存的文件要小
torch.save(vgg16.state_dict(), './vgg16_save2.model')
# 3. load model 1
vgg16_load1 = torch.load('./vgg16_save1.model')
print(vgg16_load1) # 打印的是模型网络结构
# 3. load model 2
vgg16_load2 = torch.load('./vgg16_save2.model')
print(vgg16_load2) # 打印的是模型参数
# 将参数导入到网络
vgg16.load_state_dict(vgg16_load2)
print(vgg16)
# 5. 保存模型方式1的陷阱
# 当用方法1导入模型的时候,模型结构是要已知的
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
# class MySeq2(nn.Module):
# def __init__(self):
# super(MySeq2, self).__init__()
# self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
# MaxPool2d(2),
# Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
# MaxPool2d(2),
# Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
# MaxPool2d(2),
# Flatten(),
# Linear(1024, 64),
# Linear(64, 10)
# )
#
# def forward(self, x):
# x = self.model1(x)
# return x
# myseq2 = MySeq2()
# torch.save(myseq2, 'myseq_self.model')
# 当用方法1导入模型的时候,模型结构是要已知的 否则就会报下面的错误 可以在代码里重新定义 但一般都是写在另一个单独的文件里面 比如上面注释的模型结构是前面已经写在p19_nn_seq 文件里面的,执行了模型保存
# AttributeError: Can't get attribute 'MySeq2' on <module '__main__' from 'C:/工作文档/learn_pytorch/p23_save_load_model.py'>
from p19_nn_seq import *
myseq2 = torch.load('myseq_self.model')
print(myseq2)
执行结果
只打印模型参数
打印模型结构,在调试模式下 可以在feature–保护属性–models–0–weight下看到模型参数
自己写过的模型文件保存后加载