基于pytoch卷积神经网络水质图像分类实战

具体怎么学习pytorch,看b站刘二大人的视频。

完整代码:

import numpy as np
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
'''https://zhuanlan.zhihu.com/p/156926543'''
# 定义图片目录
image_dir = 'images'

# 初始化图片路径列表
img_list = []

# 遍历指定目录及其子目录中的所有文件
for parent, _, filenames in os.walk(image_dir):
    for filename in filenames:
        # 拼接文件的完整路径
        filename_path = os.path.join(parent, filename)
        img_list.append(filename_path)

# 初始化图像张量列表和标签列表
image_tensors = []
y_list = []

for image_path in img_list:
    # 提取标签 (假设标签是文件名的第一个字符)
    label = int(os.path.basename(image_path)[0])
    y_list.append(label)

    # 打开图像
    img = Image.open(image_path)

    # 获取图像尺寸
    width, height = img.size

    # 定义裁剪的区域(假设要保留图像中心的 100x100 区域)
    left = (width - 100) / 2
    top = (height - 100) / 2
    right = (width + 100) / 2
    bottom = (height + 100) / 2

    # 裁剪图像
    img = img.crop((left, top, right, bottom))

    # 将图像转换为 NumPy 数组
    img_array = np.asarray(img)

    # 将 NumPy 数组转换为 PyTorch 张量
    img_tensor = torch.from_numpy(img_array).float()

    # 如果图像是 RGB,将其转换为 (C, H, W) 格式
    if img_tensor.ndimension() == 3 and img_tensor.shape[2] == 3:
        img_tensor = img_tensor.permute(2, 0, 1)  # 从 (H, W, C) 变为 (C, H, W)

    # 增加 batch 维度
    img_tensor = img_tensor.unsqueeze(0)  # 从 (C, H, W) 变为 (1, C, H, W)

    # 规范化到0-1之间
    img_tensor = img_tensor / 255.0

    # 添加到图像张量列表
    image_tensors.append(img_tensor)

    # 打印图像张量的形状
    print(f"当前图像形状: {img_tensor.shape}")

# 将图像张量列表转换为四维张量
x_data = torch.cat(image_tensors, dim=0)
# 遍历 y_list 中的每个元素,并将每个数减去 1
for i in range(len(y_list)):
    y_list[i] -= 1

# 将标签列表转换为张量
y_labels = torch.tensor(y_list).long()  # 注意这里使用 .long() 方法将标签转换为长整型

print(x_data.shape,y_labels.shape)
print(y_labels)

# 定义数据集和数据加载器
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x_data, y_labels):
        self.x_data = x_data
        self.y_labels = y_labels

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        return self.x_data[idx], self.y_labels[idx]


# 使用自定义数据集和数据加载器
custom_dataset = CustomDataset(x_data, y_labels)
train_size = int(0.8 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(custom_dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)

# 定义卷积神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 25 * 25, 128)
        self.fc2 = nn.Linear(128, 5)  # 假设有5个类别

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 训练模型
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(50):  # 假设训练50个epoch
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}")

    # 在每个epoch结束后,计算并打印验证集的准确率
    model.eval()  # 将模型设置为评估模式
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_accuracy = correct / total
    print(f'Validation Accuracy after Epoch {epoch + 1}: {val_accuracy}')

本数据集中x_data的维度是四维张量(203,3,100,100),y_labels的维度是一维张量 

代码中需要注意的点,卷积模型接受的是四维张量,因此要转变为四维张量。

全连接层中输入的特征数,需要自己计算,通过前面卷积层和池化层后,计算总的维度数。一般是最后的通道数*高度*宽度

定义模型中的forward函数中,在经过全连接层计算前,需要将四维的x转为2维

如果 x 的形状是 (64, 32, 28, 28),表示一个批次大小为64的图像张量,其中每个图像有32个通道,高度和宽度都是28像素。现在,我们希望将这个张量展平为一个二维张量,以便输入到全连接层进行进一步处理。

