《pytorch深度学习实战》学习笔记第2章

第2章 预训练网络

讨论3种常用的预训练模型:

        1、根据内容对图像进行标记(识别)

        2、从真实图像中生成新图像(GAN)

        3、使用正确的英语句子来描述图像内容(自然语言)

2.1 获取一个预训练好的网络用于图像识别

ImageNet数据集,用于大规模视觉识别挑战赛。

所有预训练好的模型都在TorchVision中。

2.1.1 导入已有的模型

所有模型都在torchvison的models中。导入并查看。

from torchvision import models
dir(models)

输出的是所有torchvison里面集成的模型框架。其中首字母大写的是一些流行的模型小写的名字是快捷函数,返回实例化模型函数

1.1.1 AlexNet模型

实例化AlexNet。

alexnet=models.AlexNet()
alexnet

可以像函数一样调用它。给alexnet输入数据,就会通过正向传播(forward pass)得到输出。比如output=alexnet(input)。由于网络没有初始化,没有经过训练。所以一般先要将模型从头训练或者加载训练好的网络。然后再调用。

1.1.2 Resnet模型

(1)加载在ImageNet数据集上训练好的权重,来实例化ResNet101
resnet=models.resnet101(pretrained=True)
resnet

然后就开始下载,下载完成后查看resnet101的结构。

神经网络由许多模块构成,包含过滤器和非线性函数,fc层结束,输出每个类的分数。

预训练好的模型可以跟函数一样调用,并输入图片实现预测。

(2)定义预处理函数:
from torchvision import transforms
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

预处理包括:图像缩放到256*256像素,围绕中心裁剪到224*224像素,转为张量,归一化处理,使用定义的均值和标准差。

(3)导入图片并进行预处理

导入一张狗的照片并显示

from PIL import Image
img = Image.open('bobby.jpg')
img

待用预处理函数对图片进行预处理。

img_t=preprocess(img)
img_t.shape

输出为一个3维的张量。

给张量前面再增加一个维度。

import torch
batch_t = torch.unsqueeze(img_t,0)#加一个维度,数字0代表增加在第0维前面,如果为1就代表维度1前面
batch_t.shape

输出在第0维前面增加了一个1.

(4)运行模型

在新数据上运行训练过的模型的过程被称为推理(inference),为了推理需要先将网络放到eval模式。执行代码:

resnet.eval()

进行推理:

out=resnet(batch_t)
out

产生了一个1000分类的向量,每个ImageNet对应一个分数。

(5)查看预测结果

加载定义好的ImageNet标签。

with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

需要找出out输出在labels标签中的索引。可以利用max()函数输出张量中最大值以及最大值的索引。代码如下:

_,index=torch.max(out,1)
index

输出的索引不是一个数字,而是一个一维张量。

使用index[0]获得实际的数字作为标签列表的索引,用torch.nn.functional.softmax()将输出归一化到[0,1],然后除以总和。可以求出模型在预测中的置信度。

代码:

percentage = torch.nn.functional.softmax(out,dim=1)[0]*100
labels[index[0]],percentage[index[0]].item()

输出:('golden retriever', 96.29335021972656)

分类结果维金毛犬,置信度为96%。

也可以对预测结果的其它值进行排序输出。比如输出前5个。

_,indices = torch.sort(out,descending=True)
[(labels[idx],percentage[idx].item()) for idx in indices[0][:5]]

输出:

2.2 一个足以以假乱真的预训练模型

GAN是生成式对抗网络(generative adversarial network)的缩写。

cycleGAN是循环生成式对抗网络的缩写,可以将一个领域的图像转换为另一个领域的图像。

2.2.1 将马变为斑马的网络

CycleGAN从ImageNet数据集中提取的马和斑马的数据集进行训练。该网络学习获取一匹或多匹马的图像,并将它们全部变成斑马,图像的其余部分尽可能不被修改。

使用预训练好的CycleGAN将使我们有机会更进一步了解网络是如何实现的,对于本例就是生成器。

(1)以ResNet为例,定义一个ResNetGenerator类。

import torch
import torch.nn as nn

class ResNetBlock(nn.Module): # <1>

    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]

        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x) # <2>
        return out


class ResNetGenerator(nn.Module):

    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> 

        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)

(2)实例化

netG = ResNetGenerator()

权重为随机权重。

