self.register_buffer方法使用解析(pytorch)

self.register_buffer就是pytorch框架用来保存不更新参数的方法。

列子如下:

self.register_buffer("position_emb", torch.randn((5, 3)))

第一个参数position_emb传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数torch.randn((5, 3),并一次初始化后保存于模型,不会有梯度传播给它,能被模型的model.state_dict()记录下来,可以理解为模型的常数。当然,你想保留固定值,使用如下代码:

self.register_buffer("position_emb", torch.tensorrt([[2,5],[8,9]]))

进一步探讨训练对该参数是否有影响,答案是:没影响。具体可看下面实现的列子代码:

import torch
from torch.nn import Embedding

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = Embedding(5, 3)
        self.register_buffer("position_emb", torch.randn((5, 3)))
    def forward(self,vec):
        input = torch.tensor([0, 1, 2, 3, 4])
        emb_vec1 = self.emb(input)
        emb_vec1=emb_vec1+self.position_emb
        output = torch.einsum('ik, kj -> ij', emb_vec1, vec)
        return output
def simple_train():
    model = Model()
    vec = torch.randn((3, 1))
    label = torch.Tensor(5, 1).fill_(3)
    loss_fun = torch.nn.MSELoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.015)
    print('初始化后position_emb参数:\n',model.position_emb)
    for iter_num in range(100):
        output = model(vec)
        loss = loss_fun(output, label)
        opt.zero_grad()
        loss.backward(retain_graph=True)
        opt.step()
    print('训练后position_emb参数:\n', model.position_emb)

if __name__ == '__main__':
   simple_train()  # 训练与保存权重

实现结果如下:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/124312.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

虚拟机网络没有有效的ip配置

