基于Pytorch框架构建VGG-19模型

Pytorch

  • 一、训练模型
    • 1.导入资源包
    • 2.定义数据预处理
    • 3.读取数据
  • 二、定义VGG19模型
    • 1.定义自定义的 VGG19 模型
    • 运行结果:
  • 四、验证模型
    • 1. 定义验证过程
    • 2.用于训练模型并应用学习率调整策略的循环
    • 运行结果:
    • 3.保存模型的状态字典
  • 三、训练模型
    • 1. 定义训练函数
  • 五、创建 CustomVGG19 模型实例
    • 1. 导入资源包
    • 2.定义数据预处理
    • 4.创建 CustomVGG19 模型实例
    • 5.定义预测函数
    • 6.定义了一个可视化函数
    • 运行结果:

一、训练模型

1.导入资源包

import torch.utils.data: 导入了PyTorch的数据工具模块,这个模块提供了用于数据加载和处理的工具,如Dataset和DataLoader。
from torchvision import models: 导入了PyTorch的预训练模型模块,这个模块提供了多种预训练的型,如ResNet、VGG、AlexNet等,可以直接用于迁移学习或特征提取。

from sched import scheduler
import torch.optim as optim
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import  os
from torchvision import models

2.定义数据预处理

这些预处理操作的目的是为了增强模型的泛化能力,并确保模型在训练和验证时输入数据的格式一致。通过这些操作,模型能够接受不同尺寸、角度和方向的图像,从而提高其在实际应用中的表现。同时,归一化处理有助于稳定训练过程,加速模型收敛。,这些预处理操作的目的是为了增强模型的泛化能力,并确保模型在训练和验证时输入数据的格式一致。通过这些操作,模型能够接受不同尺寸、角度和方向的图像,从而提高其在实际应用中的表现。同时,归一化处理有助于稳定训练过程,加速模型收敛。

