SelfAttention|自注意力机制ms简单实现

自注意力机制学习有感

  • 观看b站博主的讲解视频以及跟着他的pytorch代码实现mindspore的自注意力机制:
  • up主讲的很好,推荐入门自注意力机制。
import mindspore as ms
import mindspore.nn as nn
from mindspore import Parameter
from mindspore import context
context.set_context(device_target='Ascend',max_device_memory='1GB') 

class SelfAttention(nn.Cell):
    def __init__(self, dim):
        super(SelfAttention, self).__init__()
        wq_data = [[1.0, 0], [1., 1.]] # wq权重初始化 超参数
        wk_data = [[0., 1.], [1., 1.]] # wk权重初始化 超参数
        wv_data = [[0., 1., 1.], [1., 0., 0.]] # wv权重初始化 超参数
        
        self.q = nn.Dense(in_channels=dim, out_channels=2, has_bias=False)
        self.q.weight.set_data(ms.Tensor(wq_data).T)
        print("wq value:", self.q.weight.value())
        
        self.k = nn.Dense(in_channels = dim, out_channels=2, has_bias=False)
        self.k.weight.set_data(ms.Tensor(wk_data).T)
        print('wk value:', self.k.weight.value())
        
        self.v = nn.Dense(in_channels=dim, out_channels=3, has_bias=False)
        # print(self.v.weight.shape)
        self.v.weight.set_data(ms.Tensor(wv_data).T)
        print('wv value:',self.v.weight.value())
        print("*********************" * 2)
        
    def construct(self, x):
        q = self.q(x)
        print('q value:', q)
        k = self.k(x)
        print('k value:', k)
        v = self.v(x)
        # xx = x.matmul(ms.Tensor([[0., 1., 1.], [1., 0., 0.]]))
        print('v value:', v, '\n')
        print('#################################')
        x = (q @ k.T)/ms.ops.sqrt(ms.tensor(2.))
        x = ms.ops.softmax(x) @ v
        print("result:", x)
        

x = [[1., 1.],[1,0],[2,1],[0, 2.]]
x = ms.Tensor(x)
attn = SelfAttention(2)
attn(x)

结果如下:

wq value: [[1. 1.]
 [0. 1.]]
wk value: [[0. 1.]
 [1. 1.]]
wv value: [[0. 1.]
 [1. 0.]
 [1. 0.]]
******************************************
q value: [[2. 1.]
 [1. 0.]
 [3. 1.]
 [2. 2.]]
k value: [[1. 2.]
 [0. 1.]
 [1. 3.]
 [2. 2.]]
v value: [[1. 1. 1.]
 [0. 1. 1.]
 [1. 2. 2.]
 [2. 0. 0.]] 

#################################
result: [[1.5499581  0.71284014 0.71284014]
 [1.3395231  0.7726004  0.7726004 ]
 [1.7247156  0.4475609  0.4475609 ]
 [1.4366053  1.         1.        ]]

** 吐槽mindspore说明文档,对ms.nn.Dense的说明太过简单了,有对新手真不友好(对我) **

  • pytorch的文档:
    在这里插入图片描述
  • mindspore的文档:
    在这里插入图片描述
    pytorch有公式,至少提示A的转置有提示。mindspore没有,导致我这步实现的时候输出的结果不对,还是希望mindspore说明问昂也把公式写清楚点。其实mindspore的Dense和pytorch的Linear的公式实现是一样的。
    附上pytorch的实现:
