使用ResNet-50训练自己的数据集(花的种类)

1.数据集处理

具体操作

1.把不同类别的花(或者自己数据集的不同类别)放在不同的文件夹下

2.文件夹名字是花朵类别

代码预处理

# 对数据集进行处理
import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = 'flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/' + cla)

# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/' + cla)

# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1

# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
    cla_path = file_path + '/' + cla + '/'  # 某一类别动作的子目录
    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
    for index, image in enumerate(images):
        # eval_index 中保存验证集val的图像名称
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)  # 将选中的图像复制到新路径

        # 其余的图像保存在训练集train中
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
    print()

print("processing done!")

代码实现效果

2.模型训练以及结果可视化

# 模型训练

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
import os
# 定义数据转换
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
data_dir = 'flower_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,
                             shuffle=True, num_workers=4)
               for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# 加载预训练的ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def train_model(model, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # 每个epoch都有训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 设置模型为训练模式
            else:
                model.eval()   # 设置模型为评估模式

            running_loss = 0.0
            running_corrects = 0

            # 遍历数据
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清零参数梯度
                optimizer.zero_grad()

                # 前向传播
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

    print('Training complete')

# 调用训练函数
train_model(model, criterion, optimizer, num_epochs=10)


# 可视化一些预测结果
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

# 可视化模型预测结果
visualize_model(model)

plt.ioff()
plt.show()

可视化效果

3.保存模型

# 保存模型
torch.save(model, 'resnet50_flowers_model.pth')

4.使用网上图片进行测试

from PIL import Image


model.eval()  # 确保模型处于评估模式

# 加载和预处理图片
def process_image(image_path):
    # 这里需要根据您的模型和数据集来定义图片的预处理步骤
    # 例如,调整大小、归一化等
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img_tensor = transform(img).unsqueeze(0)  # 添加batch维度
    return img_tensor

# 预测图片
def predict_image(image_path, model):
    img_tensor = process_image(image_path)
    img_tensor = img_tensor.to(device)  # 确保图片在正确的设备上
    with torch.no_grad():  # 确保在预测过程中不计算梯度
        output = model(img_tensor)
    _, pred = torch.max(output, 1)  # 获取最高分数的类别
    return class_names[pred[0]]

# 比较预测结果和实际标签
image_path ="D:/PyCharm 2024.1.1/pythonProject/data/2.jpg " # 替换为您的图片路径
true_label = 'meigui'  # 替换为实际的标签
predicted_label = predict_image(image_path, model)

print(f'Predicted: {predicted_label}, True: {true_label}')

全部源码

# 对数据集进行处理
import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = 'flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/' + cla)

# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/' + cla)

# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1

# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
    cla_path = file_path + '/' + cla + '/'  # 某一类别动作的子目录
    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
    for index, image in enumerate(images):
        # eval_index 中保存验证集val的图像名称
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)  # 将选中的图像复制到新路径

        # 其余的图像保存在训练集train中
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
    print()

