pytorch自定义数据集分类resnet18

 

# 文件结构为:
# |--- data
#  |--- dog
#     |--- dog1_1.jpg
#     |--- dog1_2.jpg
#  |--- cat
#     |--- cat2_1.jpg
#     |--- cat2_2.jpg

 

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# 定义数据集的根目录和预处理的转换
data_dir = '../data'  # 数据集的根目录

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小为 224x224
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])

# 创建 ImageFolder 数据集实例
dataset = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)

# 划分训练集和测试集
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

print(len(train_dataset))
print(len(test_dataset))
# 创建数据加载器
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 定义预训练的卷积神经网络模型
model = torchvision.models.resnet18(pretrained=True)  #pretrained=False表示不使用预训练的权重,True表示使用预训练的权重
num_classes = len(dataset.classes) #获取图片的类别数量
model.fc = nn.Linear(model.fc.in_features, num_classes) #提取model.fc.in_features线性层中固定输入的size,
# num_classes分类图片的类型['cat','dog']

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(num_epochs):
    model.train() #(训练模式,这句代码主要是对模型中的Droupout层和Normsize(均值方差计算)起作用)
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.to(device) #将图片放到GPU训练
        labels = labels.to(device) #标签放到GPU训练

        optimizer.zero_grad() #梯度清零

        outputs = model(images) #图片输入到模型
        loss = criterion(outputs, labels) #预测值和真是值之间计算损失
        loss.backward() #反向传播
        optimizer.step() #更新参数

        running_loss += loss.item() #每次损失相加

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_loader):.4f}")

# 在测试集上评估模型
model.eval() #训练模式,这句代码主要是对模型中的Droupout层和Normsize(均值方差计算)不加入计算
total_correct = 0
total_samples = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

accuracy = total_correct / total_samples
print(f"测试集准确率: {accuracy * 100:.2f}%")
torch.save(model,"model56")


# 文件结构为:
# |--- data
# 	|--- dog
# 		|--- dog1_1.jpg
# 		|--- dog1_2.jpg
# 	|--- cat
# 		|--- cat2_1.jpg
# 		|--- cat2_2.jpg



# 不同的模型构建细节
# AlexNet 模型结构
# torchvision.models.alexnet(pretrained=False, ** kwargs)
# pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# 构建一个resnet18模型
# torchvision.models.resnet18(pretrained=False, ** kwargs)
# pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# 构建一个ResNet-34 模型.
# torchvision.models.resnet34(pretrained=False, ** kwargs)
# Parameters: pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# 构建一个ResNet-50模型
# torchvision.models.resnet50(pretrained=False, ** kwargs)
# pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# 构建一个ResNet-101模型
# torchvision.models.resnet101(pretrained=False, ** kwargs)
# pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# 构建一个ResNet-152模型
# torchvision.models.resnet152(pretrained=False, ** kwargs)
# pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# VGG 11层模型(配置“A”)
# torchvision.models.vgg11(pretrained=False, ** kwargs)
# pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# 批量归一化的VGG 11层模型(配置“A”)
# torchvision.models.vgg11_bn(** kwargs)
#
# 构建一个VGG 13模型
# torchvision.models.vgg13(pretrained=False, ** kwargs)
# pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# 批量归一化的VGG 13层模型(配置“B”)
# torchvision.models.vgg13_bn(** kwargs)
#
# VGG 16层模型(配置“D”)
# torchvision.models.vgg16(pretrained=False, ** kwargs)
# Parameters: pretrained (bool) = True, returns a model pre-trained on ImageNet
#
# 批量归一化的VGG 16层模型(配置“D”)
# torchvision.models.vgg16_bn(** kwargs)
#
# VGG 19层模型(配置“E”)
# torchvision.models.vgg19(pretrained=False, ** kwargs)
# pretrained (bool) = True, 返回在ImageNet上训练好的模型。
#
# 批量归一化的VGG 16层模型(配置“E”)
# torchvision.models.vgg19_bn(** kwargs)

predict.py保存模型之后预测:

from torchvision import datasets, transforms
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from cov01 import Model

classes = ('cat','dog')

if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torch.load('model56')  # 加载模型
    model = model.to(device)
    model.eval()  # 把模型转为test模式

    img = Image.open("../dog.jpg")

    trans = transforms.Compose(
        [
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

    img = trans(img)
    img = img.to(device)
    img = img.unsqueeze(0)  # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
    # 扩展后,为[1,1,28,28]
    output = model(img)
    prob = F.softmax(output, dim=1)  # prob是10个分类的概率
    print(prob)
    value, predicted = torch.max(output.data, 1) #按照维度返回最大概率dim = 0 表示按列求最大值,并返回最大值的索引,dim = 1 表示按行求最大值,并返回最大值的索引
    print(predicted.item())
    print(value)
    pred_class = classes[predicted.item()]
    print(pred_class)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/408744.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【软件测试面试】要你介绍项目-如何说?完美面试攻略...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 1、测试面试时&am…

UI风格汇:扁平化风格来龙去脉,特征与未来趋势

Hello,我是大千UI工场,设计风格是我们新开辟的栏目,主要讲解各类UI风格特征、辨识方法、应用场景、运用方法等,本次带来的扁平化风格的解读,有设计需求,我们也可以接单。 一、什么是扁平化风格 扁平化风格…

C# EF Core迁移数据库

现象: 在CodeFirst时,先写字段与表,创建数据库后,再添加内容 但字段与表会变更,比如改名删除增加等 需求: 当表字段变更时,同时变更数据库,执行数据库迁移 核心命令 Add-Migrat…

陪诊小程序:温暖您的就医之路,让关怀触手可及

随着社会的进步和科技的发展,人们对于医疗健康的需求日益增长。然而,在繁忙的生活节奏中,许多人在面对就医时却面临着无人陪伴的困境。为了解决这一问题,陪诊小程序应运而生。 陪诊小程序是一种便捷、高效、人性化的医疗服务应用…

9-pytorch-现有模型使用及修改

b站小土堆pytorch教程学习笔记 1 使用ImageNet测试模型vgg16 train_datatorchvision.datasets.ImageNet(dataset/ImageNet,trainTrue ,downloadTrue ,transformtorchvision.transforms.ToTensor())代码运行报错:ImageNet数据集过大,导致现在无法公开访问…

聊聊 Go 边界检查消除

前言 在这篇文章中碰巧看到了Go边界检查消除相关的讨论. 我也借此简单聊聊. 有这样一段代码, 非常简单, 就是一段求向量点积的程序: func sum(a, b []int) int {if len(a) ! len(b) {panic("must be same len")}ret : 0for i : 0; i < len(a); i {ret a[i] * …

SAM轻量化的终点竟然是RepViT + SAM

本文首发&#xff1a;AIWalker&#xff0c;欢迎关注~~ 殊途同归&#xff01;SAM轻量化的终点竟然是RepViT SAM&#xff0c;移动端速度可达38.7fps。 对于 2023 年的计算机视觉领域来说&#xff0c;「分割一切」&#xff08;Segment Anything Model&#xff09;是备受关注的一项…

0-1背包问题-动态规划

解法归纳&#xff1a; 一、如果装不下当前物品&#xff0c;那么前n个物品的最佳组合和前n-1个物品的最佳组合是一样的。 二、如果装得下当前物品。 假设1 :装当前物品&#xff0c;在给当前物品预留了相应空间的情况下&#xff0c;前n-1 个物品的最佳组 合加上当前物品的价值就…

作业 找单身狗2

方法一&#xff1a; 思路&#xff1a; 我们可以先创建一个新的数组&#xff0c;初始化为0&#xff0c;然后让原来的数组里面的元素作为新数组的下标 如果该下标对应的值为0&#xff0c;说明没有出现过该数&#xff0c;赋值为1作为标记&#xff0c;表示出现过1次 如果该下标…

#FPGA(基础知识)

1.IDE:Quartus II 2.设备&#xff1a;Cyclone II EP2C8Q208C8N 3.实验&#xff1a;正点原子-verilog基础知识 4.时序图&#xff1a; 5.步骤 6.代码&#xff1a;

代码随想录刷题第41天

首先是01背包的基础理论&#xff0c;背包问题&#xff0c;即如何在有限数量的货物中选取使具有一定容量的背包中所装货物价值最大。使用动规五步曲进行分析&#xff0c;使用二维数组do[i][j]表示下标从0到i货物装在容量为j背包中的最大价值&#xff0c;dp[i][j]可由不放物品i&a…

物理备份的方式

完全备份恢复流程 停止数据库清理环境重演回滚&#xff0d;&#xff0d;> 恢复数据修改权限启动数据库 1.关闭数据库&#xff1a; [rootmysql-server ~]# systemctl stop mysqld [rootmysql-server ~]# rm -rf /var/lib/mysql/* //删除所有数据// [rootmysql-server ~]# …

Sora:颠覆性AI视频生成工具

Sora是一款基于人工智能&#xff08;AI&#xff09;技术的视频生成工具&#xff0c;它彻底改变了传统视频制作的模式&#xff0c;为创作者提供了高效、便捷、高质量的视频内容生成方式。通过深度学习和自然语言处理等先进技术&#xff0c;Sora实现了从文字描述到视频画面的自动…

并发编程(5)共享模型之不可变

7 共享模型之不可变 本章内容 不可变类的使用不可变类设计无状态类设计 7.1 日期转换的问题 问题提出 下面的代码在运行时&#xff0c;由于 SimpleDateFormat 不是线程安全的, 有很大几率出现 java.lang.NumberFormatException 或者出现不正确的日期解析结果&#xff0c;…

SpringCloud Alibaba 2022之Nacos学习

SpringCloud Alibaba 2022使用 SpringCloud Alibaba 2022需要Spring Boot 3.0以上的版本&#xff0c;同时JDK需要是17及以上的版本。具体的可以看官网的说明。 Spring Cloud Alibaba版本说明 环境搭建 这里搭建的是一个聚合项目。项目结构如下&#xff1a; 父项目的pom.xm…

(拦截器)学习SpringMVC的第三天

一 .拦截器简介 拦截器的几个处理阶段 二 . 拦截器快速入门 2.1 实现拦截器接口 public class MyInterceptor1 implements HandlerInterceptor {Overridepublic boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Excep…

微信小程序开启横屏调试

我们先打开小程序项目 开启真机运行 目前是一个竖屏的 然后打开全局配置文件 app.json 给下面的 window 对象 下面加一个 pageOrientation 属性 值为 landscape 运行结果如下 然后 我们开启真机运行 此时 就变成了个横屏的效果

(done) Positive Semidefinite Matrices 什么是半正定矩阵?如何证明一个矩阵是半正定矩阵? 可以使用特征值

参考视频&#xff1a;https://www.bilibili.com/video/BV1Vg41197ew/?vd_source7a1a0bc74158c6993c7355c5490fc600 参考资料(半正定矩阵的定义)&#xff1a;https://baike.baidu.com/item/%E5%8D%8A%E6%AD%A3%E5%AE%9A%E7%9F%A9%E9%98%B5/2152711?frge_ala 看看半正定矩阵的…

ubantu设置mysql开机启动

阅读本文之前请参阅----MySQL 数据库安装教程详解&#xff08;linux系统和windows系统&#xff09; 在Ubuntu系统中设置MySQL开机启动&#xff0c;通常有以下几种方法&#xff1a; 1. **使用systemctl命令**&#xff1a; Ubuntu 16.04及更高版本使用systemd作为…

Facebook群控:利用代理IP克服多账号关联

拥有多个 Facebook 帐户对于区分您的个人和企业在线形象或维护客户页面非常有用。然而&#xff0c;Facebook 的服务条款正式限制用户只能使用一个个人帐户&#xff0c;想要多账号运营&#xff0c;下面的干货必须看&#xff01; 一、Facebook群控是什么&#xff1f; Facebook群…