深度学习——优化器Optimizer

代码以及详细注释:

import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt

# torch.manual_seed(1)    # reproducible
"""
    超参数
"""
# 学习率
LR = 0.01
# 批大小
BATCH_SIZE = 32
# 轮次
EPOCH = 12

# 造数据
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))

# # plot dataset
# plt.scatter(x.numpy(), y.numpy())
# plt.show()


# put dateset into torch dataset
torch_dataset = Data.TensorDataset(x, y)
# 数据加载器
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)


# default network
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(1, 20)   # hidden layer
        self.predict = torch.nn.Linear(20, 1)   # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.predict(x)             # linear output
        return x

if __name__ == '__main__':
    # 相同的网络结构
    net_SGD         = Net()
    net_Momentum    = Net()
    net_RMSprop     = Net()
    net_Adam        = Net()
    # 将上面的网络集成到这里
    nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]

    # 不同的优化器
    opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)
    opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
    opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
    opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
    # 将上面的优化器集成到这里
    optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]

    # 损失函数
    loss_func = torch.nn.MSELoss()
    losses_his = [[], [], [], []]   # record loss

    # 训练轮次
    for epoch in range(EPOCH):
        print('Epoch: ', epoch)
        # 分批训练
        for step, (b_x, b_y) in enumerate(loader):          # for each training step
            for net, opt, l_his in zip(nets, optimizers, losses_his):
                output = net(b_x)              # get output for every net
                loss = loss_func(output, b_y)  # compute loss for every net
                opt.zero_grad()                # clear gradients for next train
                loss.backward()                # backpropagation, compute gradients
                opt.step()                     # apply gradients
                l_his.append(loss.data)        # loss recoder

    # 绘图
    labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
    for i, l_his in enumerate(losses_his):
        plt.plot(l_his, label=labels[i])
    plt.legend(loc='best')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.ylim((0, 0.2))
    plt.show()

运行结果:

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/37326.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

什么是RPC并实现一个简单的RPC

1. 基本的RPC模型 主要介绍RPC是什么,基本的RPC代码,RPC与REST的区别,gRPC的使用 1.1 基本概念 RPC(Remote Procedure Call)远程过程调用,简单的理解是一个节点请求另一个节点提供的服务本地过程调用&am…

详解Jenkins配置邮件通知

前言 这几天Darren洋在使用Jenkins定时构建jmeter脚本中,要用到邮箱配置,故记录之。 一、Jenkins默认邮箱通知 这里填好smtp服务器地址和邮箱后缀,这样下面的账号就不用加邮箱后缀了。 网易邮箱设置以下我就不说废话文学了,直接上…

智能优化算法——哈里鹰算法(Matlab实现)

目录 1 算法简介 2 算法数学模型 2.1.全局探索阶段 2.2 过渡阶段 2.3.局部开采阶段 3 求解步骤与程序框图 3.1 步骤 3.2 程序框图 4 matlab代码及结果 4.1 代码 4.2 结果 1 算法简介 哈里斯鹰算法(Harris Hawks Optimization,HHO),是由Ali Asghar Heid…

【深度剖析】 快速排序为什么不稳定?!

文章目录 零、前言一、快速排序的步骤原理二、什么是稳定性?三、不稳定的地方在哪里?四、怎么让快速排序变得稳定?1、采用双指针的快速排序A 思路简述B 参考代码 :C 算法分析 2、基于递归的快速排序A 思路简述B 参考代码C 算法分析 3、采用归…

【K8S系列】深入解析K8S调度

序言 做一件事并不难,难的是在于坚持。坚持一下也不难,难的是坚持到底。 文章标记颜色说明: 黄色:重要标题红色:用来标记结论绿色:用来标记论点蓝色:用来标记论点 Kubernetes (k8s) 是一个容器编…

使用docker部署rancher并导入k8s集群

前言:鉴于我已经部署了k8s集群,那就在部署rancher一台用于管理k8s,这是一台单独的虚拟环境,之前在k8s的master节点上进行部署并未成功,有可能端口冲突了,这个问题我并没有深究,如果非要通过修改…

C#使用Chart进行统计,切换不同的图表类型

WindowsForm应用程序中Chart图表控件所属的命名空间: Chart 命名空间: System.Windows.Forms.DataVisualization.Charting 对应的dll路径: C:\Program Files (x86)\Reference Assemblies\Microsoft\Framework\.NETFramework\v4.6.1\Syst…

COT、COT-SC、TOT 大预言模型思考方式||底层逻辑:prompt设定