#@title Default title text 
import torch
import torch_npu
import torch.nn as nn
class Self_Attention(torch.nn.Module):
    def __init__(self, dim):
        super(Self_Attention, self).__init__() #  其中qkv代表构建好训练好的wq,wk,wv的权重参数;
        self.scale = 2 ** -0.5
        self.q = torch.nn.Linear(dim, 2, bias=False) 
        q_list = [[1., 0.],[1., 1.]]
        self.q.weight.data = torch.Tensor(q_list).T
        print('q value:', self.q.weight.data)
        
        self.k = nn.Linear(dim, 2, bias=False)
        
        k_list = [[0., 1.], [1., 1.]]
        self.k.weight.data = torch.Tensor(k_list).T
        print('k value:', self.k.weight.data)
        
        self.v = nn.Linear(dim,3,bias=False)
        v_list = [[0., 1., 1.],[1., 0., 0.]]
        
        # print("origin shape:", self.v.weight.data.shape)
        
        self.v.weight.data = torch.Tensor(v_list).T
        print('init shape:',self.v.weight.data)
        
    def forward(self, x):
        q = self.q(x)  # 通过训练好的参数生成q参数
        print("q:", q)
        
        k = self.k(x)
        print("k:", k)
        
        v = self.v(x)
        print("v shape:", v.shape)
        
        # Att公式
        attn = (q.matmul(k.T)) / torch.sqrt(torch.tensor(2.0))
        print("attn1:", attn)
        
        # attn = (q @ k.transpose(-2, -1)) / torch.sqrt(torch.tensor(2.0))
        # print("attn11:", attn)
        # attn = (q @ k.transpose(-2, -1)) * self.scale
        # print("attn2:", attn)
        attn = attn.softmax(dim=-1)
        print("softmax attn:", attn)
        # print(attn.shape) # shape[4,4]
        x = attn @ v
        print(x.shape)  #shape[4,3]
        return x 