通过 torch.flatten(x, 1) 操作,我们将在指定维度(这里是第一个维度,也就是通道维度)上对张量进行展平。展平后的张量形状将变为 (64, 32*28*28),其中64是批次大小,而 32*28*28 是展平后的特征数量,即每个图像的特征数量。这与前面定义的全连接层的输入特征数要一致。

Dataloader中batch_size就是设置第一个维度,比如这里的batch_size是32,那么

for inputs, labels in train_loader:

 这里的inputs维度是(32,3,100,100)

新学习pytorch中的分割数据集与测试集方法。

# 使用自定义数据集和数据加载器
custom_dataset = CustomDataset(x_data, y_labels)
train_size = int(0.8 * len(custom_dataset))
val_size = len(custom_dataset) - train_size
train_set, val_set = torch.utils.data.random_split(custom_dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)

结果展现,可以看见准确率有0.82:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/696242.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

模板显式、隐式实例化和(偏)特化、具体化的详细分析

最近看了<The C Programing Language>看到了模板的特化&#xff0c;突然想起来<C Primer>上说的显式具体化、隐式具体化、特化、偏特化、具体化等概念弄得头晕脑胀&#xff0c;我在网上了找了好多帖子&#xff0c;才把概念给理清楚。 看着这么多叫法&#xff0c;其…

晨控CK-UR12-E01与欧姆龙NX/NJ系列EtherNet/IP通讯手册

晨控CK-UR12-E01与欧姆龙NX/NJ系列EtherNet/IP通讯手册 晨控CK-UR12-E01 是天线一体式超高频读写器头&#xff0c;工作频率默认为902MHz&#xff5e;928MHz&#xff0c;符合EPC Global Class l Gen 2&#xff0f;IS0-18000-6C 标准&#xff0c;最大输出功率 33dBm。读卡器同时…

C语言怎样初始化图形模式?

一、问题 在C语⾔中&#xff0c;initgraph( ) 函数⽤于初始化图形模式。初始化时&#xff0c;那么多参数都是⼲什么的&#xff1f;怎样设置&#xff1f; 二、解答 initgraph( ) 函数⽤于初始化图形模式&#xff0c;其语法格式如下。 void far initgraph(int far * gdriver, i…

0基础学习区块链技术——入门

大纲 区块链构成区块链相关技术Hash算法区块链区块链交易 参考资料 本文力求简单&#xff0c;不讨论任何技术细节&#xff0c;只是从简单的组成来介绍区块链技术&#xff0c;以方便大家快速入门。同时借助一些可视化工具&#xff0c;辅助大家有直观的认识。 区块链构成 顾名思…

python导入非当前目录(如:父目录)下的内容

在开发python项目时&#xff0c;通常会划分不同的目录&#xff0c;甚至不同层级的目录&#xff0c;这时如果直接导入不在当前目录下的内容时&#xff0c;会报如下的错误&#xff1a;ModuleNotFoundError: No module named miniai其实这里跟操作系统的环境变量很类似的&#xff…

绘唐官网绘唐科技

绘唐AI工具是一种基于人工智能技术的绘画辅助工具。 使用教程&#xff1a;https://iimenvrieak.feishu.cn/docx/CWwldSUU2okj0wxmnA0cHOdjnF 它可以根据用户提供的输入或指令生成各种类型的图像。 绘唐AI工具可以理解用户的绘画需求&#xff0c;并根据用户的要求生成具有艺术…

文件操作(Python和C++版)

一、C版 程序运行时产生的数据都属于临时数据&#xff0c;程序—旦运行结束都会被释放通过文件可以将数据持久化 C中对文件操作需要包含头文件< fstream > 文件类型分为两种: 1. 文本文件 - 文件以文本的ASCII码形式存储在计算机中 2. 二进制文件- 文件以文本的二进…

【图解IO与Netty系列】Netty核心组件解析

Netty核心组件解析 Bootstrap & ServerBootstrapEventLoop & EventLoopGroupChannelChannelHandler & ChannelPipeline & ChannelHandlerContextChannelHandlerChannelPipelineChannelHandlerContext ChannelFuture Bootstrap & ServerBootstrap Bootstra…

免费!GPT-4o发布,实时语音视频丝滑交互