(3)将预训练好的权重添加到ReNet Generator中。

model_path='horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

执行后,netG就获得了训练中需要的所有知识。

(4)推理

netG.eval()

输出:

程序是将一匹或多匹马逐像素修改。

导入随机马的图像进行测试。

导入需要的库:

from PIL import Image
from torchvision import transforms

定义预处理函数:

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor()
])

导入马的图片:

img = Image.open('horse.jpg')
img

对图片预处理:

img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t,0)

将变量传递给模型:

batch_out = netG(batch_t)

将生成器的输出转换为图像。

out_t = (batch_out.data.squeeze()+1.0)/2.0
out_img = transforms.ToPILImage()(out_t)
out_img

2.6 练习题

1.将金毛猎犬的图像输入马-斑马模型中。

参考资料:

1. 预训练网络 · 深度学习与PyTorch(中文版) (paper2fox.github.io)

4. PyTorch深度学习 Deep Learning with PyTorch ch.2, p2_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/517393.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c++宏有什么离谱操作?

Boost.Preprocessor确实是一个非常强大而复杂的C宏库&#xff0c;专门用于元编程&#xff0c;即在编译时进行代码生成和变换。我这里有一套编程入门教程&#xff0c;不仅包含了详细的视频讲解&#xff0c;项目实战。如果你渴望学习编程不妨点个关注&#xff0c;给个评论222&…

WordPress建站教程:10步快速搭建个人网站

WordPress是一个广泛使用的内容管理系统&#xff08;CMS&#xff09;&#xff0c;凭借其用户友好的界面和大量可定制的主题和插件&#xff0c;为WordPress 提供了多功能性和灵活性&#xff0c;可用于创建各种类型的网站&#xff0c;包括个人博客、B2B企业网站、B2C外贸网站等&a…

骨架屏:提升用户体验的巧妙技巧

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

CleanMyMac2024拥有优化Mac电脑所需的所有工具,深层清理、恶意软件移除、用CleanMyMac X就够了!

CleanMyMac X 拥有优化 Mac 电脑所需的所有工具&#xff1a;深层清理、恶意软件移除、性能提升…… 优化 Mac 电脑&#xff0c;用 CleanMyMac X 就够了&#xff01; CleanMyMac作为一款专为Mac用户设计的系统清理和优化软件&#xff0c;通过其独特的功能和特性&#xff0c;帮助…

Linux集群(二)集群搭建与keeplived配置

目录 一、集群搭建 1.复制3份Tomcat 2.测试Tomcat 3. 配置nginx.conf配置文件 二、keeplived 1.什么是keeplived 2.keeplived特点 3.下载和安装 3.1下载 3.2安装 3.3配置keeplived.conf配置文件 3.4测试 一、集群搭建 集群的主要目的是解决并发的问题。 1.复制3…

【QT+QGIS跨平台编译】056:【pdal_lazperf+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

点击查看专栏目录 文章目录 一、pdal_lazperf介绍二、pdal下载三、文件分析四、pro文件五、编译实践一、pdal_lazperf介绍 pdal_lazperf 是 PDAL(Point Data Abstraction Library)的一个插件,用于处理点云数据。PDAL 是一个开源的库,用于处理和分析地理空间数据,特别是点云…

详解TCP/IP五层模型

目录 一、什么是TCP五层模型&#xff1f; 二、TCP五层模型的详细内容 1. 应用层 2. 传输层 3. 网络层 4. 数据链路层 5. 物理层 三、网络设备所在分层 封装和分⽤ 三、Java示例 引言&#xff1a; 在网络通信中&#xff0c;TCP/IP协议是至关重要的。为了更好地理解TCP协议的工…

SD-WAN组网,协助企业业务需求灵活调整带宽和网络资源

随着企业网络需求的不断增长和变化&#xff0c;传统的网络架构已经不能满足企业灵活调整带宽和网络资源的需求。而SD-WAN&#xff08;软件定义广域网&#xff09;作为一种新型的网络架构&#xff0c;可以帮助企业实现灵活调整带宽和网络资源&#xff0c;满足不同业务需求。本文…

Cisco Modeling Labs (CML) 2.7 - 网络仿真工具

Cisco Modeling Labs (CML) 2.7 - 网络仿真工具 思科建模实验室 (CML) 请访问原文链接&#xff1a;https://sysin.org/blog/cisco-modeling-labs-2/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 Cisco Modeling Labs 是我们用于网络仿真的首要平台。凭借易…

