使用pytorch构建ResNet50模型训练猫狗数据集

数据集

1.导包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

2.数据预处理

ResNet50模型适合的图片大小为224x244

# 定义数据转换
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

3.加载数据集和模型构建

# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,
                             shuffle=True, num_workers=4)
               for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes

# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练

# 训练次数
num_epochs = 10

# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数
    for phase in ['train', 'test']:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_loss = 0.0
        running_corrects = 0

        # 使用tqdm显示进度条
        with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:
            for inputs, labels in dataloaders[phase]:
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]
                progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)
                progress_bar.update(1)

        print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
     # 更新训练次数计数器
    train_count += 1
    print(f'Training Count: {train_count}')

训练过程

5.预测

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# 定义模型的类别数量
num_classes = 2

# 加载模型
model = torchvision.models.resnet50(pretrained=False)
# 修改模型的fc层以匹配训练时的结构
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 加载保存的权重
model.load_state_dict(torch.load('mg_ResNet50model.pth'))
model.eval()

# 图像预处理
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 测试图片
img_path = 'mao_1.jpg'  # 替换为你的图片路径
img = Image.open(img_path)
img_t = preprocess(img)

# 扩展维度,因为模型需要4维输入(Batch, Channels, Height, Width)
batch_t = torch.unsqueeze(img_t, 0)

# 预测
with torch.no_grad():
    out = model(batch_t)

# 获取最高分数的类别
_, index = torch.max(out, 1)

# 可视化结果
plt.imshow(img)
plt.title(f'Predicted: {index.item()}')
plt.show()

预测效果

0就是猫咪,1就是小狗

全部代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.auto import tqdm  # 引入tqdm库以显示进度条

# 定义数据转换
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
data_dir = 'catdog_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,
                             shuffle=True, num_workers=4)
               for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes

# 加载ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练次数
num_epochs = 10