虚拟机网络没有有效的ip配置: 原因猜测:或许是之前使用的操作系统把网络给占了。 解决方法:点击虚拟机的 遍历->网络编辑器->移除不要的网络,然后添加网络。(下面的图就是我把虚拟网络全部移除,然后…

png怎么转jpg?这款图片转格式工具一学就会用

虽然png图片格式是一种无损压缩格式,但是png图片的内存大小也是比较大的,而且兼容性上也没有jpg图片好,许多平台推荐的也都是jpg格式,所以当我们需要把png转jpg格式的时候,就需要用到图片格式转换器,今天推…

【Qt绘制小猪】以建造者模式绘制小猪

效果 学以致用&#xff0c;使用设计模式之建造者模式绘制小猪。 代码 接口&#xff1a;申明绘制的步骤 PigBuilder.h #ifndef PIGBUILDER_H #define PIGBUILDER_H#include <QObject> #include <QPainter>class PigBuilder : public QObject {Q_OBJECT public:ex…

解密Elasticsearch:深入探究这款搜索和分析引擎

•开篇 最近使用Elasticsearch实现画像系统&#xff0c;实现的dmp的数据中台能力。同时调研了竞品的架构选型。以及重温了redis原理等。特此做一次es的总结和回顾。网上没看到有人用Elasticsearch来完成画像的。我来做第一次尝试。 背景说完&#xff0c;我们先思考一件事&…

计算机中丢失mfc140u.dll怎么解决

mfc140u.dll是一个Microsoft Visual C库文件&#xff0c;主要用于MFC&#xff08;Microsoft Foundation Class&#xff09;应用程序的开发。它包含了MFC应用程序所需的一些常用功能&#xff0c;如对话框、窗口、菜单等。当mfc140u.dll丢失时&#xff0c;可能会导致MFC应用程序无…

Android MotionLayout

MotionLayout exends ConstraintLayout(动画框架 过渡) View动画 API1 属性动画API11 过渡动画API18 root.width RootViewWidth TransitionManager.beginDelayedTransition(view) 过渡动画 可以改变其大小和流畅性 Fade 可以改变透明度 通过TrasitinManager管理 Go:动态替…

vue前端实现多个url下载并合并为zip文件

一、安装 npm install jszip npm install file-saver 二、引入 import axios from axios import JSZip from "jszip"; import FileSaver from "file-saver"; 三、核心代码 videoData:[/video/26519f026fc012521605563015227403.mp4,/video/f7b9cdae14…

数字通信和fpga概述——杜勇版本学习笔记

1数字通信处理流程 脉冲调制是每个数字通信系统中间必不可少的环节&#xff0c;通常是使用升余弦滚降滤波器来实现。 超外差接收机原理是利用本地产生的振荡波与输入信号混频&#xff0c;将输入信号频率变换为某个预先确定的频率的方法。超外差原理最早是由E.H.阿姆斯特朗于1…

2023年云计算发展趋势:生活的智能未来

目录 引言1 智能家居的崭新时代2 无人驾驶的崭新时代3 虚拟现实的扩展与改进4 人工智能的综合应用5 云计算的可持续性结语 引言 时光荏苒&#xff0c;科技的飞速发展已经成为当今社会的标志之一。在这个数字化时代&#xff0c;云计算已经成为推动技术革新和生活方式改变的关键…

软件测试|Python Faker库使用指南

简介 Faker是一个Python库&#xff0c;用于生成虚假&#xff08;假的&#xff09;数据&#xff0c;用于测试、填充数据库、生成模拟数据等目的。它可以快速生成各种类型的虚假数据&#xff0c;如姓名、地址、电子邮件、电话号码、日期等&#xff0c;非常适合在开发和测试过程中…

CSS实现鼠标移至图片上显示遮罩层及文字效果

效果图&#xff1a; 1、将遮罩层html代码与图片放在一个div 我是放在 .proBK里。 <div class"proBK"><img src"../../assets/image/taskPro.png" class"proImg"><div class"imgText"><h5>用户在线发布任务&l…

FreeRTOS学习笔记(二)

一、时间片调度 1、同等优先级任务轮流地享有相同的 CPU 时间(可设置)&#xff0c; 叫时间片&#xff0c;在FreeRTOS中&#xff0c;一个时间片就等于SysTick 中断周期 /* 任务一&#xff0c;实现LED0每500ms翻转一次 */ void task1( void * pvParameters ) {uint32_t task1_n…

Java算法(六):模拟评委打分案例 方法封装抽离实现 程序的节流处理

Java算法&#xff08;六&#xff09; 评委打分 需求&#xff1a; 在编程竞赛中&#xff0c;有 6 个评委为参赛选手打分&#xff0c;分数为 0 - 100 的整数分。 选手的最后得分为&#xff1a;去掉一个最高分和一个最低分后 的 4个评委的平均值。 注意程序的节流 package c…

Spring-循环依赖简述

什么是循环依赖 // A依赖了B class A {public B b; } ​ // B依赖了A class B {public A a; } ​ // 循环依赖 A a new A(); B b new B(); a.b b; b.a a; 对象之间的相互依赖很正常&#xff0c;但是在Spring中由于对象创建要经过Bean的生命周期&#xff0c;所以就有了循环…

【广州华锐互动】气象卫星监测AR互动教学软件为气象学习带来更多乐趣

由VR制作公司广州华锐互动开发的气象卫星监测AR互动教学软件是一款结合了增强现实(AR)技术与气象监测技术的教育软件。它通过直观、互动的方式&#xff0c;帮助学生更好地理解和掌握气象监测的基本知识和技能。本文将从气象卫星监测AR互动教学软件的应用场景、优势分析、实际意…

c#如何把字符串中的指定字符删除

可以使用以下四种方法&#xff1a; 一、使用关键字&#xff1a;Replace public string Replace(char oldChar,char newChar); 在对象中寻找oldChar&#xff0c;如果寻找到&#xff0c;就用newChar将oldChar替换掉。 1、实例代码&#xff1a; 2、执行结果&#xff1a; 二、Rem…

【CesiumJS入门】(11)加载LAS点云数据

前言 最近有两次投递简历以及面试都被问到了是否有三维点云数据处理相关的经验。然而我的岗位都没有和点云相关的工作任务&#xff0c;所以还是得自己加把劲呀。 本篇将从数据获取到加载来简易地介绍一个LAS点云数据的加载。 加载数据 首先&#xff0c;你得有一份LAS格式的…

node插件MongoDB(三)—— 库mongoose 的使用和数据类型

前言 提示&#xff1a;使用mongoose 的前提是你安装了node和 MongoDB。 mongoose 官网文档&#xff1a;http://mongoosejs.net/docs/index.html 文章目录 前言一、安装二、基本使用1. 打开bin目录的mongod.exe文件2. 基本使用的代码&#xff08;连接mongodb 服务&#xff09;3.…

虚拟机复制后,无法ping通问题解决

虚拟机复制后&#xff0c;无法ping通问题解决 可能出现的现象 ssh工具连接不上虚拟机&#xff1b;虚拟机ping不通外网或者ping不通内网其它虚拟机&#xff1b; 原因 原虚拟机和新复制出来的虚拟机的ip地址重复&#xff1b;原虚拟机和新复制出来的虚拟机的MAC地址重复&#…

【wp】2023鹏城杯初赛 Web web1(反序列化漏洞)

考点&#xff1a; 常规的PHP反序列化漏洞双写绕过waf 签到题 源码&#xff1a; <?php show_source(__FILE__); error_reporting(0); class Hacker{private $exp;private $cmd;public function __toString(){call_user_func(system, "cat /flag");} }class A {p…