机器学习周报(12.2-12.8)

文章目录

    • 摘要
    • Abstract
  • Vision Transformer
    • 1 原理
    • 2 代码

摘要

本周学习了Vision Transformer (ViT) 的基本原理及其实现,并完成了基于PyTorch的模型训练、验证和预测任务。深入理解了ViT如何将图像分割成patch作为输入序列,并结合Transformer Encoder处理。通过迁移学习在花类数据集上训练模型,并验证了模型在预测任务中的优越性能。

Abstract

This week, I studied the fundamental principles and implementation of Vision Transformer (ViT) and completed model training, validation, and prediction tasks using PyTorch. I gained a deep understanding of how ViT splits an image into patches as input sequences and processes them using the Transformer Encoder. By leveraging transfer learning, I trained the model on a flower dataset and validated its superior performance in prediction tasks.

Vision Transformer

1 原理

  • 数据处理

我认为ViT的关键在于理解怎么将图片当作一个序列输入进模型之中。我们先看看ViT整体结构图,如下图所示

在这里插入图片描述
论文中提到将 224x224x3 的图像作为输入,将图像分为 16x16x3 大小的patch,也就是说将输入图像分为了 224 × 224 × 3 16 × 16 × 3 = 196 \frac{224×224×3}{16×16×3}=196 16×16×3224×224×3=196 个patch。其中每个patch拉直之后的维度为 16×16×3=768维,也就是Linear Projection of Flattened Patches层下面分割的小图像。

在具体实现中,使用卷积核大小为 16x16x3 、步距为16、卷积核个数为768的卷积层,就能将3维图像转换为Transformer所需要的输入token[组数,维度]。

  • 全连接层
    上述[196,768]的token将传入Linear Projection of Flattened Patches层,该层是 768x768 的全连接层,该层输出认为 196x768 。

  • 位置编码
    将经过全连接层后的输出进行位置编码,其位置编码和Transformer中的时序编码有异曲同工之妙,前者可以通过位置编码表示出token之间关于原输入图像的一些位置信息,后者可以表示输入先后的时序信息。
    该模型位置编码通过类似于坐标的形式表达,直接于输入相加,不改变维度大小。如下图所示:
    在这里插入图片描述

进行位置编码后,还需要加上一个特殊字符(最左输入0*),输入总组数从之前的196变为197,传入Transformer Encoder的token为[197,768]。

  • Transformer Encoder
    在这里插入图片描述
    ViT采用的是Transformer中编码器进行叠加,但其中的参数数量有所不同。
    经过位置编码和加入特殊字符的token[197,768]传入编码器,首先经过层归一化,再经过多头自注意力。这里的多头自注意力是采用12个头,也就是将768维分为12份,每份(Q、K、V)64维度,计算之后再进行合并为768维。

ViT中的编码器仍是采用残差连接,再经过一次层归一化后,就进入单个Transformer Encoder的最后一层MLP(多层感知机)。MLP将经过多头自注意力的输出维度升高4倍,即从768变为3072,最后再将维度降至768维

ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。

  • 输出
    ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。

在这里插入图片描述

最后,通过全连接层和softmax进行概率输出即可

2 代码

在理解完ViT的原理之后,我们来看看PyTorch代码如何实现。这里以ViT-base模型,输入图像 224x224x3,patch大小 16x16x3 为例

花类数据集:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzy

训练模型代码如下,需要自行更改数据集路径和权重路径。

import os
import math
import argparse
 
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
 
from my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluate
 
 
def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
 
    if os.path.exists("../weights") is False:
        os.makedirs("../weights")
 
    tb_writer = SummaryWriter()
 
    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
 
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
 
    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])
 
    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])
 
    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)
 
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)
 
    model = create_model(num_classes=args.num_classes, has_logits=False).to(device)
 
    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)
        # 删除不需要的权重
        # del_keys = ['head.weight', 'head.bias'] if model.has_logits \
        #     else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
        del_keys = ['head.weight', 'head.bias']
        for k in del_keys:
            del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))
 
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head, pre_logits外,其他权重全部冻结
            if "head" not in name and "pre_logits" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))
 
    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
 
    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)
 
        scheduler.step()
 
        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)
 
        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
 
        torch.save(model.state_dict(), "../weights/model-{}.pth".format(epoch))
 
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lrf', type=float, default=0.01)
 
    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data_path', type=str, default='../data/flower_photos', help='path to dataset')
    parser.add_argument('--model-name', default='', help='create model name')
 
    # 预训练权重路径,如果不想载入就设置为空字符
    parser.add_argument('--weights', type=str, default='../weights/vit_base_patch16_224.pth', help='path to initial weights')
    # 是否冻结权重
    parser.add_argument('--freeze-layers', type=bool, default=True)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
 
    opt = parser.parse_args()
 
    main(opt)