# 定义数据预处理
transform = {
'train': transforms.Compose([
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
}

3.读取数据

读取和准备图像数据集,以便用于训练和验证深度学习模型,这段代码设置了数据加载器,它们将在训练和验证过程中提供经过预处理的图像数据。这些数据加载器是PyTorch中用于批量加载数据并使其易于迭代的重要工具。

# 读取数据
dataset = './dataset'
train_directory = os.path.join(dataset, 'train')
valid_directory = os.path.join(dataset, 'val')

batch_size = 32
num_classes = 2  # 修改为您的分类数

data = {
'train': datasets.ImageFolder(root=train_directory, transform=transform['train']),
'val': datasets.ImageFolder(root=valid_directory, transform=transform['val'])
}

train_loader = DataLoader(data['train'], batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(data['val'], batch_size=batch_size, shuffle=False, num_workers=8)

二、定义VGG19模型

1.定义自定义的 VGG19 模型

这段代码定义了一个自定义的VGG-19模型,并将其适用于一个新的二分类任务。然后,它设置了损失函数和优化器,并检查了是否有可用的GPU以决定在哪个设备上进行训练。

from torchvision.models import vgg19
# 定义自定义的 VGG-19 模型
class CustomVGG19(nn.Module):
def __init__(self):
super(CustomVGG19, self).__init__()
self.vgg19_model = vgg19(pretrained=True)
for param in self.vgg19_model.features.parameters():
param.requires_grad = False
num_features = self.vgg19_model.classifier[6].in_features
self.vgg19_model.classifier[6] = nn.Sequential(
nn.Linear(num_features, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 2)
)

def forward(self, x):
return self.vgg19_model(x)

# 创建 CustomVGG19 模型实例
vgg19_model = CustomVGG19()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg19_model.parameters(), lr=0.001, weight_decay=1e-4)

# 首先,检查是否有可用的 GPU
if torch.cuda.is_available():
# 定义 GPU 设备
device = torch.device('cuda')
print("CUDA is available! Using GPU for training.")
else:
# 如果没有可用的 GPU,则使用 CPU
device = torch.device('cpu')
print("CUDA is not available. Using CPU for training.")

# 将模型移动到 GPU
vgg19_model.to(device)

# 如果你有优化器和其他需要移动到 GPU 的参数,例如梯度
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg19_model.parameters(), lr=0.001, weight_decay=1e-4)

运行结果:

在这里插入图片描述

四、验证模型

1. 定义验证过程

这个验证函数是评估分类模型性能的基本框架,您可以根据需要调整打印频率或其他参数。在实际使用中,您需要确保在调用这个函数之前已经定义了模型、数据加载器和损失函数。

# 定义验证过程
def val(model, device, test_loader, criterion):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
running_loss += loss.item()
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

print(f'Validation, Loss: {running_loss / len(test_loader)}, Accuracy: {100 * correct / total}%')

2.用于训练模型并应用学习率调整策略的循环

总的来说,这段代码将训练模型10个周期,并在每个周期结束后进行验证,同时使用学习率调度器来调整学习率。这种学习率调整策略可以帮助模型在训练过程中更好地收敛。在实际应用中,您可能需要根据您的具体任务和数据集调整周期数和学习率调度器的参数。

# 定义学习率调整策略
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# 训练模型
EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
train(vgg19_model, device, train_loader, optimizer, epoch)
val(vgg19_model, device, test_loader, criterion)
scheduler.step()  # 调整学习率

运行结果:

在这里插入图片描述

3.保存模型的状态字典

保存模型的状态字典是一个重要的步骤,因为它允许您在训练完成后保存模型的结果,以便将来使用或进行分析。在实际应用中,您可能需要根据您的具体需求选择不同的文件名和保存路径。

# 保存模型的状态字典
torch.save(vgg19_model.state_dict(), 'vgg19_model_weights.pth')

三、训练模型

1. 定义训练函数

这个训练函数是训练分类模型的基本框架,您可以根据需要调整打印频率或其他参数。在实际使用中,您需要确保在调用这个函数之前已经定义了模型、数据加载器、优化器和损失函数。

def train(model, device, train_loader, optimizer, epoch):
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

if batch_idx % 10 == 0:  # 每10个批次打印一次
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

print(f'Epoch {epoch}, Loss: {running_loss / len(train_loader)}, Accuracy: {100 * correct / total}%')

五、创建 CustomVGG19 模型实例

1. 导入资源包

from torch.autograd import Variable: 导入了PyTorch的自动求导变量模块。在早期的PyTorch版本中,Variable是用于封装张量并记录计算图的工具。但在最新的PyTorch版本中,Variable已经不再推荐使用,因为PyTorch自动将普通张量转换为Variable。如果您使用的是最新版本的PyTorch,这行代码可能是不必要的。

import torch
from PIL import Image
import torchvision.transforms as transforms
from torchvision import models
from torch.autograd import Variable

2.定义数据预处理

这些步骤是图像识别任务中的常见操作,用于准备数据和选择计算设备。在实际应用中,您可能需要根据您的具体任务和数据集调整这些参数。

# 定义数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 定义类别
classes = ['cat', 'dog']  # 替换为您的实际类别名称

# 检查是否有可用的 GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3.定义自定义的 VGG-19 模型
这个自定义的VGG-19模型是用于二分类任务的,它保留了VGG-19的卷积层和池化层不变,只修改了最后的全连接层以适应新的类别数。在实际应用中,您可能需要根据您的具体任务调整最后的全连接层,以匹配您的类别数。

# 定义自定义的 VGG-19 模型
class CustomVGG19(nn.Module):
def __init__(self):
super(CustomVGG19, self).__init__()
self.vgg19_model = models.vgg19(pretrained=True)
for param in self.vgg19_model.features.parameters():
param.requires_grad = False
num_features = self.vgg19_model.classifier[6].in_features
self.vgg19_model.classifier[6] = nn.Sequential(
nn.Linear(num_features, 4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, len(classes))
)

def forward(self, x):
return self.vgg19_model(x)

4.创建 CustomVGG19 模型实例

用于创建自定义的VGG-19模型实例,加载模型的权重,并将模型移动到指定的设备上。

# 创建 CustomVGG19 模型实例
model = CustomVGG19()
# 加载权重
model.load_state_dict(torch.load("vgg19_model_weights.pth"))
model.to(DEVICE)
model.eval()

5.定义预测函数

这个预测函数是使用模型进行图像分类的基本框架,我们可以根据需要调整打印频率或其他参数。在实际应用中,您需要确保在调用这个函数之前已经定义了模型、数据预处理和类别列表。

# 定义预测函数
def predict_image(image_path):
# 打开图片
image = Image.open(image_path)
# 应用预处理
image = transform(image).unsqueeze(0)  # 添加batch维度
# 转换为Variable(如果模型需要)
image = Variable(image).to(DEVICE)
# 获取模型预测
output = model(image)
_, prediction = torch.max(output.data, 1)
return classes[prediction.item()]

# 上传的图片路径
uploaded_image_path = '1111.jpg'
# 进行预测
predicted_class = predict_image(uploaded_image_path)

print(f"The uploaded image is predicted as: {predicted_class}")

运行结果:
在这里插入图片描述

6.定义了一个可视化函数

当我们运行这个脚本时,它会打开一个Matplotlib窗口,显示上传的图片,并在图片上添加一个标题显示预测的类别。同时,脚本会打印出预测结果。这个脚本是用于可视化图像分类结果的基本框架,您可以根据需要调整打印频率或其他参数。在实际应用中,您需要确保在调用这个函数之前已经定义了模型、数据预处理和类别列表。

import matplotlib.pyplot as plt

# 定义可视化函数
def visualize_prediction(image_path, predicted_class):
# 打开图片
image = Image.open(image_path)
# 显示图片
plt.imshow(image)
plt.axis('off')
plt.title(f'Predicted: {predicted_class}')
plt.show()
# 上传的图片路径
uploaded_image_path = '44.jpg'
# 进行预测
predicted_class = predict_image(uploaded_image_path)

# 可视化预测结果
visualize_prediction(uploaded_image_path, predicted_class)

print(f"The uploaded image is predicted as: {predicted_class}")

运行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/746666.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL—存储过程(详细介绍与基本语法)

目录 一、存储过程——介绍 (1)基本介绍 (2)基本特点 二、存储过程——语法 (1)基本语法 创建 调用 (2)实操(创建和调用) 1、创建一个叫 "p1&qu…

2024年6月26日 (周三) 叶子游戏新闻

老板键工具来唤去: 它可以为常用程序自定义快捷键,实现一键唤起、一键隐藏的 Windows 工具,并且支持窗口动态绑定快捷键(无需设置自动实现)。 土豆录屏: 免费、无录制时长限制、无水印的录屏软件 《Granblue Fantasy Versus: Risi…

K210视觉识别模块学习笔记6: 识别苹果_图形化操作函数_

今日开始学习K210视觉识别模块: 图形化操作函数 亚博智能 K210视觉识别模块...... 固件库: canmv_yahboom_v2.1.1.bin 训练网站: 嘉楠开发者社区 今日学习如何在识别到目标的时候添加图形化操作:(获取坐标、框出目标等) 在识别苹果的基础上 学习与添加 这些操…

在前端开发过程中如果函数参数很多,该如何精简

1. 在前端开发过程中如果函数参数很多,该如何精简 1.1. 对象参数(对象字面量):1.2. 默认参数和解构赋值:1.3. 使用类或构造函数:1.4. 利用闭包或者高阶函数:1.5. 利用ES6的扩展运算符&#xff1…

# 深入理解 Java 虚拟机 (二)

深入理解 Java 虚拟机 (二) Java内存模型 主内存与工作内存 所有的变量存储在主内存(虚拟机内存的一部分)每条线程有自己的工作内存,线程对变量的所有操作(读取、赋值)都必须在工作内存中进行…

数据质量低下会造成什么后果?应从哪些维度衡量数据质量?

大数据时代的到来,预示着前所未有的商业机遇和洞察力。然而,要将这些海量数据中蕴含的巨大价值转化为实际的业务成果,一个关键的前提条件是必须确保所收集数据的质量。数据质量是大数据价值链上的第一道关卡,它的高低直接关系到数…

【QT】设置QTabWidget样式:上、下边线的显示与去除

目录 0.简介 1.环境 2.详细介绍 2.1我的原代码和显示效果 2.2 去掉QTabWidget的边框 2.3 单独留下边线 2.3.1 法一:通过【this->setDocumentMode(true);】设置下边线 2.3.2 通过【QTabWidget::pane】设置下边线 2.4单独设置上边线 2.5 优化界面tab 2.…

Ceil()——向上取整函数

函数原型为: double ceil(double x); 大家可以在这个网站里更清晰的了解ceil - C Reference (cplusplus.com) 下面借助一道例题来帮助大家理解:牛牛的快递_牛客题霸_牛客网 (nowcoder.com) 我们分析题得知,在大于1的情况下,只要…

AI在软件开发中的应用

AI在软件开发中的应用可以帮助开发人员更高效地编写和测试代码,并提高软件的质量和性能。它能够帮助加快软件的部署和维护过程,提供更好的开发体验。 编码辅助 帮助开发人员更快地编写代码。例如,AI可以识别代码中的语法错误,并提…

实时美颜技术解析:视频美颜SDK如何改变直播行业

实时美颜技术的出现,尤其是视频美颜SDK的应用,正逐渐改变着直播行业的生态。 一、实时美颜技术的原理 实时美颜技术利用人工智能和图像处理算法,对视频中的人物面部进行优化和修饰。该技术通常包含以下几个步骤: 1.人脸检测和识…

Linux文件编程详解

Linux文件编程详解 在Ubuntu(Linux)系统下进行文件操作涉及一系列的系统调用,这些调用是基于Unix风格的文件操作API。这些操作包括打开或创建文件、从文件中读取数据、向文件中写入数据、移动文件指针以及关闭文件。以下是这些函数的详细介绍…

std::enable_if和std::is_base_of

std::enable_if,其主要为了完成模板特偏化,有两个参数,第一个为布尔值类型,第二个如果布尔值为true,其为默认空值,如果已经赋值,则为对应的类型。 std::is_base_of,其一共存在两个参数&#xff…

ora-15025 ora-27041问题处理

这个问题先排查 [oracleracdg2-2 ~]$ cd $ORACLE_HOME/bin [oracleracdg2-2 bin]$ ls -ld oracle -rwsr-s--x 1 oracle oinstall 239626641 Jun 25 19:09 oracle 正常的属组是 [gridracdg2-1 ~]$ setasmgidwrap -o /u01/app/oracle/product/11.2.0.4/dbhome_1/bin/oracle […

玩转AI之四个免费热门的AI工具

2023年,可以说称之为人工智能元年,随着 AI 人工智能、机器学习技术的不断发展,各种 AI 算法的应用也越来越广泛,在AI这一领域中,软件、工具和网站如雨后春笋般涌现。下半年,预计会有更多王炸级别的产品问世…

windows10/win11截图快捷键 和 剪贴板历史记录 快捷键

后知后觉的我今天又学了两招: windows10/win11截图快捷键 按 Windows 徽标键‌ Shift S。 选择屏幕截图的区域时,桌面将变暗。 默认情况下,选择“矩形模式”。 可以通过在工具栏中选择以下选项之一来更改截图的形状:“矩形模式”…

线性代数基础概念:行列式

目录 线性代数基础概念:行列式 1. 行列式的定义 1.1 递归定义 1.2 代数余子式定义 1.3 几何定义 2. 行列式的性质 2.1 行列式等于其转置的行列式 2.2 交换两行或两列,行列式变号 2.3 将一行或一列乘以一个数 k,行列式乘以 k 2.4 将…

植物大战僵尸杂交版技巧大全(附下载攻略)

《植物大战僵尸杂交版》为策略游戏爱好者带来了全新的挑战和乐趣。如果你是新手玩家,可能会对游戏中的植物和僵尸感到困惑。以下是一些实用的技巧,帮助你快速掌握游戏并享受其中的乐趣。 技巧一:熟悉基本玩法 游戏的基本玩法与原版相似&…

Android 11.0 修改系统显示大小导航栏消失

Android 11.0 修改系统显示大小导航栏消失 1.显示大小设置为大时,导航栏图标不显示。 设置为大,较大,最大时,导航栏图标不显示。 2.开始怀疑是导航栏被隐藏了,各种折腾无效。 3.发现: frameworks/base/packages/SystemUI/src/com/android/systemui/statusbar/phone/Edg…

OpenCV cv::Mat到 Eigen 的正确转换——cv2eigen

在进行计算机视觉项目时,我们经常需要处理相机位姿的变换。最近,我在项目中遇到了一个看似简单但实际上颇具挑战性的问题:从 OpenCV 的 cv::Mat 格式转换到 Eigen 库的格式。这个过程中遇到了一些问题,但最终找到了一个稳健的解决…

高考成绩加分,西藏学生推荐使用的《藏文翻译词典》APP,藏文作文高考大纲,初中高中学习内容与考试同步更新!

2024年高考成绩出炉啦!在这个特别的时刻,我想向大家表达最真挚的祝贺。高考不仅是一场考试,更是你多年学习旅程的一次总结。当你的成绩揭晓,无论结果如何,你都应该为自己感到骄傲。 在高原,藏语如同雪山上…