We’re announcing GPT-4o, our new flagship model that can reason across audio, vision, and text in real time. 5月14日凌晨&#xff0c;OpenAI召开了春季发布会&#xff0c;发布会上公布了新一代旗舰型生成式人工智能大模型【GPT-4o】&#xff0c;并表示该模型对所有免费…

AI智能体做高考志愿填报分析

关注公众号&#xff0c;赠送AI/Python/Linux资料&#xff0c;对AI智能体有兴趣的朋友也可以添加一起交流 高考正在进行时&#xff0c;学生焦虑考试&#xff0c;家长们焦虑的则是高考志愿怎么填。毕竟一个好的学校&#xff0c;好的专业是进入社会的第一个敲门砖 你看张雪峰老师…

Tomcat源码解析(八):一个请求的执行流程(附Tomcat整体总结)

Tomcat源码系列文章 Tomcat源码解析(一)&#xff1a;Tomcat整体架构 Tomcat源码解析(二)&#xff1a;Bootstrap和Catalina Tomcat源码解析(三)&#xff1a;LifeCycle生命周期管理 Tomcat源码解析(四)&#xff1a;StandardServer和StandardService Tomcat源码解析(五)&…

Python opencv读取深度图,网格化显示深度

效果图&#xff1a; 代码&#xff1a; import cv2 import osimg_path "./outdir/180m_norm_depth.png" depth_img cv2.imread(img_path, cv2.IMREAD_ANYDEPTH) filename os.path.basename(img_path) img_hig, img_wid depth_img.shape # (1080, 1920) print(de…

设计模式- 责任链模式(行为型)

责任链模式 责任链模式是一种行为模式&#xff0c;它为请求创建一个接收者对象的链&#xff0c;解耦了请求的发送者和接收者。责任链模式将多个处理器串联起来形成一条处理请求的链。 图解 角色 抽象处理者&#xff1a; 一个处理请求的接口&#xff0c;可以通过设置返回值的方…

Python基础教程(八):迭代器与生成器编程

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; &#x1f49d;&#x1f49…

Buildroot和Debian文件系统修改方法

本文档主要介绍在没有编译环境的情况下&#xff0c;如何修改buildroot和debian文件系统方法&#xff0c;如在buildroot文件系统中添加文件、修改目录等文件操作&#xff0c;在debian文件系统中&#xff0c;安装软件库、工具、扩大文件系统空间等等操作。 1.Debian文件系统 …

【Python从入门到进阶】57、Pandas入门:背景、应用场景与基本操作

一、引言 1、Pandas简介 在数字化时代&#xff0c;数据已经成为企业决策和个人洞察的重要基础。无论是金融市场的波动、零售业的销售趋势&#xff0c;还是科研实验的结果&#xff0c;都蕴含在大量的数据之中。然而&#xff0c;如何有效地提取、分析和解读这些数据&#xff0c…

嵌入式应用之FIFO模块原理与实现

FIFO介绍与原理 FIFO是First-In First-Out的缩写&#xff0c;它是一个具有先入先出特点的缓冲区。FIFO在嵌入式应用的非常广泛&#xff0c;可以说有数据收发的地方&#xff0c;基本就有FIFO的存在。或者为了降低CPU负担&#xff0c;提高数据处理效率&#xff0c;可以在积累到一…

2、数据操作

索引从0开始 一行 [1,:] 一列[:,1] 子区域&#xff1a;[1:3,1:] 第一行和第二行&#xff0c;从第一列开始 [::3,::2] 每3行一跳&#xff0c;每2列一跳 torch.tensor([[1,2,3,4]] 按位置算 xy ,x-y x*y x**y&#xff08;幂&#xff09; 1、广播机制形状不一样&#xff0c;…

一个简单好用的 C# Easing Animation 缓动动画类库

文章目录 1.类库说明2.使用步骤2.1 创建一个Windows Form 项目2.2 安装类库2.3 编码2.4 效果 3. 扩展方法3.1 MoveTo 动画3.2 使用回调函数的Color动画3.3 属性动画3.4 自定义缓动函数 4.该库支持的内置缓动函数5.代码下载 1.类库说明 App.Animations 类库是一个很精炼、好用的…