先讲一下具体缩写的意思 COT-chain of thoughts COT-SC (Self-consistency) Tree of thoughts:Deliberate problem solving with LLM 我理解其实不复杂 1. 最简单的是:直接大白话问一次 (IO) 2. 进阶一点是:思维链,…

PDF转CAD后尺寸如何保持一致?这几种方法可以尝试一下

CAD文件是可编辑的,可以进行修改、添加和删除,这使得在CAD软件中进行编辑更加容易和灵活。这意味着,如果需要对图纸进行修改或者添加新的元素,可以直接在CAD软件中进行操作,而不需要重新制作整个图纸。那么将PDF文件转…

Linux嵌入式项目-智能家居

一、资料下载 二、框架知识 三、MQTT通信协议 1、上位机APP主要工作 1.wait for msg / while(1)订阅等待消息 2.处理消息 客户端创建了两个线程,一个线程用于发布消息,一个线程用于监听订阅消息 (那我的仿真系统也可以啊,一个…

DVDNET A FAST NETWORK FOR DEEP VIDEO DENOISING

DVDNET: A FAST NETWORK FOR DEEP VIDEO DENOISING https://ieeexplore.ieee.org/document/8803136 摘要 现有的最先进视频去噪算法是基于补丁的方法,以往的基于NN的算在其性能上无法与其媲美。但是本文提出NN的视频去噪算法性能要好: 其相比于基于补丁…

Oracle通过函数调用dblink同步表数据方案(全量/增量)

创建对应的包,以方便触发调用 /*包声明*/ CREATE OR REPLACE PACKAGE yjb.pkg_scene_job AS /*创建同步任务*/FUNCTION F_SYNC_DRUG_STOCK RETURN NUMBER;/*同步*/PROCEDURE PRC_SYNC_DRUG_STOCK(RUNJOB VARCHAR2) ; END pkg_scene_job; /*包体*/ CREATE OR REPL…

深入理解netfilter和iptables

目录 Netfilter的设计与实现 内核数据包处理流 netfilter钩子 钩子触发点 NF_HOOK宏与Netfilter裁定 回调函数与优先级 iptables 内核空间模块 xt_table的初始化 ipt_do_table() 复杂度与更新延时 用户态的表,链与规则 conntrack Netfilter(结合iptable…

100种思维模型之安全边际思维模型-92

安全边际, 简而言之即距离某一件糟糕的事件发生,还有多大的空间,安全边际越高,我们就越安全! 安全边际思维模型一个 让生活变得更从容 的 思维模型。 01、何谓安全边际思维模型 一、安全边际思维 安全边际 源于…

ACL 2023 | 持续进化中的语言基础模型

尽管如今的 AI 模型已经具备了理解自然语言的能力,但科研人员并没有停止对模型的不断改善和理论探索。自然语言处理(NLP)领域的技术始终在快速变化和发展当中,酝酿着新的潮流和突破。 NLP 领域的顶级学术会议国际计算语言学年会 …

声网 Agora音视频uniapp插件跑通详解

一、前言 在使用声网SDK做音视频会议开发时, 通过声网官方论坛 了解到,声网是提供uniapp插件的,只是在官方文档中不是很容易找到。 插件地址如下: Agora音视频插件 Agora音视频插件(JS) 本文讲解如何跑通演示示例 二、跑通Demo 2.1 环境安装: 参考: 2. 通过vue-…

vue3+element+sortablejs实现table表格 行列动态拖拽

vue3elementsortablejs实现table动态拖拽 1.第一步我们要安装sortablejs依赖2.在我们需要的组件中引入3.完整代码4.效果 1.第一步我们要安装sortablejs依赖 去博客设置页面,选择一款你喜欢的代码片高亮样式,下面展示同样高亮的 代码片. npm install so…

巩固一下NodeJs

1、初始化(确保当前电脑有node环境) npm init 2、安装express npm i expressnpm i ws文件结构 3、编写相关代码启动node服务(server.js) //导入下列模块,express搭建服务器,fs用来操作文件、ws用来实现webscoket const express require("expr…

Android 使用webView打开网页可以实现自动播放音频

使用webview 自动播放音视频,场景如,流媒体自动部分,音视频通话等。会出现如下问题: 解决方案如下: 配置webview 如下,这样可以自动播放音频。 webView.getSettings().setMediaPlaybackRequiresUserGestur…

原生JS实现图片裁剪功能

功能介绍:图片通过原生input上传,使用canvas进行图片裁剪。 裁剪框限制不允许超出图片范围,图片限制了最大宽高(自行修改要的尺寸),点击确认获取新的base64图片数据 注:fixed布局不适用该方案&…