VGG(pytorch)

VGG:达到了传统串型结构深度的极限

学习VGG原理要了解CNN感受野的基础知识

model.py

import torch.nn as nn
import torch

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}


class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()
        #features参数是特征层模型,传入这个参数直接使用构造的特征层模型
        self.features = features
        self.classifier = nn.Sequential(
            
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_features(cfg: list):
    layers = []
    in_channels = 3

    #传入参数cfg是一个列表,遍历参数列表构造VGG特征层
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.ReLU(True)]
            in_channels = v
    return nn.Sequential(*layers)
    #特征层函数返回一个nn.Sequential(*layers),
    #这段代码中的 return nn.Sequential(*layers) 使用了 nn.Sequential 类来创建一个神经网络模型。
    # 在这里,layers 是一个可迭代对象,包含了神经网络模型的各个层或模块。
    #这段代码的作用是封装一个神经网络模型,该模型按照 layers 中层或模块的顺序连接起来,并作为 nn.Sequential 对象返回。

cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}



#在函数定义中的 **kwargs 是一个特殊的参数形式,它允许函数接受任意数量的关键字参数(keyword arguments)。
# 这个参数形式使用了双星号 ** 来表示。
#在上述代码中,**kwargs 的作用是允许函数 vgg() 接受额外的关键字参数,并将这些参数收集到 kwargs 字典中
#如vgg(model_name="vgg16", num_classes=10, pretrained=True) pretrained就是一个**kwargs参数
def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs)
    return model

train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    #定义一个数据加载器用于迭代提取数据
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = './{}Net.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

这里由于训练时间太长,运行了19个epoch中断。结果如下

using cuda:0 device.
Using 8 dataloader workers every process
using 3306 images for training, 364 images for validation.
train epoch[1/30] loss:1.542: 100%|██████████| 104/104 [08:39<00:00,  4.99s/it]
100%|██████████| 12/12 [01:13<00:00,  6.15s/it]
[epoch 1] train_loss: 1.605  val_accuracy: 0.245
train epoch[2/30] loss:1.399: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.12s/it]
[epoch 2] train_loss: 1.476  val_accuracy: 0.401
train epoch[3/30] loss:1.310: 100%|██████████| 104/104 [08:34<00:00,  4.94s/it]
100%|██████████| 12/12 [01:18<00:00,  6.53s/it]
[epoch 3] train_loss: 1.293  val_accuracy: 0.456
train epoch[4/30] loss:0.958: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.11s/it]
[epoch 4] train_loss: 1.185  val_accuracy: 0.519
train epoch[5/30] loss:1.327: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.11s/it]
[epoch 5] train_loss: 1.135  val_accuracy: 0.527
train epoch[6/30] loss:1.209: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.12s/it]
[epoch 6] train_loss: 1.077  val_accuracy: 0.571
train epoch[7/30] loss:0.725: 100%|██████████| 104/104 [1:25:27<00:00, 49.30s/it]
100%|██████████| 12/12 [01:21<00:00,  6.82s/it]
[epoch 7] train_loss: 1.051  val_accuracy: 0.596
train epoch[8/30] loss:1.146: 100%|██████████| 104/104 [08:50<00:00,  5.10s/it]
100%|██████████| 12/12 [01:27<00:00,  7.31s/it]
[epoch 8] train_loss: 1.008  val_accuracy: 0.615
train epoch[9/30] loss:1.381: 100%|██████████| 104/104 [08:48<00:00,  5.08s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 9] train_loss: 0.995  val_accuracy: 0.640
train epoch[10/30] loss:0.466: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 10] train_loss: 0.966  val_accuracy: 0.673
train epoch[11/30] loss:0.867: 100%|██████████| 104/104 [08:33<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.13s/it]
[epoch 11] train_loss: 0.926  val_accuracy: 0.659
train epoch[12/30] loss:0.804: 100%|██████████| 104/104 [08:34<00:00,  4.94s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 12] train_loss: 0.916  val_accuracy: 0.665
train epoch[13/30] loss:0.377: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 13] train_loss: 0.879  val_accuracy: 0.648
train epoch[14/30] loss:0.588: 100%|██████████| 104/104 [08:35<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.16s/it]
[epoch 14] train_loss: 0.841  val_accuracy: 0.676
train epoch[15/30] loss:0.725: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.13s/it]
[epoch 15] train_loss: 0.830  val_accuracy: 0.687
train epoch[16/30] loss:0.977: 100%|██████████| 104/104 [08:35<00:00,  4.96s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 16] train_loss: 0.811  val_accuracy: 0.720
train epoch[17/30] loss:0.923: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.14s/it]
[epoch 17] train_loss: 0.796  val_accuracy: 0.703
train epoch[18/30] loss:1.150: 100%|██████████| 104/104 [08:34<00:00,  4.95s/it]
100%|██████████| 12/12 [01:13<00:00,  6.15s/it]
[epoch 18] train_loss: 0.794  val_accuracy: 0.720
train epoch[19/30] loss:0.866:  19%|█▉        | 20/104 [01:54<07:59,  5.71s/it]

predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)

    # load model weights
    weights_path = "./googleNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
                                                          strict=False)

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

预测结果

class: daisy        prob: 0.00207
class: dandelion    prob: 0.00144
class: roses        prob: 0.101
class: sunflowers   prob: 0.00535
class: tulips       prob: 0.89

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/251469.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

WEB 3D技术 简述React Hook/Class 组件中使用three.js方式

之前 已经讲过了 用vue结合three.js进行开发 那么 自然是少不了react 我们 还是先创建一个文件夹 终端执行 npm init vitelatest输入一下项目名称 然后技术选择 react 也不太清楚大家的基础 那就选择最简单的js 然后 我们就创建完成了 然后 我们用编辑器打开创建好的项目目…

卷积的计算 - im2col 2

卷积的计算 - im2col 2 flyfish import numpy as np np.set_printoptions(linewidth200)# F filter kerneldef im2col(images, kernel_size, stride1, padding0):#process imagesif images.ndim 2:images images.reshape(1, 1, *images.shape)elif images.ndim 3:N, I_…

RK3568平台开发系列讲解(Linux系统篇)如何优化Linux驱动的稳定性和效率

🚀返回专栏总目录 文章目录 一、检测 ioctl 命令二、检测传递地址是否合理三、分支预测优化沉淀、分享、成长,让自己和他人都能有所收获!😄 📢在 Linux 中应用程序运行在用户空间,应用程序错误之后,并不会影响其他程序的运行,而驱动工作在内核层,是内核代码的一部分…

响应者链概述

