文章目录
- 一、修改
- 1.方法
- 2.代码
- 二、保存和读取
- 1.方法
- 2.代码
- (1)保存
- (2)加载
- 3.陷阱
一、修改
1.方法
add_module(name: str, module: Module) -> None
name 是要添加的子模块的名称。
module 是要添加的子模块。
调用 add_module 方法会向当前模块中添加一个子模块,并使用指定的名称进行标识。
2.代码
import torchvision
from torch import nn
# 实例化一个未经过预训练的 VGG16 模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 实例化一个经过预训练的 VGG16 模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("ok")
# 输出经过预训练的 VGG16 模型及修改后的模型
print(vgg16_true)
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
print(vgg16_true)
# 输出未经过预训练的 VGG16 模型及修改后的模型
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)
修改前的vgg16_true:
修改后的vgg16_true:
修改前的vgg16_true:
修改后的vgg16_true:
二、保存和读取
1.方法
保存: torch.save(要保存的模型,“文件路径”)
加载: torch.load(“文件路径”)
2.代码
(1)保存
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1:模型结构+模型参数
torch.save(vgg16, "vgg16_module1.pth")
# 保存方式2:模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_module2.pth")
(2)加载
import torch
import torchvision
# 方式1 加载模型
module1 = torch.load("vgg16_module1.pth")
print(module1)
#
module2 = torch.load("vgg16_module2.pth")
print(module2)
# 方式2 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_module2.pth"))
print(vgg16)
运行加载的代码后,打印结果如下
module1:
module2:
vgg16:
可以看到,第二种方式保存的数据,加载后是向量形式,需要通过别的方法加载为模型
3.陷阱
第一种方式加载,在某些条件下可能会报错
例如:
假设自定义一个神经网络,保存:
import torch
import torchvision
from torch import nn
# 陷阱
class Guodong(nn.Module):
def __init__(self):
super(Guodong,self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self,x):
x = self.conv1(x)
return x
guodong = Guodong()
torch.save(guodong,"guodong_method1.pth")
在另一个文件中加载:
import torch
# 陷阱
module = torch.load("guodong_method1.pth")
print(module)
就会报错:
AttributeError: Can’t get attribute ‘Guodong’ on <module ‘main’ from ‘E:\deepLearning\Pycharm\pytroch_project\theFirstFile\module_load.py’>
解决办法:
(1)把Guodong类放在这个文件里
import torch
from torch import nn
import torchvision
class Guodong(nn.Module):
def __init__(self):
super(Guodong,self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self,x):
x = self.conv1(x)
return x
# 陷阱
module = torch.load("guodong_method1.pth")
print(module)
(2)from module_save import *
(module_save)是保存自定义模型的文件
from module_save import *
# 陷阱
module = torch.load("guodong_method1.pth")
print(module)