深度学习pytorch实战五:基于ResNet34迁移学习的方法图像分类篇自建花数据集图像分类(5类)超详细代码

1.数据集简介
2.模型相关知识
3.split_data.py——训练集与测试集划分
4.model.py——定义ResNet34网络模型
5.train.py——加载数据集并训练,训练集计算损失值loss,测试集计算accuracy,保存训练好的网络参数
6.predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试

一、数据集简介

1.自建数据文件夹

首先确定这次分类种类,采用爬虫、官网数据集和自己拍照的照片获取5类,新建个文件夹data,里面包含5个文件夹,文件夹名字取种类英文,每个文件夹照片数量最好一样多,五百多张以上。如我选了雏菊,蒲公英,玫瑰,向日葵,郁金香5类,如下图,每种类型有600~900张图像。如下图

花数据集下载链接https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
在这里插入图片描述
2.划分训练集与测试集

这是划分数据集代码,同一目录下运,复制改文件夹路径。

import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


# 获取 photos 文件夹下除 .txt 文件以外所有文件夹名(即3种分类的类名)
file_path = 'data/flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

# 创建 训练集train 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/' + cla)

# 创建 验证集val 文件夹,并由3种类名在其目录下创建3个子目录
mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/' + cla)

# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1

# 遍历3种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:
    cla_path = file_path + '/' + cla + '/'  # 某一类别动作的子目录
    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称
    num = len(images)
    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称
    for index, image in enumerate(images):
        # eval_index 中保存验证集val的图像名称
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)  # 将选中的图像复制到新路径

        # 其余的图像保存在训练集train中
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
    print()

print("processing done!")

二、模型相关知识

之前有文章介绍模型,如果不清楚可以点下链接转过去学习。

深度学习卷积神经网络CNN之ResNet模型网络详解说明(超详细理论篇)

在这里插入图片描述

三、model.py——定义ResNet34网络模型

这里还是直接复制给出原模型,不用改参数。模型包含34、50、101模型

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

四、train.py——训练,计算损失值loss,计算accuracy,保存训练好的网络参数

第一步,提前下载权重链接,复制链接网址打开直接下载,下载完,放在同一个工程文件夹,记得修改个名字,后面要用。

ResNet34权重链接https://download.pytorch.org/models/resnet34-333f7ec4.pth

第二步 71行类数、63行之前下载权重文件名字、83行保存最终权重文件名字

net.fc = nn.Linear(in_channel, 5)//修改5类的5
model_weight_path = "./resnet34-pre.pth"
save_path = './resNext34.pth'

其他参数bach_size=16;(根据cpu或GPU性能选择32、64等)
学习率 0.01
epoch 5

import os
import sys
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34,resnet101


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "zjdata", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 16
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    for param in net.parameters():
        param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)
    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.01)

    epochs = 5
    best_acc = 0.0
    save_path = './resNext34.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

训练开始截图,我是用CPU训练
在这里插入图片描述

六、predict.py——利用训练好的网络参数后,用自己找的图像进行分类测试

注意图片位置和权重参数名字

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "./1.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = resnet34(num_classes=5).to(device)

    # load model weights
    weights_path = "./resNext34.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # prediction
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

预测结果截图
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/29405.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

(三)Kafka 生产者

文章目录 1. Kafka 发送消息的主要步骤2.创建 Kafka 生产者3.发送消息到 Kafka(1)发送并忘记(2)同步发送(3)异步发送 4.生产者配置(1)client.id(2)ack&#x…

Python基础(2)——Python解释器

Python基础(2)——Python解释器 文章目录 Python基础(2)——Python解释器目标一. 解释器的作用二. 下载Python解释器三. 安装Python解释器总结 目标 解释器的作用下载Python解释器安装Python解释器 一. 解释器的作用 Python解释…

对于ChatGPT,马化腾、马斯克等科技大佬竟然这么说!

ChatGPT一夜爆火之后,国内几乎是各大互联网公司都在摩拳擦掌,跃跃欲试,从百度的文心一言,到阿里的通义千问,还有360的智脑,讯飞的星火,语言大模型如雨后春笋一般涌出,犹如2014年新能…

Android 逆向安全行业前景如何?

前言 Android 逆向是指对已经发布的 Android 应用进行分析和研究,通过逆向工程,将 Android 应用中的底层实现原理、业务逻辑、源代码以及恶意行为等等信息进行破解和掌握。逆向工程可以让研究者深入了解 Android 应用的实现细节,从而识别和修…

麒麟V10服务器 安装samba 软件,并且实现远程连接(压缩包形式)

目录 1 安装包2 实现3 如何查看安装的sambd 的版本4 使用 1 安装包 百度网盘 链接: https://pan.baidu.com/s/1l6HDAGE4_Itj-cp7XtpUNg 提取码: 100w 复制这段内容后打开百度网盘手机App,操作更方便哦2 实现 以下是在Linux系统中使用压缩包方式安装Samba服务的步…

走向实用的AI编解码

