pytorch-激活函数与GPU加速

目录

  • 1. sigmod和tanh
  • 2. relu
  • 3. Leaky Relu
  • 4. selu
  • 5. softplus
  • 6. GPU加速
  • 7. 使用GPU加速手写数据训练

1. sigmod和tanh

sigmod梯度区间是0~1,当梯度趋近0或者1时会出现梯度弥散的问题。
tanh区间时-1~1,是sigmod经过平移和缩放而得到的,也存在梯度弥散的问题。
在这里插入图片描述

2. relu

relu函数当梯度<0时,梯度是0,梯度>0时梯度是1,不会出现梯度弥散和梯度爆炸,虽然relu函数使用广泛也不易出现梯度弥散和梯度爆炸,但是不代表它不会出现。
在这里插入图片描述

3. Leaky Relu

在梯度<0的时候,不在是等于0而是变成了a*x, a是一个比较小的系数,确保梯度小于0时不再是0
在这里插入图片描述

4. selu

由两部分组成一部分时Relu,另一部分是一个指数函数,从而使得selu在0点变成了连续的。
在这里插入图片描述

5. softplus

时relu的一个连续光滑的版本,在0处变得光滑而连续
在这里插入图片描述
总结:目前用的最大的sigmod、tanh、relu、leakyrelu,其他两种用的较少

6. GPU加速

torch.device(‘cuda:0’)中的cuda:0代表第几块显卡,如果使用CPU那么就是torch.device(‘cpu’)
使用.to(device)就把模块或者数据搬到了GPU上,然而模块和数据是有一些区别的,模块执行.to(device)返回一个reference和不使用初始化是完全一样的属于一个inplace操作,但是data就不一样了,比如:data2=data.to(device),data2和data是完全不一样的,data2是gpu数据,data是cpu数据。
注意:.cuda()方法已经不推荐使用了
在这里插入图片描述

7. 使用GPU加速手写数据训练

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms


batch_size=200
learning_rate=0.01
epochs=10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)






class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x

device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

从代码中可以看到网络、loss函数和数据都搬到了GPU上,激活函数改成了LeakyRelu

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/576248.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

第 3 篇 : Netty离线消息处理(可跳过)

说明 仅是个人的不成熟想法, 未深入研究验证 1. 修改 NettyServerHandler类 package com.hahashou.netty.server.config;import com.alibaba.fastjson.JSON; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHan…

云计算时代:SFP、SFP+、SFP28、QSFP+和QSFP28光纤模块详解

随着数据中心的快速发展和云计算的广泛应用&#xff0c;高速、高效率的光纤网络传输成为关键需求。在众多光纤模块中&#xff0c;SFP、SFP、SFP28、QSFP和QSFP28是最常见的几种类型。本文将为您详细解析这几种光纤模块之间的区别&#xff0c;帮助您更好地了解和选择适合自己需求…

【产品经理修炼之道】- B端产品用户层级与需求优先级

B端的需求和C端有比较大的差异&#xff1a;C端的用户画像&#xff0c;在B端更多是以角色、权力和义务的划分。在这种情况下&#xff0c;我们的需求处理方式也会有所不同。 交互设计其实就是用户的行为设计&#xff0c;既然是围绕用户的行为&#xff0c;那么我们首先得清楚我们的…

flutter 微信输入框 (第二版)

微信的聊天输入框之前实现了一个版本&#xff08;flutter 微信聊天输入框_flutter 聊天输入框-CSDN博客&#xff09;&#xff0c; 但是之前实现的不太优雅。这两天重写了一遍。效果如下&#xff1a; 1.页面拆分 这里我们把 聊天的页面进行 拆分&#xff1a;Scaffold &#xff0…

免费预约即将截止,5月7日上海TCT亚洲3D打印展参观指南,收藏!

进入TCT亚洲展官网&#xff08;网页搜索TCT亚洲展&#xff09;&#xff0c;免费登记预约 2024年TCT亚洲展作为推动增材制造在亚洲市场的业务交流的重要平台&#xff0c;将于2024年5月7日至9日在国家会展中心&#xff08;上海&#xff09;7.1&8.1馆举办&#xff0c;与海内外…

二 SSM整合实操

SSM整合实操 一 依赖管理 数据库准备 mysql8.0.33 CREATE DATABASE mybatis-example;USE mybatis-example;CREATE TABLE t_emp(emp_id INT AUTO_INCREMENT,emp_name CHAR(100),emp_salary DOUBLE(10,5),PRIMARY KEY(emp_id) );INSERT INTO t_emp(emp_name,emp_salary) VALUE…

短视频素材有哪些?短视频素材哪一类最吸引人?

随着视频内容在全球各种媒体和平台上的普及&#xff0c;寻找能够让你的项目脱颖而出的视频素材变得尤为重要。以下视频素材网站各具特色&#xff0c;提供从自然风景到都市快照&#xff0c;从简单背景到复杂动画的多样选择。 1. 蛙学府&#xff08;中国&#xff09; 提供4K高解…

