霹雳吧啦Wz《pytorch图像分类》-p3VGG网络

《pytorch图像分类》p3VGG网络详解及感受野的计算

  • 一、零碎知识点
    • 1.nn.Sequential
    • 2.**kwargs
  • 二、VGG网络模型详解
    • 1.感受野
    • 2.模型手算
  • 三、代码
    • 1.module.py
    • 2.train.py
    • 3.predict.py

一、零碎知识点

论文连接:VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION
代码链接:霹雳吧啦Wzdeep-learning-for-image-processing

1.nn.Sequential

nn.Sequential是PyTorch中的一个类,用于按顺序组织和堆叠神经网络的层或模块。它提供了一种便捷的方式来构建简单的前向传播网络。

import torch
import torch.nn as nn

model = nn.Sequential(
in_channels,out_channels,kernel_size
    nn.Conv2d(in_channels,out_channels,kernel_size)
    nn.ReLU(),                                # 添加激活函数
    nn.Linear(hidden_features, out_features)  # 添加线性层
)

2.**kwargs

**kwargs是一个特殊的参数传递方式,它允许函数接受不定数量的关键字参数(Keyword Arguments)并将它们作为一个字典进行处理。

下面是一个简单的示例说明**kwargs的用法:

def example_func(**kwargs):
    for key, value in kwargs.items():
        print(key, value)

example_func(name='Maverick', age=22, location='cheng du')

输出结果:

name Maverick
age 22
location cheng du

二、VGG网络模型详解

1.感受野

感受野(receptive field)是指在卷积神经网络(CNN)中的某一层输出特征图上的像素位置所对应的输入图像上的区域大小。
随着卷积核的增多(即网络的加深),感受野会越来越大。
在这里插入图片描述
当我们说一个神经网络层的感受野大小为N时,可以简单解释为:在该层输出特征图上的一个像素点,它所"看到"的输入图像区域大小是N×N。
随着网络的层数增加,感受野也会逐渐增大。最早的卷积层(例如卷积核为3x3)的感受野较小,但后续的层会通过池化或步幅更大的卷积来逐渐增加感受野的大小。

在这里插入图片描述

2.模型手算

VGG网络的常用配置是D,有16个层(包括13个卷积层和3个全连接层)

LRN是一种对神经网络中的特征图进行局部归一化的操作。其目的是增加网络的鲁棒性,防止某些特征具有过大的响应值而抑制其他特征的重要性。
具有鲁棒性的模型能够在输入数据中存在一定程度的扰动、噪声或异常情况下仍然保持良好的性能。
在这里插入图片描述
反复记忆:输出的特征矩阵的深度out_channels和卷积核的个数相同
因为彩色图形有rgb三个通道,所以最开始的特征矩阵深度为3
后面都是根据卷积核个数的不同产生不同的改变。
在这里插入图片描述

三、代码

1.module.py

import torch.nn as nn
import torch

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}


class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_features(cfg: list):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs)
    return model

2.train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 2
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=0)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=0)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

用的是老师的代码,我的gpu内存不够,我已经将批处理大小(batch size)减少到2了,还是运行不起来
CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 2.00 GiB total capacity; 718.01 MiB already allocated; 341.00 MiB free; 740.00 MiB reserved in total by PyTorch)

3.predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)
    
    # create model
    model = vgg(model_name="vgg16", num_classes=5).to(device)
    # load model weights
    weights_path = "./vgg16Net.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/287323.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

leetcode2487.从链表中移除节点

文章目录 题目思路复杂度Code 题目 给你一个链表的头节点 head 。 移除每个右侧有一个更大数值的节点。 返回修改后链表的头节点 head 。 示例 1: 输入:head [5,2,13,3,8] 输出:[13,8] 解释:需要移除的节点是 5 &#xff0…

Spring Security 6.x 系列(13)—— 会话管理及源码分析

