CNN的小体验

用的pytorch。

训练代码cnn.py:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

# 定义超参数
num_epochs = 10
batch_size = 100
learning_rate = 0.001

# 数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 定义卷积神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# 测试模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of the model on the 10000 test images: {100 * correct / total}%')

# 保存模型
torch.save(model.state_dict(), 'cnn.pth')

推断代码cnn2.py

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F

# 定义卷积神经网络(与之前的定义保持一致)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2) # 第一个卷积层
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) # 池化层
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2) # 第二个卷积层
        self.fc1 = nn.Linear(32 * 8 * 8, 128) # 全连接层
        self.fc2 = nn.Linear(128, 10) # 输出层

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # 通过第一个卷积层和池化层
        x = self.pool(F.relu(self.conv2(x))) # 通过第二个卷积层和池化层
        x = x.view(-1, 32 * 8 * 8) # 展平
        x = F.relu(self.fc1(x)) # 通过全连接层
        x = self.fc2(x) # 通过输出层
        return x

# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load('cnn.pth'))
model.eval() # 设置模型为评估模式

# 预处理图片
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((32, 32)), # 调整图像大小到32x32
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    image = Image.open(image_path)
    image = transform(image)
    image = image.unsqueeze(0) # 增加批量维度
    return image

# 加载并预处理图片
image_path = 'test.jpg' # 替换为你要分析的图片路径
image = preprocess_image(image_path)

# 使用模型进行推理
with torch.no_grad():
    outputs = model(image)
    _, predicted = torch.max(outputs.data, 1)
    class_index = predicted.item()

# CIFAR-10类别标签
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# 输出预测结果
print(f'Predicted class: {classes[class_index]}')

可惜出来的东西跟弱智一般。

python3 cnn2.py
Predicted class: horse

python3 cnn2.py
Predicted class: bird

几个小点:

1 使用的数据集是CIFAR10

2 训练真的挺耗时的,我用的阿里云,一共搞了差不多10分钟(训练一个弱智)。

3 环境依然麻烦,python,numpy的版本都不能太高。否则要出问题。。。

4 最后实事求是的说,我不太懂的一点是怎么分类出来的。。。晚点再看看。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/762232.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

论文翻译 | (DSP)展示-搜索-预测:为知识密集型自然语言处理组合检索和语言模型

摘要 检索增强式上下文学习已经成为一种强大的方法,利用冻结语言模型 (LM) 和检索模型 (RM) 来解决知识密集型任务。现有工作将这些模型结合在简单的“检索-读取”流程中,其中 RM 检索到的段落被插入到 LM 提示中。 为了充分发挥冻结 LM 和 RM 的…

数据结构预科

