基于Pytorch框架的深度学习Swin-Transformer神经网络食物分类系统源码

 第一步:准备数据

5种鸟类数据:self.class_indict = ["苹果派", "猪小排", "果仁蜜饼", "生牛肉薄片", "鞑靼牛肉"]

,总共有5000张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个Swin-Transformer网络,其原理介绍如下:

Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper的荣誉称号。虽然Vision Transformer (ViT)在图像分类方面的结果令人鼓舞,但是由于其低分辨率特性映射和复杂度随图像大小的二次增长,其结构不适合作为密集视觉任务高分辨率输入图像的通过骨干网路。为了最佳的精度和速度的权衡,提出了Swin-Transformer结构。

Swin-Transformer的基础流程。

  1. 输入一张图片 [ H ∗ W ∗ 3 ] [H*W*3] [H∗W∗3]
  2. 图片经过Patch Partition层进行图片分割
  3. 分割后的数据经过Linear Embedding层进行特征映射
  4. 将特征映射后的数据输入具有改进的自关注计算的Transformer块(Swin Transformer块),并与Linear Embedding一起被称为第1阶段
  5. 与阶段1不同,阶段2-4在输入模型前需要进行Patch Merging进行下采样,产生分层表示。
  6. 最终将经过阶段4的数据经过输出模块(包括一个LayerNorm层、一个AdaptiveAvgPool1d层和一个全连接层)进行分类。
Swin-Transformer结构

简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。其中,图(a)表示Swin Transformer的网络结构流程,图(b)表示两阶段的Swin Transformer Block结构。注意:在Swin Transformer中,每个阶段的Swin Transformer Block结构都是2的倍数,因为里面使用的都是两阶段的Swin Transformer Block结构,如下图所示:

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

from my_dataset import MyDataSet
from model import swin_tiny_patch4_window7_224 as create_model
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    tb_writer = SummaryWriter()

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    img_size = 224
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    model = create_model(num_classes=args.num_classes).to(device)

    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)["model"]
        # 删除有关分类类别的权重
        for k in list(weights_dict.keys()):
            if "head" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)

    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)

        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=0.0001)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default=r"G:\demo\data\foods")

    # 预训练权重路径,如果不想载入就设置为空字符
    parser.add_argument('--weights', type=str, default='swin_tiny_patch4_window7_224.pth',
                        help='initial weights path')
    # 是否冻结权重
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习Swin-Transformer神经网络食物分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/725787.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

期望28K,5.14日蚂蚁java社招一面(杭州)

面经哥只做互联网社招面试经历分享,关注我,每日推送精选面经,面试前,先找面经哥 1、线程池的几个参数? 2、一道关于线程池的代码题目,数据库中存任务,通过一个有10个核心线程和无限队列的线程池…

基于springboot实现宠物商城网站管理系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现宠物商城网站管理系统演示 摘要 传统信息的管理大部分依赖于管理人员的手工登记与管理,然而,随着近些年信息技术的迅猛发展,让许多比较老套的信息管理模式进行了更新迭代,商品信息因为其管理内容繁杂&#xff…

C#.Net筑基-类型系统②常见类型

01、结构体类型Struct 结构体 struct 是一种用户自定义的值类型,常用于定义一些简单(轻量)的数据结构。对于一些局部使用的数据结构,优先使用结构体,效率要高很多。 可以有构造函数,也可以没有。因此初始…

数据结构:4.1.2二叉搜索树的插入

整个框架和FInd函数的实现是一样的&#xff0c;但是也有不同&#xff08;注意&#xff09; 35>30 向30的右子树 35<41 向41的左子树 35>33 向33的右子树&#xff0c;但33右边为空&#xff0c;所以35就挂在33的右边 因为要把35挂在33的右边&#xff0c;所以要把33的…

前端路线指导(2):前端基础版学习路线

前端基础路线的细节&#xff1a; 哈喽大家好&#xff01;我是小粉&#xff0c;双一流本科 自学前端一年&#xff0c;收获腾讯&#xff0c;字节等9家互联网大厂offer&#xff0c;秋招面试通过率100%&#xff0c;其中半数offer为ssp&#xff08;薪资最高档&#xff09; 以下是我根…

Houdini到UE地形流程

目录 Houidni地形制作 UE地形设置 Houdini engine插件安装 B站参考视频 Houidni地形制作 使用Terrain的HeightField相关节点制作地形&#xff1b;设置地形相关的材质层&#xff08;如rock、soil、grass等&#xff09;&#xff0c;注意材质的重叠&#xff1b; //detail层级&…

Stable Diffusion 3 大模型文生图实践