一、会话概念 在实现会话管理之前,我们还是先来了解一下协议和会话的概念,连协议和会话都不知道是啥,还谈啥管理。 1.1 http 协议 因为我们现在的会话,基本上都是基于HTTP协议的,所以在讲解会话之前,我再…

Nginx(十三) 配置文件详解 - 反向代理(超详细)

本篇文章主要讲ngx_http_proxy_module和ngx_stream_proxy_module模块下各指令的使用方法。 1. 代理请求 proxy_pass 1.1 proxy_pass 代理请求 Syntax: proxy_pass URL; Default: — Context: location, if in location, limit_except 设置代理服务器的协议和地址以…

【模拟电路】模拟集成电路之神-NE555

一、集成电路NE555简介 二、功能框图与引脚说明 三、比较器(运放) 四、反相门(非门) 五、或非门 六、双稳态触发器 七、NE555的工作原理 集成电路NE555的芯片手册 C5157696 一、集成电路NE555简介 NE555起源于上个世纪70年代&a…

element-ui Tree 树形控件 过滤保留子级并获取过滤后的数据

本示例基于vue2 element-ui element-ui 的官网demo是只保留到过滤值一级的&#xff0c;并不会保留其子级 目标 1、Tree 树形控件 保留过滤值的子级 2、在第一次过滤数据的基础上进行第二次过滤 先看效果 Tree 树形控件 保留过滤值的子级 <el-treeclass"filter-t…

宝藏推荐:其实,这些工具一点不输Intercom-

在当今快节奏的商业环境中&#xff0c;提供优质的客户服务和支持是每个企业都追求的目标。为了实现这一目标&#xff0c;许多企业都在寻找可以帮助他们提供高效和个性化客户支持的工具。Intercom是一个非常受欢迎的客户支持工具&#xff0c;但它并不是唯一的选择。在本文中&…

玩转贝启科技BQ3588C开源鸿蒙系统开发板 —— 编译构建及此过程中的踩坑填坑(1)

接前一篇文章&#xff1a;玩转贝启科技BQ3588C开源鸿蒙系统开发板 —— 代码下载&#xff08;2&#xff09; 本文主要参考&#xff1a; BQ3588C_代码下载 上一回完成了代码下载&#xff0c;本回开始进行编译构建。 1. 编译构建 &#xff08;1&#xff09;执行prebuilts 在源…

用 MATLAB 产生单位抽样序列、单位阶跃序列、矩形序列、正弦序列和复指数序列

%% 单位抽样&#xff08;脉冲&#xff09;序列&#xff08;冲激函数&#xff09; % 参数设置 n -10:10; % 定义时间范围 delta (n 0); % 生成单位抽样序列% 绘图 figure; stem(n, delta); title(单位抽样序列); xlabel(n); ylabel(delta[n]);%% 单位阶跃序列 % 参数设置 n …

LanChatRoom局域网聊天室

CPP已经结课&#xff0c;我提交的项目是Qt的入门项目&#xff0c;局域网聊天室LanChatRoom。 这个代码重构了很多遍。第一遍是照着明哥推荐到书&#xff0c;把代码抄了一遍。 但抄下来之后&#xff0c;各种问题&#xff0c;而且是清朝老代码。抄了一遍之后&#xff0c;对代码的…

深度学习 Day23——J3DenseNet算法实战与解析

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制&#x1f680; 文章来源&#xff1a;K同学的学习圈子 文章目录 前言1 我的环境2 pytorch实现DenseNet算法2.1 前期准备2.1.1 引入库2.1.2 设…

JavaScript中alert、prompt 和 confirm区别及使用【通俗易懂】

✨前言✨   本篇文章主要在于&#xff0c;让我们看几个与用户交互的函数&#xff1a;alert&#xff0c;prompt 和confirm的使用及区别 &#x1f352;欢迎点赞 &#x1f44d; 收藏 ⭐留言评论 &#x1f4dd;私信必回哟&#x1f601; &#x1f352;博主将持续更新学习记录收获&…