响应者链 iOS事件的3大类型 Touch Events(触摸事件)Motion Events(运动事件&#xff0c;比如重力感应和摇一摇等)Remote Events(远程事件&#xff0c;比如用耳机上得按键来控制手机) 触摸事件 处理触摸事件的两个步骤 寻找事件的最佳响应者事件的响应在响应链中的传递 寻…

记录 | gpu docker启动报错libnvidia-ml.so.1: file exists: unknown

困扰了两天的问题&#xff0c;记录一下 问题出在启动一个本身已经安装 cuda 的镜像上&#xff0c;具体来说&#xff0c;我是启动地平线天工开物工具链镜像的时候出现的问题&#xff0c;具体报错如下&#xff1a; docker: Error response from daemon: failed to create task …

System作为系统进程陔如何关闭?

一、简介 system进程是不可以关闭的&#xff0c;它是用来运行一些系统命令的&#xff0c;比如reboot、shutdown等&#xff0c;以及用来运行一些后台程序&#xff0c;比如ntfs-3g、v4l2loopback等。system进程也被用于运行一些内核模块&#xff0c;比如nvidia、atd等。system进程…

结构体基础全家桶(1)创建与初始化

目录 结构体概念&#xff1a; 结构体类型&#xff1a; 结构体变量的创建&#xff1a; 定义结构体变量的三种方式&#xff1a; 结构体变量的引用&#xff1a; 结构体变量的初始化: 结构体数组&#xff1a; 结构体数组定义&#xff1a; 结构体数组初始化&#xff1a; 结…

AlexNet(pytorch)

AlexNet是2012年ISLVRC 2012&#xff08;ImageNet Large Scale Visual Recognition Challenge&#xff09;竞赛的冠军网络&#xff0c;分类准确率由传统的 70%提升到 80% 该网络的亮点在于&#xff1a; &#xff08;1&#xff09;首次利用 GPU 进行网络加速训练。 &#xff…

介绍strncpy函数

strncpy函数需要引用#include <string.h>头文件 函数原型&#xff1a; char *_Dest 是字符串的去向 char *_Source是字符串的来源 size_t_Count是复制字符串的大小 #include <stdio.h> #include <string.h> int main() { char arr[128] { \0 }; …

数据结构之排序

目录 ​ 1.常见的排序算法 2.插入排序 直接插入排序 希尔排序 3.交换排序 冒泡排序 快速排序 hoare版本 挖坑法 前后指针法 非递归实现 4.选择排序 直接选择排序 堆排序 5.归并排序 6.排序总结 一起去&#xff0c;更远的远方 1.常见的排序算法 排序&#xff1a;所…

积分球均匀光源遥感器如何保持稳定

积分球的亮度均匀性取决于其内部涂层的反射率和分布情况。当光源通过积分球时&#xff0c;光会被内部的涂层多次反射&#xff0c;最终从出光口均匀地散射出去。为了提高亮度均匀性&#xff0c;可以采用具有高反射率和均匀分布的光源&#xff0c;同时选择合适的涂层材料和涂层厚…

NXP应用随记(五):eMios功能点阅读随记

目录 1、概念点 2、eMios功能点 2.1、eMIOS - Single Action Input Capture (SAIC) 2.2、eMIOS - Single Action Output Compare (SAOC) 2.3、eMIOS - Double Action Output Compare (DAOC) 2.4、eMIOS - Pulse/Edge Counting (PEC) – Single Shot 2.5、eMIOS - Pulse/E…

算法:程序员的数学读书笔记

目录 ​0的故事 ​一、按位计数法 二、不使用按位计数法的罗马数字 三、十进制转二进制​​​​​​​ ​四、0所起到的作用​​​​​​​ 逻辑 一、为何逻辑如此重要 二、兼顾完整性和排他性 三、逻辑 四、德摩根定律 五、真值表 六、文氏图 七、卡诺图 八、逻…

【算法Hot100系列】最长回文子串

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

向华为学习:基于BLM模型的战略规划研讨会实操的详细说明,含研讨表单(二)

上一篇文章&#xff0c;华研荟结合自己的经验和实践&#xff0c;详细介绍了基于BLM模型的战略规划研讨会的设计和组织流程&#xff0c;提高效率的做法。有朋友和我私信沟通说&#xff0c;其实这个流程不单单适合于BLM模型的战略规划研讨会&#xff0c;实际上&#xff0c;使用其…

Linux centos7安装redis 6.2.14 gz并且使用systemctl为开机自启动 / 彻底删除 redis

1.下载 && 减压 wget http://download.redis.io/releases/redis-6.2.14.tar.gz tar -zvxf redis-6.2.14.tar.gz 2.编译&#xff08;分开运行&#xff09; cd redis-6.2.14 make cd src make install 安装目录展示 3.redis.conf 配置更改 daemonize yes supervised s…

【STM32入门】4.1中断基本知识

1.中断概览 在开展红外传感器遮挡计次的实验之前&#xff0c;有必要系统性的了解“中断”的基本知识. 中断是指&#xff1a;在主程序运行过程中&#xff0c;出现了特定的中断触发条件&#xff08;中断源&#xff09;&#xff0c;使得CPU暂停当前正在运行的程序&#xff0c;转…

交友网站的设计与实现(源码+数据库+论文+开题报告+说明文档)

项目描述 临近学期结束&#xff0c;还是毕业设计&#xff0c;你还在做java程序网络编程&#xff0c;期末作业&#xff0c;老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下&#xff0c;你想解决的问…

听力健康“吃”出来

大多数的研究报告都指出&#xff0c;听力下降的最常见原因是年龄和噪音暴露。然而&#xff0c;近年来越来越多的文章开始探讨其他因素对听力的影响。食物不仅是维持人类基本生存的必需品&#xff0c;随着营养学的进步&#xff0c;人们也逐渐认识到食物中的营养与保持健康之间存…

Unity中Shader URP 简介

文章目录 前言一、URP&#xff08;Universal Render Pipeline&#xff09;由名字可知&#xff0c;这是一个 通用的 渲染管线1、Universal&#xff08;通用性&#xff09;2、URP的由来 二、Build-in Render Pipeline&#xff08;内置渲染管线&#xff09;1、LWRP&#xff08;Lig…