Delphi编写的图片查看器

UNIT Unit17;INTERFACEUSESWinapi.Windows, Winapi.Messages, System.SysUtils, System.Variants,System.Classes, Vcl.Graphics, Vcl.Controls, Vcl.Forms, Vcl.Dialogs,Vcl.StdCtrls, Vcl.ExtDlgs, Vcl.ExtCtrls, Vcl.Imaging.jpeg; //注意&#xff1a;要加入jpej 否侧浏览图…

Listener

简介 概念&#xff1a;Listeener表示监听器&#xff0c;是JavaWeb三大组件&#xff08;Servlet&#xff0c;Filter&#xff0c;Listener&#xff09;之一 监听器就是可以在application&#xff0c;session&#xff0c;request三个对象创建&#xff0c;销毁或者往其中添加修改…

Python语言在地球科学领域中的应用

Python是功能强大、免费、开源&#xff0c;实现面向对象的编程语言&#xff0c;Python能够运行在Linux、Windows、Macintosh、AIX操作系统上及不同平台&#xff08;x86和arm&#xff09;&#xff0c;Python简洁的语法和对动态输入的支持&#xff0c;再加上解释性语言的本质&…

定义一棵松弛红黑树及其根结点颜色转换后的影响

定义一棵松弛红黑树及其根结点颜色转换后的影响 1. 红黑树的性质2. 松弛红黑树的定义3. 根节点颜色变化的影响4. 伪代码实现5. C语言代码实现6. 结论 在计算机科学中&#xff0c;红黑树是一种自平衡的二叉搜索树&#xff0c;它在许多数据结构和算法问题中都有着广泛的应用。红黑…

LangChain Demo | Agent X ReAct X wikipedia 询问《三体》的主要内容

背景 LangChain学习中&#xff0c;尝试改了一下哈里森和吴恩达课程当中的问题&#xff0c;看看gpt-3.5-turbo在集成了ReAct和wikipedia后&#xff0c;如何回答《三体》的主要内容是什么这个问题&#xff0c;当然&#xff0c;主要是为了回答这问题时LangChain内部发生了什么。所…

DFS:深搜+回溯+剪枝解决矩阵搜索问题

创作不易&#xff0c;感谢三连&#xff01;&#xff01; 一、N皇后 . - 力扣&#xff08;LeetCode&#xff09; class Solution { public:vector<vector<string>> ret;vector<string> path;bool checkcol[9];bool checkdig1[18];bool checkdig2[18];int n…

LabVIEW电动汽车供电设备接触电流测试

LabVIEW电动汽车供电设备接触电流测试 随着电动汽车技术的迅猛发展和普及率的不断提高&#xff0c;电动汽车供电设施的电气安全显得尤为重要。为了优化电动汽车供电设备接触电流的测试方案&#xff0c;设计了一种基于LabVIEW的测试方案&#xff0c;通过平台校准测试和电动汽车…

Stable diffusion 加载扩展列表报错解决方法

项目场景&#xff1a; 在使用Stable diffusion webui时&#xff0c;使用扩展列表出现错误 问题描述 点击loadfrom后&#xff0c;出现加载扩展列表报错 原因分析&#xff1a; 下载的扩展的时候&#xff0c;都是github 的url&#xff0c;需要科学上网&#xff0c;如果不能科学…

P6维护:Oracle P6服务性能优化

前言 本文将介绍如何对ORACLE Primavera P6 EPPM软件进行性能调优&#xff0c;考虑到P6主要采用JAVA语言编制&#xff0c;且其使用的是Weblogic Server应用服务器部署P6各项服务器&#xff0c;其性能优化的原理便是基于其JVM特征参数进行设置 方法一&#xff1a;修改配置文件…

探索前端架构:MVC、MVVM和MVP模式

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

前端三剑客 —— CSS (第六节)

目录 内容回顾&#xff1a; 弹性布局属性介绍 案例演示 商品案例 布局分析 登录案例 网格布局 内容回顾&#xff1a; 变量&#xff1a;定义变量使用 --名称&#xff1a;值&#xff1b; 使用变量&#xff1a; 属性名&#xff1a;var&#xff08;--名称&#xff09;&a…