windows教程2024年最新Stable Diffusion本地化部署详细攻略&#xff0c;手把手教程&#xff08;建议收藏!!)_stable diffusion 本地部署-CSDN博客 linux本地安装教程 1.前期准备工作 1&#xff09;创建conda环境 conda create --name stable3 python3.10 2&#xff09;下…

一种基于非线性滤波过程的旋转机械故障诊断方法(MATLAB)

在众多的旋转机械故障诊断方法中&#xff0c;包络分析&#xff0c;又称为共振解调技术&#xff0c;是目前应用最为成功的方法之一。首先&#xff0c;对激励引起的共振频带进行带通滤波&#xff0c;然后对滤波信号进行包络谱分析&#xff0c;通过识别包络谱中的故障相关的特征频…

代码随想录——全排列(Leetcode LCR083)

题目链接 回溯 class Solution {List<List<Integer>> res new ArrayList<List<Integer>>();List<Integer> list new ArrayList<Integer>();boolean[] used;public List<List<Integer>> permute(int[] nums) {used new bo…

数据资产安全保卫战:构建多层次、全方位的数据安全防护体系,守护企业核心数据资产安全

一、引言 在信息化时代&#xff0c;数据资产已成为企业运营的核心&#xff0c;其安全性直接关系到企业的生存与发展。然而&#xff0c;随着网络技术的飞速发展&#xff0c;数据泄露、黑客攻击等安全威胁日益增多&#xff0c;给企业的数据资产安全带来了严峻挑战。因此&#xf…

基于esp-idf的arm2d移植

什么是ARM2D Arm在Github上发布了一个专门针对“全体” Cortex-M处理器的2D图形加速库——Arm-2D 我们可以简单的把这个2D图形加速库理解为是一个专门针对Cortex-M处理器的标准“显卡驱动”。虽然这里的“显卡驱动”只是一个夸张的说法——似乎没有哪个Cortex-M处理器“配得上…

怎么生成活码类型的二维码?在线制作活码的简单方法

活码是现在很多人会选择使用的一种二维码类型&#xff0c;制作活码二维码可以展现更多类型的内容&#xff0c;而且二维码可以随时在图案不变的情况下修改内容&#xff0c;与静态码相比使用起来更加的灵活。目前&#xff0c;活码可以用来展示图片、视频、音频、文件、网址、表单…

一个小例子助你彻底理解协程

一个小例子助你彻底理解协程 协程&#xff0c;可能是Python中最让初学者困惑的知识点之一&#xff0c;它也是Python中实现并发编程的一种重要方式。Python中可以使用多线程和多进程来实现并发&#xff0c;这两种方式相对来说是大家比较熟悉的。事实上&#xff0c;还有一种实现…

css 文字两端对齐

<body><div class"box"><p>姓名</p><p>性与别</p><p>家庭住址</p><p>how are you</p><p>hello</p><p>1234</p><p>1 2 3 4</p></div> </body> text-a…

Java零基础之多线程篇:线程的多种创建方式

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。运营社区&#xff1a;C站/掘金/腾讯云&#xff1b;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互相学习&#xff0c;一…

【递归、搜索与回溯】综合练习四

综合练习四 1.单词搜索2.黄金矿工3.不同路径 III 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&#x1f603; 1.单词搜索 题目链接&#xff1a;79. 单词搜…

前端路线指导(3):前端进阶版学习路线

前端进阶版学习路线&#xff1a; 哈喽大家好&#xff01;我是小粉&#xff0c;双一流本科&#xff0c;自学前端一年&#xff0c;收获腾讯&#xff0c;字节等9家互联网大厂offer&#xff0c;秋招面试通过率100%&#xff0c;其中半数offer为ssp&#xff08;薪资最高档&#xff09…

如何查看公网IP?

什么是公网IP&#xff1f; 公网IP&#xff08;Internet Protocol&#xff09;是指分配给互联网上的计算机设备的唯一标识符。公网IP地址是由互联网服务提供商&#xff08;ISP&#xff09;分配给用户设备&#xff0c;使其可以与全球范围内的其他设备进行通信。公网IP地址通常采…

【超越拟合:深度学习中的过拟合与欠拟合应对策略】

如何处理过拟合 由于过拟合的主要问题是你的模型与训练数据拟合得太好&#xff0c;因此你需要使用技术来“控制它”。防止过拟合的常用技术称为正则化。我喜欢将其视为“使我们的模型更加规则”&#xff0c;例如能够拟合更多类型的数据。 让我们讨论一些防止过拟合的方法。 获…

代理模式(静态代理/动态代理)

代理模式&#xff08;Proxy Pattern&#xff09; 一 定义 为其他对象提供一种代理&#xff0c;以控制对这个对象的访问。 代理对象在客户端和目标对象之间起到了中介作用&#xff0c;起到保护或增强目标对象的作用。 属于结构型设计模式。 代理模式分为静态代理和动态代理。…