基于AI的端到端数据压缩方法受到越来越多的关注,研究对象已经包括图像、视频、点云、文本、语音和基因组等,其中AI图像压缩的研究最为活跃。图像编解码的研究和应用历史悠久,AI方法要达到实用,需要解决诸多问题才能取得相比于传统…

Gradle版本目录(Version Catalog)

Gradle版本目录(Version Catalog) “版本目录是一份依赖项列表,以依赖坐标表示,用户在构建脚本中声明依赖项时可以从中选择。” 我们可以使用版本目录将所有依赖项声明及其版本号保存在单个位置。这样,我们可以轻松地在模块和项目之间共享依…

串口协议说明

文章目录 关系波特率概念波特率相对误差UART误差保证 协议常见的串行接口协议之间的比较USB 转串口PL2303USB 转串口CP2102USB转232终端电阻 串口电平TTL电平485电平 帧奇偶校验 关系 两个半双工,一发一收,就是Uart 在一根线的基础上,多加一…

iPhone手机UDID获取方法

UDID:iOS设备的唯一识别码,每台iOS设备都有一个独一无二的编码,这个编码,就称为识别码,也叫做UDID(Unique Device Identifier) 一、通过Xcode查看 手机连接电脑打开Xcode,选择wind…

初探 transformer

大部分QA的问题都可以使用seq2seq来实现。或者说大多数的NLP问题都可以使用seq2seq模型来解决。 但是呢最好的办法还是对具体的问题作出特定的模型训练。 概述 Transformer就是一种seq2seq模型。 我们先看一下seq2seq这个模型的大体框架(其实就是一个编码器和一个解码器)&a…

Vue中如何进行表单图片裁剪与预览

Vue中如何进行表单图片裁剪与预览 在前端开发中,表单提交是一个常见的操作。有时候,我们需要上传图片,但是上传的图片可能会非常大,这会增加服务器的负担,同时也会降低用户的体验。因此,我们通常需要对上传…

选择合适的采购系统,实现企业数字化转型

随着数字化技术的飞速发展,企业数字化转型已经成为了当今市场的必然趋势。而采购系统作为企业数字化转型的重要组成部分,选择合适的采购系统对于企业来说至关重要。本文将围绕选择合适的采购系统,实现企业数字化转型展开讨论。 一、企业数字化…

OpenCV项目开发实战-- 的单应性(Homography)实例Python/C++代码实现

文末附基于Python和C++两种方式实现的测试代码下载链接 什么是单应性(Homography)? 考虑图 1 中所示的平面(书的顶部)的两个图像。红点表示两个图像中的相同物理点。在计算机视觉术语中,我们称这些为对应点。图 1. 显示了四种不同颜色的四个对应点——红色、绿色、黄色和…

YUM源安装,在线YUM,本地YUM

YUM源 一、定义 YUM(全称为 Yellow dog Updater, Modified)是一个在 Fedora 和 RedHat 以及 CentOS 中的 Shell 前端软件包管理器。基于RPM包管理,能够从指定的服务器自动下载RPM包并且安装,**可以自动处理依赖性关系&…

【八大排序(五)】快排进阶篇-挖坑法+前后指针法

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:八大排序专栏⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习排序知识   🔝🔝 快排进阶篇 1. 前情回顾2. 思路回顾3. 单…

java方法

文章目录 一、java方法总结 一、java方法 在前面几个章节中我们经常使用到 System.out.println(),那么它是什么呢? println() 是一个方法。 System 是系统类。 out 是标准输出对象。这句话的用法是调用系统类 System 中的标准输出对象 out 中的方法 pr…

docker部署prometheus+grafana视图监控

效果 一、grafana可视化平台部署 docker run -d \--namegrafana \--restartalways \-p 3000:3000 \grafana/grafanagrafana我也是部署在170.110服务器上,192.168.170.110:3000访问grafana 默认账号密码都是admin 二、部署exportor采集信息 针对各类数据库平台系统…

ch8_1_CPU的结构和功能

1. cpu的结构 1.1CPU 的功能 控制器的功能 控制器的功能具体作用取指令指令控制分析指令操作控制执行指令, 发出各种操作命令控制程序输入与结果的输出时间控制总线管理处理中断处理异常情况和特殊请求数据加工 运算器的功能 实现算术运算 和 逻辑运算&#x…

我的256创作纪念日

机缘 挺开心的,想到自己未曾写过一些非技术类的博客,恰巧今天刚好也是我的256创作纪念日,就乘着这个日子,写一点自己过去的收获、内心的想法和对未来的展望吧。 本人不才,只就读于一所民办本科之中,我挺不想…

【linux】探索Linux命令行中强大的网络工具:netstat

文章目录 前言一、netstat是什么?二、使用方法1.常用参数2.实例演示3.更多功能 总结 前言 在Linux命令行中,有许多实用的工具可帮助我们管理和监控网络连接。其中一个最重要的工具就是netstat,它提供了丰富的网络连接和统计信息,…