全志ARM-蜂鸣器

操作准备&#xff1a; 1.使Tab键的缩进和批量对齐为4格 在/etc/vim/vimrc 中添加一项配置 set tabstop 4; 也可以再加一行 set nu显示代码的行数 vim的设置&#xff0c;修改/etc/vim/vimrc文件&#xff0c;需要用超级用户权限 /etc/vim/vimrc set shiftwidth4 设置批量对…

VsCode一直连接不上 timed out

前言 前段时间用VsCode连接远程服务器&#xff0c;正常操作后总是连接不上&#xff0c;折磨了半个多小时&#xff0c;后面才知道原来是服务器设置的问题&#xff0c;故记录一下&#xff0c;防止后面的小伙伴也踩坑。 我使用的是阿里云服务器&#xff0c;如果是使用其他平台服务…

web(微博发布案例)

示例&#xff1a; 1、检测空白内容 2、发布内容 html: <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><meta …

vue+element之解决upload组件上传文件失败后仍显示在列表上、自动上传、过滤、findIndex、splice、filter

MENU 前言错误案例(没有用)正确方法结束语 前言 el-upload上传失败后&#xff0c;文件仍显示在列表上。 这个pdf文件上传失败&#xff0c;仍显示在列表&#xff0c;给人错觉是上传成功&#xff0c;所以要把它去掉。 在element中&#xff0c;file-list和v-model:file-list是用于…

苹果一次性开源了8个大模型! 包含模型权重、训练日志和设置,OpenELM全面开源

不以开放性著称的苹果居然同时开源了大模型的权重、训练和评估框架&#xff0c;涵盖训练日志、多个保存点和预训练设置。同时升级计算机视觉工具包 CVNets 为 CoreNet&#xff01;支持 OpenELM&#xff01; ▲图1.由Stable Diffusion3生成。 OpenELM是Apple苹果公司最新推出的…

【产品经理修炼之道】- 如何分析一个产品

新人产品经理面试的时候&#xff0c;常被问到的一个问题是&#xff1a;如何评价一款产品。这个问题&#xff0c;我们可以从五个层级一个模型来解答&#xff0c;看你能分析到哪一层。 初级产品经理面试时&#xff0c;经常会问这样的问题&#xff1a; 1&#xff09;你是最喜欢的…

U盘格式转换GPT格式转回DOS

当前格式 fdisk /dev/sdb# 在 fdisk 提示符下&#xff0c;输入以下命令删除分区&#xff1a; d # 选择要删除的分区编号&#xff08;如 1、2 等&#xff09; w开始转换 [rootnode-24 ~]# fdisk /dev/sdbWelcome to fdisk (util-linux 2.37.4). Changes will remain in memory o…

网络安全实训Day17and18

写在前面 第17和18天都讲的sql注入&#xff0c;故合并 ​​​​​​ 网络空间安全实训-渗透测试 Web渗透 定义 针对Web站点的渗透攻击&#xff0c;以获取网站控制权限为目的 Web渗透的特点 Web技术学习门槛低&#xff0c;更容易实现 Web的普及性决定了Web渗透更容易找到目…

python项目练习-1

获取无忧书城的小说内容&#xff01; import requests # 导入请求包 from lxml import etree # 导入处理xml数据包url https://www.51shucheng.net/wangluo/douluodalu/21750.html book_num 1 # 文章页数 download_urls [] # 定义一个空列表&#xff0c;表示我们下载过小…

提升你的C编程技能:使用cURL下载Kwai视频

概述 本文将介绍如何利用C语言以及cURL库来实现Kwai视频的下载。cURL作为一个功能强大的网络传输工具&#xff0c;能够在C语言环境下轻松地实现数据的传输。我们还将探讨如何运用代理IP技术&#xff0c;提升爬虫的匿名性和效率&#xff0c;以适应Kwai视频平台的发展趋势。 正…

《欢乐钓鱼大师》攻略,钓友入坑必备!

欢迎来到《欢乐钓鱼大师》&#xff01;在这个游戏里&#xff0c;你可以尽情享受垂钓的乐趣&#xff0c;通过不断更换和升级高阶鱼竿&#xff0c;轻松地钓到各种稀有鱼类。因为许多玩家在挑战关卡时遇到了一些困难&#xff0c;所以今天我给大家带来了《欢乐钓鱼大师攻略指南》&a…

自动化机器学习流水线:基于Spring Boot与AI机器学习技术的融合探索

&#x1f9d1; 作者简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

【毕设绝技】基于 SpringCloud 的在线交易平台商城的设计与实现-数据库设计(三)

毕业设计是每个大学生的困扰&#xff0c;让毕设绝技带你走出低谷迎来希望&#xff01; 基于 SpringCloud 的在线交易平台商城的设计与实现 一、数据库设计原则 在系统中&#xff0c;数据库用来保存数据。数据库设计是整个系统的根基和起点&#xff0c;也是系统开发的重要环节…