print("processing done!")
# 模型训练

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
import os
# 定义数据转换
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 加载数据集
data_dir = 'flower_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4,
                             shuffle=True, num_workers=4)
               for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# 加载预训练的ResNet-50模型
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# 替换最后的全连接层以适配我们的分类问题
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def train_model(model, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # 每个epoch都有训练和验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 设置模型为训练模式
            else:
                model.eval()   # 设置模型为评估模式

            running_loss = 0.0
            running_corrects = 0

            # 遍历数据
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清零参数梯度
                optimizer.zero_grad()

                # 前向传播
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

    print('Training complete')

# 调用训练函数
train_model(model, criterion, optimizer, num_epochs=10)


# 可视化一些预测结果
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

# 可视化模型预测结果
visualize_model(model)

plt.ioff()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/649754.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

简易计算器

前言 简易计算器,旨在实现一个简单的计算器功能。 整形,浮点型数据的加减乘除运算;数据的统计(如文件中某字符的出现频数);期望,方程运算;平均数,最小值,最大值,中位数…

每日5题Day11 - LeetCode 51 - 55

每一步向前都是向自己的梦想更近一步&#xff0c;坚持不懈&#xff0c;勇往直前&#xff01; 第一题&#xff1a;51. N 皇后 - 力扣&#xff08;LeetCode&#xff09; class Solution {public List<List<String>> solveNQueens(int n) {List<List<String>…

揭秘订单排队模式:社交电商新策略

随着移动互联网的蓬勃发展&#xff0c;社交电商正以其独特的魅力席卷全球。据权威机构预测&#xff0c;到2024年&#xff0c;全球社交电商市场规模有望达到惊人的2.8万亿美元。面对如此庞大的市场蛋糕&#xff0c;如何精准把握机遇&#xff0c;实现业务的跨越式增长&#xff0c…

SQL157 平均播放进度大于60%的视频类别

描述 用户-视频互动表tb_user_video_log iduidvideo_idstart_timeend_timeif_followif_likeif_retweetcomment_id110120012021-10-01 10:00:002021-10-01 10:00:30011NULL210220012021-10-01 10:00:002021-10-01 10:00:21001NULL310320012021-10-01 11:00:502021-10-01 11:01…

强大的医院绩效考核管理系统源码,支持行业内所有绩效方案,且每步核算都可自主进行方案的新建、调整。

医院绩效考核管理系统是采用B/S架构模式设计、使用JAVA语言开发、后台使用MySql数据库进行管理的一整套计算机应用软件源码。 系统和his系统进行对接&#xff0c;按照设定周期&#xff0c;从his系统获取医院科室和医生、护士、其他人员工作量&#xff0c;对没有录入信息化系统…

定时调度任务——kettle开发22

一、流查询 流查询在查询前把数据加载到内存中&#xff0c;并且只能进行等值查询。 流查询的操作包括指定数据源&#xff0c;即我们将数据进行流查询的数据源头&#xff0c;如下图所示流查询的数据源是我们的Excel输入-departments&#xff0c;然后我们还需要指定我们进行流查…

手机远程协助,3个方法,开启沟通新篇章

已经不仅仅是一个通讯工具&#xff0c;更是一个多功能的魔法棒。它轻轻一挥&#xff0c;就能让我们跨越千山万水&#xff0c;与亲朋好友、合作伙伴进行亲密无间的交流。今天&#xff0c;就让我们一起探讨如何利用手机远程协助的三大法宝&#xff0c;开启沟通的新篇章&#xff0…

二叉树——堆的实现

一.前言 前面我们讲解了二叉树的概念以及二叉树的存储结构&#xff1a;https://blog.csdn.net/yiqingaa/article/details/139224974?spm1001.2014.3001.5502 今天我们主要讲讲二叉树的存储结构&#xff0c;以及堆的实现。 二.正文 1.二叉树的顺序结构及实现 1.1二叉树的顺序…

手动操作很麻烦?试试这个自动加好友神器吧!

你是不是也觉得手动逐一输入号码或是微信号&#xff0c;再搜索添加很麻烦&#xff1f;试试这个自动加好友神器——个微管理系统&#xff0c;帮助你省去繁琐的手工操作&#xff0c;节省时间和精力。 首先&#xff0c;在系统上登录微信号&#xff0c;无论你有多少个微信号&#…

服务器重装系统与磁盘操作

诱因&#xff1a;服务器原来装的EXSI&#xff0c;现在要重装一个ubuntu server&#xff0c;出现了下面一些问题&#xff0c;在此记录一下。 目录 1、过程中出现的问题&#xff08;2024.5.26&#xff09;1.1 问题1&#xff1a;如何磨掉原来的ESXI&#xff1f;1.2 问题2&#xf…

【Typescript】通过变量的值即可获取变量的类型【typeof 变量】

注意&#xff1a;只要变量的类型准确,则typeof获取变量的类型就不会错 enum Test {a "a0",b "b0" }// 这里的a是一个变量的值 let a: Test.a "a0" as Test.a// 这里的typeof a是一个类型【Test.a】 let x: typeof a Test.a

玩转香橙派 AIpro,高性能AI开发板评测与项目案例分享

公司最近刚忙完一个项目&#xff0c;闲暇之余&#xff0c;看着手里的树莓派、stm32、Esp32又有些手痒了&#xff0c;准备再搞点小项目出来&#xff0c;但一直没有什么好想法。 说来也巧&#xff0c;恰好收到了CSDN官方的OrangePi AIpro测评活动&#xff0c;平时一直都在用树莓…

matplotlib---气泡图

气泡图简介&#xff1a; 气泡图&#xff08;Bubble Chart&#xff09;是一种数据可视化图形&#xff0c;主要用于展示多个数据点之间的关系。 气泡图通过气泡的大小&#xff0c;位置和颜色可以展示数据之间的关系。在气泡图中&#xff0c;横轴和纵轴通常表示数据的两个维度&a…

NoSQL Redis配置与优化

一、关系数据库与非关系型数据库 1. 关系型数据库&#xff1a; 关系型数据库是一个结构化的数据库&#xff0c;创建在关系模型&#xff08;二维表格模型&#xff09;基础上&#xff0c;一般面向于记录。 SQL 语句&#xff08;标准数据查询语言&#xff09;就是一种基于关系型…

3d火灾救援模拟仿真培训软件复用性强

消防VR安全逃生体验系统是深圳VR公司华锐视点引入了前沿的VR虚拟现实、web3d开发和多媒体交互技术&#xff0c;为用户打造了一个逼真的火灾现场应急逃生模拟演练环境。 相比传统的消防逃生模拟演练&#xff0c;消防VR安全逃生体验系统包含知识讲解和模拟实训演练&#xff0c;体…

(2024,基于熵的激活函数动态优化,具有边界条件的最差激活函数,修正正则化 ReLU)寻找更优激活函数

A Method on Searching Better Activation Functions 公众号&#xff1a;EDPJ&#xff08;进 Q 交流群&#xff1a;922230617 或加 VX&#xff1a;CV_EDPJ 进 V 交流群&#xff09; 目录 0. 摘要 3. 动机 4. 方法论 4.1 问题设定 4.1.1 贝叶斯错误率和信息熵 4.1.2 激活…

物业

用户报修 审核专员可以操作&#xff08;前端&#xff09;&#x1f197; 工程部可以看到不可以操作&#xff08;前端&#xff09;&#x1f197; 项目经理可以看到不可以操作&#xff08;前端&#xff09;&#x1f197; 经理可以看到不可以操作&#xff08;前端&#xff09;&…

Kivy 项目51斩百词 6 播放读音

为了给小喇叭图像绑定点击事件&#xff0c;实现当用户点击按钮时&#xff0c;触发该事件对应的回调方法。 在方法内对于不同的系统Kivy使用不同的播放语音方法&#xff0c; 对于Windows系统 使用SoundLoader播放语音&#xff0c; 对于其他的Unix系统 使用Pyjnjus播放…

C语言 数组——排序算法的函数实现

目录 交换法排序 用交换法对成绩数组升序排序 选择法排序 冒泡法排序 归并法排序 交换法排序 用交换法对成绩数组升序排序 选择法排序 冒泡法排序 归并法排序