# 初始化训练次数计数器
train_count = 0
for epoch in range(num_epochs):  # num_epochs 是你希望训练的轮数
    for phase in ['train', 'test']:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_loss = 0.0
        running_corrects = 0

        # 使用tqdm显示进度条
        with tqdm(total=len(dataloaders[phase]), desc=f'Epoch {epoch+1}/{num_epochs}', leave=False) as progress_bar:
            for inputs, labels in dataloaders[phase]:
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]
                progress_bar.set_postfix(loss=epoch_loss, acc=epoch_acc)
                progress_bar.update(1)

        print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
     # 更新训练次数计数器
    train_count += 1
    print(f'Training Count: {train_count}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/663864.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

java期末突击

目录 1. 快速入门 2. 类 3. 成员方法 4. 构造器 5. 单例模式 6. 继承 7. 抽象类 8. 多线程(Thread与Runnable) 1. 快速入门 public class Hello {public static void main(String[] args){System.out.println("Hello,World&q…

开源项目的 README.md 这样写,hr 看了都心动

大家好,我是南城余。 最近在搞项目的开源 README.md 时,看到有的大佬的 README.md 前面都会有 badge 标识和多彩的 tag 标签。都看着超级高级,就想着也来实践下。 今天一篇文章带着大家个性化你的 README.md。如下图片: 这里我先…

【Linux】进程池

大致草稿 思维导图 学习目标 一、进程池的代码编写顺序 在进程池中,我们要创建多个子进程,并且对多个子进程和父进程建立管道的关系,确保父进程和子进程之间可以进行相互通信。 父进程就是master,而子进程就是work/slaver。大致…

低代码平台如何革新客户管理

引言 在当今快节奏的商业环境中,企业对于客户管理的需求日益增长。同时,随着技术的不断发展,低代码开发作为一种创新的软件开发方法正在逐渐崭露头角。低代码开发为企业提供了一种更快速、更灵活的方式来构建应用程序,使得传统的软…

烧脑的逻辑图又来了,精力绝对不是花费在做图上。

逻辑图设计之所以比较耗费精力,主要是因为它需要进行深入的思考和分析,以确保设计的逻辑正确、完整和可行。以下是一些可能导致逻辑图设计耗费精力的原因: 复杂性:逻辑图设计通常涉及到复杂的业务流程和系统架构。设计师需要理解各…

04-树5 Root of AVL Tree(浙大数据结构PTA习题)

04-树5 Root of AVL Tree 分数 25 作者 陈越 单位 浙江大学 An AVL tree is a self-balancing binary search tree. In an AVL tree, the heights of the two child subtrees of any node differ by at most one; if at any time they differ by more th…

VMware虚拟机安装Ubuntu-Server版教程(超详细)

目录 1. 下载2. 安装 VMware3. 安装 Ubuntu3.1 新建虚拟机3.2 安装操作系统 4. SSH方式连接操作系统4.1 好用的SSH工具下载:4.2 测试SSH连接 5. 开启root用户登录5.1 设置root用户密码5.2 传统方式切换root用户5.3 直接用root用户登录5.4 SSH启用root用户登录 6. 安…

Kali 我来了

Kali 我来了 1、官网下载2、修改密码3、开启SSH远程登录服务4、关闭kali图形化界面 1、官网下载 官方链接: https://www.kali.org/ 下载链接: https://cdimage.kali.org/kali-2024.1/kali-linux-2024.1-vmware-amd64.7z 解压后 直接导入 VmWare 就可使用可爱的小 Kali 了。 …

太阳能智能农业气象站:创新科技引领现代农业革命

TH-NQ14在科技日新月异的今天,现代农业正迎来一场由智能化、精准化、绿色化驱动的深刻变革。太阳能智能农业气象站作为这场变革的先锋力量,以其独特的优势和强大的功能,为现代农业的发展注入了新的活力,并引领着农业生产的未来趋势…

el-date-picker的使用,及解决切换type时面板样式错乱问题

这里选择器的类型可以选择日月年和时间范围&#xff0c;根据类型不同&#xff0c;el-date-picker的面板也展示不同&#xff0c;但是会出现el-date-picker错位&#xff0c;或者面板位置和层级等问题。 源代码&#xff1a; <el-selectv-model"dateType"placeholder&…

【漏洞复现】Check Point 安全网关任意文件读取漏洞(CVE-2024-24919)

0x01 产品简介 Check Point Security Gateways 是 Check Point Sofware 提供的一系列 网络安全Q解决方案。这些解决方案包括下一代防火墙(NGFW)、数据中心安全网关和 A1驱动的量子网关&#xff0c;旨在为企业提供针对复杂网络威胁的先进防护。它们通过集成的威胁防护、统的安全…

微信小程序 npm构建+vant-weaap安装

微信小程序&#xff1a;工具-npm构建 报错 解决&#xff1a; 1、新建miniprogram文件后&#xff0c;直接进入到miniprogram目录&#xff0c;再次执行下面两个命令&#xff0c;然后再构建npm成功 npm init -y npm install express&#xff08;Node js后端Express开发&#xff…

初出茅庐的小李博客之使用立创开发板(ESP32)连接到EMQX Platform【MQTT TLS/SSL 端口连接】

介绍 手上有一块立创开发板&#xff0c;本着不吃灰的原则把它用起来&#xff0c;今天就来用它来连接上自己部署的MQTT服务器进行数据通信。 硬件&#xff1a;立创开发板 开发环境&#xff1a;Arduino IDE Win11 MQTT 平台&#xff1a;EMQX Platform 立创开发板介绍&#xff1…

ChatGPT AI专题资料合集【65GB】

介绍 ChatGPT & AI专题资料合集【65GB】 &#x1f381;【七七云享】资源仓库&#xff0c;海量资源&#xff0c;无偿分享√

OSError: [Errno 117] Structure needs cleaning

一 问题描述 OSError: [Errno 117] Structure needs cleaning: /tmp/pymp-wafeatri 我重新使用SSH登录也会提示这个类似问题 二 解决方法 2.1 尝试删除报错的文件 &#xff08;想直接看最终解决方法的可忽略此处&#xff09; sudo rm -rf /tmp/pymp-wafeatri 此种方法只能保证…

手机怎么投屏?值得收藏的5个投屏软件(2024全新)

在大屏幕上观看自己喜欢的电影、电视节目和其他媒体节目总是一种很棒的体验。因此&#xff0c;如果你正在寻找通过无线网络或USB数据线在windows或mac电脑或笔记本电脑上屏幕镜像、投屏到电脑显示屏上的指南&#xff0c;那么可以花3分钟时间往下看看。在本文中&#xff0c;小编…

现在,所有人都能免费用GPT-4o了!

OpenAI今日官宣&#xff0c;ChatGPT正式向所有用户免费开放&#xff01;所有用户均可以访问定制化GPT、分析图表、询问有关照片的问题以及5月初GPT-4o添加的其他功能。 OpenAI今天在X上发布推文&#xff1a; 「所有ChatGPT免费用户现在都可以使用浏览、视觉、数据分析、文件上…

CSS+Canvas绘制最美星空(一闪一闪亮晶晶效果+流星划过)

1.效果 2.代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><style>body,html {margin: 0;padding: 0;ov…

tree2retriever:面向RAG场景的递归摘要树检索器实现

tree2retriever 面向RAG场景的递归摘要树检索器实现 Recursive Abstractive Processing for Tree-Organized Retrieval Github:https://github.com/yanqiangmiffy/tree2retriever Example import logging import picklefrom tree2retriever.cluster_tree_builder import Clus…

整理GTX收发器示例工程(高速收发器十一)

前文分析了xilinx官方提供的GTX IP示例工程&#xff0c;该代码的结构比较混乱&#xff0c;本文将该代码进行梳理&#xff0c;形成一个便于使用的模块&#xff0c;后续如果要使用多通道的收发器&#xff0c;多次例化某个模块就行了。 下图是官方例程中GTX IP相关模块的RTL视图&a…