机器学习——卷积神经网络

卷积神经网络CNN

多层感知机MLP的层数足够,理论上可以用其提取出二位特征,但是毕竟复杂,卷积神经网络就可以更合适的来提取高维的特征。
而卷积其实是一种运算
在这里插入图片描述
二维离散卷积的公式
在这里插入图片描述
可以看成g是一个图像的像素点,f是每个像素点对应的权重,权重越大,重要程度越大,这里的权重f可以根据梯度反向传播的方式训练
在CNN中进行卷积运算的层称为卷积层,层中的权重f被称为卷积核
如果将f进行翻转,得到的参数在位置上是翻转的,对参数数值没有影响。这样的运算称为互相关。

卷积的运算例子

在这里插入图片描述

用卷积神经网络完成图像分类任务

class CNN(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        # 类别数目
        self.num_classes = num_classes
        # Conv2D为二维卷积层,参数依次为
        # in_channels:输入通道
        # out_channels:输出通道,即卷积核个数
        # kernel_size:卷积核大小,默认为正方形
        # padding:填充层数,padding=1表示对输入四周各填充一层,默认填充0
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, 
            kernel_size=3, padding=1)
        # 第二层卷积,输入通道与上一层的输出通道保持一致
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        # 最大池化,kernel_size表示窗口大小,默认为正方形
        self.pooling1 = nn.MaxPool2d(kernel_size=2)
        # 丢弃层,p表示每个位置被置为0的概率
        # 随机丢弃只在训练时开启,在测试时应当关闭
        self.dropout1 = nn.Dropout(p=0.25)
        
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.pooling2 = nn.MaxPool2d(2)
        self.dropout2 = nn.Dropout(0.25)

        # 全连接层,输入维度4096=64*8*8,与上一层的输出一致
        self.fc1 = nn.Linear(4096, 512)
        self.dropout3 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    # 前向传播,将输入按顺序依次通过设置好的层
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pooling1(x)
        x = self.dropout1(x)

        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pooling2(x)
        x = self.dropout2(x)

        # 全连接层之前,将x的形状转为 (batch_size, n)
        x = x.view(len(x), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout3(x)
        x = self.fc2(x)
        return x
#%%
batch_size = 64 # 批量大小
learning_rate = 1e-3 # 学习率
epochs = 5 # 训练轮数
np.random.seed(0)
torch.manual_seed(0)

# 批量生成器
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

model = CNN()
# 使用Adam优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 使用交叉熵损失
criterion = F.cross_entropy

# 开始训练
for epoch in range(epochs):
    losses = 0
    accs = 0
    num = 0
    model.train() # 将模型设置为训练模式,开启dropout
    with tqdm(trainloader) as pbar:
        for data in pbar:
            images, labels = data
            outputs = model(images) # 获取输出
            loss = criterion(outputs, labels) # 计算损失
            # 优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 累积损失
            num += len(labels)
            losses += loss.detach().numpy() * len(labels)
            # 精确度
            accs += (torch.argmax(outputs, dim=-1) \
                == labels).sum().detach().numpy()
            pbar.set_postfix({
                'Epoch': epoch, 
                'Train loss': f'{losses / num:.3f}', 
                'Train acc': f'{accs / num:.3f}'
            })
    
    # 计算模型在测试集上的表现
    losses = 0
    accs = 0
    num = 0
    model.eval() # 将模型设置为评估模式,关闭dropout
    with tqdm(testloader) as pbar:
        for data in pbar:
            images, labels = data
            outputs = model(images)
            loss = criterion(outputs, labels)
            num += len(labels)
            losses += loss.detach().numpy() * len(labels)
            accs += (torch.argmax(outputs, dim=-1) \
                == labels).sum().detach().numpy()
            pbar.set_postfix({
                'Epoch': epoch, 
                'Test loss': f'{losses / num:.3f}', 
                'Test acc': f'{accs / num:.3f}'
            })
# 该工具包中有AlexNet、VGG等多种训练好的CNN网络
from torchvision import models 
import copy

# 定义图像处理方法
transform = transforms.Resize([512, 512]) # 规整图像形状

def loadimg(path):  
    # 加载路径为path的图像,形状为H*W*C
    img = plt.imread(path)
    # 处理图像,注意重排维度使通道维在最前
    img = transform(torch.tensor(img).permute(2, 0, 1))
    # 展示图像
    plt.imshow(img.permute(1, 2, 0).numpy())
    plt.show()
    # 添加batch size维度
    img = img.unsqueeze(0).to(dtype=torch.float32)
    img /= 255 # 将其值从0-255的整数转换为0-1的浮点数
    return img

content_image_path = os.path.join('style_transfer', 'content', '04.jpg')
style_image_path = os.path.join('style_transfer', 'style.jpg')

# 加载内容图像
print('内容图像')
content_img = loadimg(content_image_path)
# 加载风格图像
print('风格图像') 
style_img = loadimg(style_image_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/693964.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

312. 戳气球

题目 有 n 个气球,编号为 0 到 n - 1,每个气球上都标有一个数字,这些数字存在数组 nums 中。 现在要求你戳破所有的气球。戳破第 i 个气球,你可以获得 nums[i - 1] * nums[i] * nums[i 1] 枚硬币。 这里的 i - 1 和 i 1 代表和…

Pytorch 实现目标检测一(Pytorch 23)

一 目标检测和边界框 在图像分类任务中,我们假设图像中只有一个主要物体对象,我们只关注如何识别其类别。然而,很多时候图像里有多个我们感兴趣的目标,我们不仅想知 道它们的类别,还想得到它们在图像中的具体位置。在…

【Python】数据处理:OS目录文件操作

Python的os模块是一个用于与操作系统进行交互的标准库模块。它提供了丰富的功能来处理文件和目录、执行系统命令、获取和设置环境变量等。 工作目录操作 获取当前工作目录 os.getcwd()参数:无返回值:一个字符串,表示当前工作目录的路径。这…

算数运算符与表达式(打印被10整除的数)

打印100以内&#xff08;包含100&#xff09;能被10整除的正整数 #include <stdio.h>#define UPPER 100int main() {int i 1;while (i < UPPER)if (i % 10 0)printf("%d\n", i);return 0; } 自增运算符 i 用于递增变量 i 的值。在 while 循环中&#xf…

Word多级标题编号不连续、一级标题用大写数字二级以下用阿拉伯数字

Word多级标题编号不连续 &#xff1a; 一级标题用大写数字二级以下用阿拉伯数字&#xff1a;

Golang——gRPC与ProtoBuf介绍

一. 安装 1.1 gRPC简介 gRPC由google开发&#xff0c;是一款语言中立&#xff0c;平台中立&#xff0c;开源的远程过程调用系统。gRPC客户端和服务器可以在多种环境中运行和交互&#xff0c;例如用java写一个服务器端&#xff0c;可以用go语言写客户端调用。 1.2 gRPC与Protob…

Gitte的使用(Windows/Linux)

Gitte的使用&#xff08;Windows/Linux&#xff09; 一、Windows上使用Gitte1.下载程序2.在Gitte上创建远程仓库3.连接远程仓库4.推送文件到远程仓库 二、Linux上使用Gitte1.第一次从仓库上传1.1生成公钥1.2配置SSH公钥1.3新建一个仓库1.4配置用户名和邮箱在Linux中1.5创建仓库…

在vscode 中使用npm的问题

当我装了 npm和nodejs后 跑项目在 文件中cmd的话可以直接运行但是在 vscode 中运行的时候就会报一下错误 解决方法就是在 vscode 中吧 power shell换成cmd 来运行就行了

Java——简单图书管理系统

前言&#xff1a; 一、图书管理系统是什么样的&#xff1f;二、准备工作分析有哪些对象&#xff1f;画UML图 三、实现三大模块用户模块书架模块管理操作模块管理员操作有这些普通用户操作有这些 四、Test测试类五、拓展 哈喽&#xff0c;大家好&#xff0c;我是无敌小恐龙。 写…

C++输入输出与IO流

C 输入输出与I/O流 文章目录 C 输入输出与I/O流IO类型与基础特性概念与特性IO状态输出缓冲区 文件输入输出文件模式 string流IO处理中常用的函数及操作符综合练习与demo一、 创建文件并写入二、控制台输入数据并拆分存储三、读写电话簿 IO类型与基础特性 C11标准提供了几种IO处…

string经典题目(C++)

文章目录 前言一、最长回文子串1.题目解析2.算法原理3.代码编写 二、字符串相乘1.题目解析2.算法原理3.代码编写 总结 前言 一、最长回文子串 1.题目解析 给你一个字符串 s&#xff0c;找到 s 中最长的回文子串。 示例 1&#xff1a; 输入&#xff1a;s “babad” 输出&am…

Spring @Transactional 事务注解

一、spring 事务注解 1、实现层(方法上加) import org.springframework.transaction.annotation.Transactional;Transactional(rollbackFor Exception.class)public JsonResult getRtransactional() {// 手动标记事务回滚TransactionAspectSupport.currentTransactionStatus…

# 梯影传媒T6投影仪刷机方法及一些刷机工具链接

梯影传媒T6投影仪刷机方法及一些刷机工具链接 文章目录 梯影传媒T6投影仪刷机方法及一些刷机工具链接1、安装驱动程序2、备份设备rom【boot、system】3、还原我要刷进设备的rom【system】4、打开开发者模式以便于安装apk5、root设备6、更多好链接&#xff1a; 梯影传媒T6使用的…

【嵌入式】波特率9600,发送8个字节需要多少时间,如何计算?

问题&#xff1a; 波特率9600&#xff0c;发送 01 03 00 00 00 04 44 09 (8字节) 需要多少时间&#xff0c;如何计算&#xff1f; 在计算发送数据的时间时&#xff0c;首先要考虑波特率以及每个字符的数据格式。对于波特率9600和标准的UART数据格式&#xff08;1个起始位&…

预期值与实际值对比

编辑实际值和预期值变量 因为在单独的代码当中&#xff0c;我们先定义了变量str&#xff0c;所以在matcher时传入str参数&#xff0c;但当我们要把这串代码写在testrun当中&#xff0c;改下传入的参数&#xff0c;与excel表做连接 匹配的结果是excel表中的expect结果&#xf…

质量小议38 -- 60岁退休的由来

总是要有个标准&#xff0c;质量更是如些。 标准不是固定不变的&#xff0c;与时俱进。 关键词&#xff1a;当时的人均寿命&#xff1b;渐进式 60岁退休。 22大学毕业开始工作&#xff08;当然可能会更早&#xff09;&#xff0c;到60岁退休&#xff0c;要工作38年。 …

从零入手人工智能(2)——搭建开发环境

1.前言 作为一名单片机工程师&#xff0c;想要转型到人工智能开发领域的道路确实充满了挑战与未知。记得当我刚开始这段旅程时&#xff0c;心中充满了迷茫和困惑。面对全新的领域&#xff0c;我既不清楚如何入手&#xff0c;也不知道能用人工智能干什么。正是这些迷茫和困惑&a…

SpringBoot+Vue体育馆管理系统(前后端分离)

技术栈 JavaSpringBootMavenMySQLMyBatisVueShiroElement-UI 角色对应功能 学生管理员 功能截图

(四)React组件、useState

1. 组件 1.1 组件是什么 概念&#xff1a;一个组件就是用户界面的一部分&#xff0c;它可以有自己的逻辑和外观&#xff0c;组件之间可以相互嵌套&#xff0c;也可以复用多次。 组件化开发可以让开发者像搭积木一样构建一个完整的庞大应用 1.2 React组件 在React中&#xf…

java中集合List,Set,Queue,Map

Java SE中的集合框架是一组用于存储和操作对象的类和接口。它提供了丰富的数据结构&#xff0c;可以用于解决各种问题。Java SE中的集合框架包含以下主要类和接口&#xff1a; 一. Collection接口&#xff1a; 是集合框架的根接口&#xff0c;它定义了一些通用的集合操作方法…