训练结果如下:

在这里插入图片描述

因为是迁移学习的原因,只需要进行微调即可,所以9epoch之后准确率就达到97.9%了。

每训练一个epoch,就会将训练模型保存至weights文件夹,如下图所示

在这里插入图片描述

通过上述代码的训练之后,我们可以将保存的模型model-9.pth引入预测代码进行预测啦!需自行更改权重路径,以及需要测试的图片路径。

import os
import json
 
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
 
from vit_model import vit_base_patch16_224_in21k as create_model
 
 
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
    data_transform = transforms.Compose(
        [transforms.Resize(254),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
    # load image
    img_path = "../data/Image/flower.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    img2 = img
    plt.imshow(img)
    plt.show()
 
    img = img.convert('RGB')
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)  # [1, 3, 224, 224]
 
    # read class_indict
    json_path = 'class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
 
    with open(json_path, "r") as f:
        class_indict = json.load(f)
 
    # create model
    model = create_model(num_classes=5, has_logits=False).to(device)  # num_classes=5:表示模型将被训练来识别5个不同的类别;has_logits=False:模型不直接输出logits,在实际应用中,这通常意味着模型的输出层之后可能会跟随一个softmax激活函数
    # load model weights
    model_weight_path = "../weights/model-9.pth"  # 采用第10轮训练的参数
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
 
    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
    plt.imshow(img2)
    plt.show()
 
 
if __name__ == '__main__':
    main()

模型预测结果如下所示:
在这里插入图片描述

在这里插入图片描述

模型预测结果几乎100%为sunflowers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/933279.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Jmeter Address already in use: connect 解决

做压测接口时,并发一段时间后,会报java.net.BindException: Address already in use: connect 原因: windows提供给TCP/IP链接的端口为 1024-5000,并且要四分钟来循环回收它们,就导致在短时间内跑大量的请求时将端口占…

深入解析 HTML Input 元素:构建交互性表单的核心

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

力扣打卡10:K个一组翻转链表

链接:25. K 个一组翻转链表 - 力扣(LeetCode) 这道题需要在链表上,每k个为一组,翻转,链接。 乍一看好像比较容易,其实有很多细节。比如每一组反转后怎么找到上一组的新尾,怎么找到…

【银河麒麟操作系统真实案例分享】内存黑洞导致服务器卡死分析全过程

了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer.kylinos.cn 文档中心:https://documentkylinos.cn 现象描述 机房显示器连接服务器后黑屏&#xff…

Android显示系统(04)- OpenGL ES - Shader绘制三角形

Android显示系统(02)- OpenGL ES - 概述 Android显示系统(03)- OpenGL ES - GLSurfaceView的使用 Android显示系统(04)- OpenGL ES - Shader绘制三角形 Android显示系统(05)- OpenGL…

Ubuntu 22.04安装Nessus(离线激活模式)

Ubuntu 22.04安装Nessus 一、 Nessus 简介二、Nessus下载安装三、激活Nessus四、创建一个基础扫描五、 破解Nessus只能扫描16个地址的限制六、更新插件 一、 Nessus 简介 Nessus 官网: https://www.tenable.com/ Nessus号称世界上最流行的扫描程序,Nessu…

OpenAI 发布 o1 LLM,推出 ChatGPT Pro

OpenAI正式发布了专为复杂推理而构建的 OpenAI o1大型语言模型(LLM)。 该公司还推出了 ChatGPT Pro,这是一项每月 200 美元的套餐,包括无限制访问 OpenAI o1、o1-mini、GPT-4o 和高级语音对话。 OpenAI o1 从 9 月 12 日起在 ChatGPT 中推出预览版&…

上海理工大学《2024年867自动控制原理真题》 (完整版)

本文内容,全部选自自动化考研联盟的:《上海理工大学867自控考研资料》的真题篇。后续会持续更新更多学校,更多年份的真题,记得关注哦~ 目录 2024年真题 Part1:2024年完整版真题 2024年真题

汽配行业数字化解决方案(一)

汽配行业数字化解决方案,是通过整合云计算、大数据、人工智能、物联网等先进技术,构建一个全面、高效、智能的数字化生态系统,以实现汽配供应链的全程可视化与智能化管理。该解决方案涵盖了从供应商管理、库存优化、订单处理、物流跟踪到客户…

华为开源自研AI框架昇思MindSpore应用案例:基于MindSpore框架的SGD优化器案例实现

SGD优化器基本原理讲解 随机梯度下降(SGD)是一种迭代方法,其背后基本思想最早可以追溯到1950年代的Robbins-Monro算法,用于优化可微分目标函数。 它可以被视为梯度下降优化的随机近似,因为它用实际梯度(从…

集成学习综合教程

一、前置知识 一个分类器的分类准确率在60%-80%,即:比随机预测略好,但准确率却不太高,我们可以称之为 “弱分类器”,比如CART(classification and regression tree 分类与回归树)。 反之&#x…

大语言模型技术相关知识-笔记整理

系列文章目录 这个系列攒了很久。主要是前段之间面试大语言模型方面的实习(被拷打太多次了),然后每天根据面试官的问题进行扩展和补充的这个笔记。内容来源主要来自视频、个人理解以及官方文档中的记录。方便后面的回顾。 2024-12-7: 对公式…

【开源免费】基于SpringBoot+Vue.JS中小型医院网站(JAVA毕业设计)

博主说明:本文项目编号 T 078 ,文末自助获取源码 \color{red}{T078,文末自助获取源码} T078,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析…

哈希知识详解

一、哈希 以前,在面对海量数据的查找时,最快就是红黑树 logN,无法满足需求。 于是探索出了另一种用关键字 key 值与其存储位置建立映射的思想,让查找效率提升到 O(1) ,这个就是哈希。 二、哈希函数 1、直接定值法 ①…

红日靶场vulnstark 4靶机的测试报告[细节](一)

目录 一、测试环境 1、系统环境 2、注意事项 3、使用工具/软件 二、测试目的 三、操作过程 1、信息搜集 2、漏洞利用Getshell ①Struts 2 s2-045漏洞 手工利用s2-45漏洞 Msf综合利用 ②Tomcat框架(CVE-2017-12615) ③phpMyAdmin(CVE-2018-12613) 构造语句写入冰蝎木…

D92【python 接口自动化学习】- pytest基础用法

day92 pytest的skip和skipif用法 学习日期:20241208 学习目标:pytest基础用法 -- pytest的skip和skipif用法 学习笔记: 测试用例跳过 skip和skipif用法,测试用例跳过 pytest.mark.skip 跳过标记的用例 pytest.mark.skipif(1 …

【Java】类似王者荣耀游戏

r77683962/WangZheYouDianRongYao 运行效果图: 类似王者荣耀游戏运行效果图_哔哩哔哩_bilibili

【数字电路与逻辑设计】实验二 数值比较器

文章总览:YuanDaiMa2048博客文章总览 【数字电路与逻辑设计】实验二 数值比较器 一、实验内容二、设计过程(一)真值表(二)设计思路 三、源代码(一)代码说明:(二&#xff…

探索十个 AI 对话提示词网站,提升交互体验

learning prompt 网址:Hello from Learning Prompt | Learning Prompt 简介:这是一个学习 提示词 的网站,有 ChatGPT,Midjourney 的提示词教程、技巧等,他在右上角有中文语言的选择,教程非常详尽 LangCha…

C# Decimal

文章目录 前言1. Decimal 的基本特性2. 基本用法示例3. 特殊值与转换4. 数学运算示例5. 精度处理示例6. 比较操作示例7. 货币计算示例8. Decimal 的保留小数位数9. 处理 Decimal 的溢出和下溢10. 避免浮点数计算误差总结 前言 decimal 是 C# 中一种用于表示高精度十进制数的关键…