HarmonyOS页面和自定义组件生命周期

页面和自定义组件生命周期 在开始之前&#xff0c;我们先明确自定义组件和页面的关系&#xff1a; 自定义组件&#xff1a;Component装饰的UI单元&#xff0c;可以组合多个系统组件实现UI的复用。页面&#xff1a;即应用的UI页面。可以由一个或者多个自定义组件组成&#xff…

【算法每日一练]-数论 (保姆级教程 篇2 )#行列式 #甜甜花研究 #约数个数 #模数 #数树 #盒子与球

目录 今日知识点&#xff1a; 辗转相减法化下三角求行列式 组合数动态规划打表 约数个数等于质因数的次方1的乘积 求一个模数 将n个不同的球放入r个不同的盒子&#xff1a;f[i][j]f[i-1][j-1]f[i-1][j]*j 行列式 甜甜花的研究 约数个数 模数 数树 盒子与球 行列…

LinkedList与ArrayList的比较

1.LinkedList 基于双向链表&#xff0c;无需连续内存 随机访问慢&#xff08;要沿着链表遍历&#xff09; 头尾插入删除性能高 占用内存多 2.ArrayList 基于数组&#xff0c;需要连续内存 随机访问快&#xff08;指根据下标访问&#xff09; 尾部插入、删除性能可以&…

西门子宣布SIMATIC S-300停产?找找理想替代品——钡铼BL30x系列工控机

1994年&#xff0c;西门子发布了S7系列的第一批产品&#xff0c;其中包括S7-300。SIMATIC S7的推出也见证了新的现场总线标准Profibus的发布&#xff0c;以及率先使用工业总线来促进自动化设备之间的通信。S7-300 CPU系列的巨大成功也帮助西门子进一步巩固了其全球自动化技术领…

格局初现:京东阿里都瞄准了这个万亿级的大市场

核 心 要 点 ▪ 企业采购有哪些痛点和解决方案&#xff1f;行业的关键赛点是什么&#xff1f; ▪ 现行格局是何情况&#xff1f;代表性玩家各自有何特点&#xff1f; ▪ 未来企业采购将往何处去&#xff1f; 当这样一组数据摆在眼前的时候&#xff0c;你或许会感到难以置…

Gitee

Gitee码云 0. 笔记说明1. Gitee概述2. Gitee和GitHub3. 创建Git远程仓库4. 分享已有项目到Gitee5. 文件恢复和合并6. 文件push或pull冲突7. 添加项目成员 0. 笔记说明 该笔记以IDEA 2023专业版进行操作需提前注册好个人gitee账号安装好IDEA的相关gitee插件或者安装Git Bash软件…

C语言:二分查找查找有序数组中的元素

前言 在我们学习C语言的过程中&#xff0c;如果要查找一个数组当中是否存在某一个元素&#xff0c;我们可能会遍历整个数组&#xff0c;来依次判断是否存在这个函数&#xff0c;但这么做是效率极低的&#xff0c;如果数组中有很多个元素&#xff0c;那么我们要查找半天 二分查…

【CASS精品教程】CASS11计算城镇建筑密度

CASS中可以很方便计算建筑密度。 文章目录 一、建筑密度介绍二、CASS计算建筑密度1. 绘制宗地范围2. 绘制建筑物3. 计算建筑密度三、注意事项一、建筑密度介绍 建筑密度(building density;building coverage ratio),指在一定范围内,建筑物的基底面积总和与占用地面积的比…

CTFshow web入门web128-php特性31

开启环境: 一个新的姿势&#xff0c;当php扩展目录下有php_gettext.dll时&#xff1a; _()是一个函数。 _()gettext() 是gettext()的拓展函数&#xff0c;开启text扩展get_defined_vars — 返回由所有已定义变量所组成的数组。 call_user_func — 把第一个参数作为回调函数调…