在堆区申请两个长度为32的空间,实现两个字符串的比较【非库函数实现】 要求: 1> 定义函数,在对区申请空间,两个申请,主函数需要调用2次 2> 定义函数,实现字符串的输入,void input(char …

一文全概括,建议收藏,那些你不可错过的IC设计书籍合集(可下载)

集成电路设计工程师的角色不仅是推动技术创新的中坚力量,更是实现产品从概念到现实的关键桥梁。随着对高性能、低功耗芯片的需求不断增长,IC设计工程师的专业技能和知识深度成为了衡量其职业价值的重要标准。无论是在数字逻辑设计、功能验证、可测试性设…

GPON-GPON帧链路层知识学习

前言: 引用: gpon学习_gpon帧结构-CSDN博客 了解 GPON 技术 - Cisco GPON、XG(S)-PON基础_网络_门牙会稍息-开放原子开发者工作坊 gpon学习_gpon帧结构-CSDN博客 广域网宽带接入技术七GPON技术_gtc帧-CSDN博客 https://www.cnblogs.com/aliyunyun/…

全球首款搭载Google Gemini和GPT-4o的智能眼镜发布

智能眼镜仍然是一个尚未完全成熟的未来概念,但生成式人工智能的到来显著提升了这些设备的能力。Meta 的 Ray-Ban 智能眼镜被许多人视为当今最好的选择之一,而现在 Solos AirGo Vision 正在为其带来竞争,这款眼镜还集成了 Google Gemini 支持。…

[论文精读]Variational Graph Auto-Encoders

论文网址:[1611.07308] Variational Graph Auto-Encoders (arxiv.org) 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎…

炎黄数智人:国家体育总局冬运中心——AI裁判与教练“观君”赋能冰雪运动新篇章

在科技创新的浪潮下,国家体育总局冬季运动管理中心(以下简称“冬运中心”)揭开了人工智能在体育领域应用的新篇章。隆重宣布推出革命性的AI裁判与教练系统——“观君”,该系统将在冰雪运动项目中大放异彩,为运动员的训…

地下水电站3D虚拟仿真展示平台

借助先进的VR技术,我们将水电站的每一个角落、每一处细节都以三维全景的形式真实呈现。您可以自由穿梭于水电站的各个区域,无论是发电机组、巍峨的水坝,还是错综复杂的输水管道,都近在咫尺。感受水流的澎湃力量,聆听机…

python自动化之schedule

目录 代码(以每5秒1次为例): 每5分钟1次 每2小时1次 每天18:00执行 用到的库:schedule,time 实现的效果:按秒来运行任务,按分钟来运行任务,按小时来运行任务,按天来运行任务 代…

Canvas 指纹:它是什么以及如何绕过它

什么是 Canvas 指纹? 网络浏览器在执行其功能时会收集各种信息。当这些信息中的某些被用于识别网站用户时,这被称为浏览器指纹。 浏览器指纹包括以下有关浏览器的信息:设备型号、浏览器类型和版本、操作系统 (OS)、屏幕分辨率、时区、p0p 文…

预约小程序源码,云开发技术,无需服务器

介绍: 很多企业的业务都需要通过服务预约来完成,比如酒店、美容、家政等等。 但很多商家因缺少合适的服务预订工具,而不知道如何让客户尽快预约。 这种情况下,制作一个自己的预约小程序,客户只需要扫码或者在微信里…

工程化:Commitlint / 规范化Git提交消息格式

一、理解Commitlint Commitlint是一个用于规范化Git提交消息格式的工具。它基于Node.js,通过一系列的规则来检查Git提交信息的格式,确保它们遵循预定义的标准。 1.1、Commitlint的核心功能 代码规则检查:Commitlint基于代码规则进行检查&a…

销量位列第一!强力巨彩LED单元板成绩斐然

据全球知名科技研究机构Omdia《LED显示产品出货分析-中国-2023》报告显示,2023年强力巨彩LED显示屏销量与单元板产品销量均位列第一,其品牌和市场优势可见一斑。 厦门强力巨彩自2004年成立之初,便以技术创新和严格品控为核心竞争力&#xff0…

【Kaggle】Telco Customer Churn 电信用户流失预测案例

⭐️前言:案例学习说明与案例建模流程 我们将围绕Kaggle中的电信用户流失数据集(Telco Customer Churn)进行用户流失预测。在此过程中,将综合应用此前所介绍的各种方法与技巧,并在实践中提炼总结更多实用技巧。 ⭐️对…

智慧的网络爬虫之CSS概述

智慧的网络爬虫之CSS概述 ​ CSS 是“Cascading Style Sheet”的缩写,中文意思为“层叠样式表”,用于描述网页的表现形式。如网页元素的位置、大小、颜色等。css的主要作用是定义网页的样式。 CSS样式 1. 行内样式 行内样式:直接定义在 HT…

造一个交互式3D火山数据可视化

本文由ScriptEcho平台提供技术支持 项目地址:传送门 使用 Plotly.js 创建交互式 3D 火山数据可视化 应用场景 本代码用于将火山数据库中的数据可视化,展示火山的高度、类型和状态。可用于地质学研究、教育和数据探索。 基本功能 该代码使用 Plotly…

模型部署:C++libtorch实现全连接模型10分类和卷积模型ResNet18的四分类的模型部署推理

Clibtorch实现模型部署推理 模型 全连接模型:公开mnist手写识别数字的十分类卷积模型:自行采集的鲜花四分类 部署 语言环境:C 对比Python python是解释性语言,效率很慢,安全性很低 系统开发一般是java、C/C&…

昂科烧录器支持BPS晶丰明源半导体的多相Buck控制器BPD93004E

芯片烧录行业领导者-昂科技术近日发布最新的烧录软件更新及新增支持的芯片型号列表,其中BPS晶丰明源半导体的多相Buck控制器BPD93004E已经被昂科的通用烧录平台AP8000所支持。 BPD93004E是一款多相Buck控制器,支持原生1~4相,数字方式控制&am…

麒麟桌面操作系统上解决任务栏消失问题

原文链接:麟桌面操作系统上解决任务栏消失问题 Hello,大家好啊!今天给大家带来一篇关于在麒麟桌面操作系统上解决任务栏消失问题的文章。任务栏是我们日常操作系统使用中非常重要的部分,它提供了快速访问应用程序和系统功能的便捷…

OpenBayes 教程上新 | CVPR 获奖项目,BioCLlP 快速识别生物种类,再也不会弄混小浣熊和小熊猫了!

市面上有很多植物识别的 App,通过对植物的叶片、花朵、果实等特征进行准确的识别,从而确定植物的种类、名称。但动物识别的 App 却十分有限,这使我们很难区分一些外形相似的动物,例如小浣熊和小熊猫。 左侧为小浣熊,右…