x = [[1., 1.],[1,0],[2,1],[0, 2.]]
x = torch.Tensor(x)
att = Self_Attention(2)  
att(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/387773.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 0987.二叉树的垂序遍历:遍历时存节点信息,遍历完自定义排序

【LetMeFly】987.二叉树的垂序遍历:遍历时存节点信息,遍历完自定义排序 力扣题目链接:https://leetcode.cn/problems/vertical-order-traversal-of-a-binary-tree/ 给你二叉树的根结点 root ,请你设计算法计算二叉树的 垂序遍历…

前端秘法进阶篇之事件循环

目录 一.浏览器的进程模型 1.进程 2.线程 二.浏览器的进程和线程 1. 浏览器进程 2. 网络进程 3. 渲染进程 三.渲染主线程 四.异步 五.优先级 1. 延时队列: 2.交互队列: 3.微队列: 六.JS 的事件循环 附加:JS 中的计时器能做到精…

XMall 开源商城 SQL注入漏洞复现(CVE-2024-24112)

0x01 产品简介 XMall 开源电商商城 是开发者Exrick的一款基于SOA架构的分布式电商购物商城 前后端分离 前台商城:Vue全家桶 后台管理:Dubbo/SSM/Elasticsearch/Redis/MySQL/ActiveMQ/Shiro/Zookeeper等。 0x02 漏洞概述 XMall 开源商城 /item/list、/item/listSearch、/sys/…

【Android】使用Android Studio打包APK文件

文章目录 1. 新建项目2. 打包生成APK3. 安装APK 1. 新建项目 打包APK之前,首先需要新建项目,有基础的可以跳过。 无基础的可以参考:使用Android Studio运行Hello World项目 2. 打包生成APK 1.找到Build -> Generate Signed Bundle or …

【C/C++语法基础】2.输入与输出(✨新手推荐阅读)

前言 在C中,输入与输出是程序与用户进行交互的基本方式。C提供了多种方式进行数据的输入与输出,其中最常用的是printf、scanf、cin和cout。此外,我们还会讨论如何取消cin和cout的同步流,以及了解各种转义字符的用法。 1.printf函…

arkTS开发鸿蒙OS个人商城案例【2024最新 新年限定开发案例QAQ】

龙年前述 源码获取>文章下方二维码,回复关键字“鸿蒙OS商场源码” 前言 arkTS是华为自己研发的一套前端语言,是在js和ts技术的基础上又进行了升级而成! 本篇文章会带领大家通过arkTSnode.jsmongoDB来完成一个鸿蒙OS版本的商城案例&…

flask cors 跨域问题解决

座右铭:怎么简单怎么来,以实现功能为主。 欢迎大家关注公众号与我交流 环境安装 pip install -U flask-cors 示例代码 from flask import Flask from flask_cors import CORS, cross_originapp Flask(__name__) CORS(app, supports_credentialsTrue)…

__attribute__ ---Compile

Section for attribute attribute_?嵌入式C代码属性怎么定义 https://www.elecfans.com/d/2269222.html section 属性的主要作用是:在程序编译时,将一个函数或者变量放到指定的段,即指定的section 中。 一个可执行文件注意由代…

AI算法初识之分类汇总

一、背景 AI算法的分类方式多种多样,可以根据不同的学习机制、功能用途以及模型结构进行划分。以下是一些主要的分类方式及相应的代表性算法: 1. 按照学习类型 - **监督学习**: - 线性回归(Linear Regression) …

学会如何备份u盘数据,让数据安全有保障

随着科技的发展,U盘已成为我们日常生活和工作中不可或缺的数据存储设备。然而,无论U盘的质量如何,数据丢失的风险始终存在。可能是硬件故障、意外删除、病毒感染或其他不可预见的原因。 尽管当前提供了多种数据恢复方案,然而没有一…

【Midjourney】解密Midjourney付费订阅:畅享全新体验!(详细流程与各版本一览)

一、Midjourney 付费订阅流程 1、在首页点击Purchase plan 2、进入到midjourney年月选择页面 3、这里续费一个最便宜的版本 , 按年付费 8 , 按月 10 4、输入银行卡信息 , 用的WildCard虚拟信用卡 ,打开 5、填写完银行卡信息就订阅成功 二、Midjourney 各版本介绍…

山西电力市场日前价格预测【2024-02-12】

日前价格预测 预测说明: 如上图所示,预测明日(2024-02-12)山西电力市场全天平均日前电价为127.42元/MWh。其中,最高日前电价为369.24元/MWh,预计出现在18:45。最低日前电价为0.00元/MWh,预计出…

QT 菜单栏

添加/删除菜单栏 默认情况下QMainWindow项目一创建就自带了菜单栏,可以在对象树窗口中,右键菜单栏对象,移除菜单栏: 删除后也可以创建菜单栏,此时在对象树中右键MainWindow对象,菜单里边会多了创建菜单栏的…

[OPEN SQL] 新增数据

INSERT语句用于数据的新增操作 本次操作使用的数据库表为SCUSTOM&#xff0c;其字段内容如下所示 航班用户(SCUSTOM) 该数据库表中的部分值如下所示 1.插入单条数据 语法格式 INSERT <dbtab> FROM <wa>. INSERT INTO <dbtab> VALUES <wa>. INSERT &…

Hive的相关概念——分区表、分桶表

目录 一、Hive分区表 1.1 分区表的概念 1.2 分区表的创建 1.3 分区表数据加载及查询 1.3.1 静态分区 1.3.2 动态分区 1.4 分区表的本质及使用 1.5 分区表的注意事项 1.6 多重分区表 二、Hive分桶表 2.1 分桶表的概念 2.2 分桶表的创建 2.3 分桶表的数据加载 2.4 …

数据库第一次实验

目录 1 实验内容 2 SQL代码 3 效果截图 1 实验内容 熟悉SQL实验环境配置和进行实验数据准备&#xff0c;用SQL Server、PostgreSQL或MySQL创建数据库&#xff0c; 并按照下列关系模式定义数据表&#xff0c;加入适当约束&#xff1a; 学生&#xff08;学号、姓名、性别、…

free pascal:fpwebview 组件通过JSBridge调用本机TTS

从 https://github.com/PierceNg/fpwebview 下载 fpwebview-master.zip 简单易用。 先请看 \fpwebview-master\README.md cd \lazarus\projects\fpwebview-master\demo\js_bidir 学习 js_bidir.lpr &#xff0c;编写 js_bind_speak.lpr 如下&#xff0c;通过JSBridge调用本机…

在中国做 DePIN?你需要明白风险与机遇

撰文&#xff1a;肖飒团队 来源Techub News专栏作者 随着科技的发展&#xff0c;我们正在日益进入一个资源相对过剩的时代&#xff0c;这使我们在日常生活中虽然支付了该部分资源的使用费&#xff0c;但却时常不能将其「物尽其用」&#xff0c;难免出现资源浪费。例如&#x…

秒懂百科,C++如此简单丨第十九天:动态规划

目录 动态规划的初步理解 求最短路径数 洛谷 P1002 过河卒 题目描述 输入样例 输出样例 思路 AC Code 动态规划的初步理解 什么是动态规划&#xff1f;最直白的理解就是动态的规划。 那高级一点的理解呢&#xff1f;就是每时每刻都拿着一个小本本&#xff0c;也就是…

模型 人货场

系列文章 主要是 分享 思维模型&#xff0c;涉及各个领域&#xff0c;重在提升认知。连接消费者与商品的桥梁。 1 ”人货场“模型的应用 1.1 以抖音直播电商为背景的人货场应用-小杨哥的带货奇迹 小杨哥&#xff0c;一位知名的抖音主播&#xff0c;以其幽默风趣的直播风格和独…