无脑入门pytorch系列(一)—— nn.embedding

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

  • 官方定义
  • demo
  • 练习1——改变**embedding_dim**
  • 练习2——index越界
  • 练习3——sequence长度不一致
  • 练习4——改变输入

官方定义

nn.embedding就是一个简单的查找表,存储固定字典和大小的嵌入。

该模块通常用于存储词嵌入并使用索引检索它们。模块的输入是索引列表,输出是相应的词嵌入。

个人理解:

  • nn.embedding就是一个字典映射表,比如它的大小是128,0~127每个位置都存储着一个长度为3的数组,那么我们外部输入的值可以通过index (0~127)映射到每个对应的数组上,所以不管外部的值是如何都能在该nn.embedding中找到对应的数组。想想哈希表,就很好理解了。
  • 既然是映射表,那么外部的输入的值肯定不能超过最大长度,比如128,同时下限也是。

官方的文档如下,torch.nn.embedding:

image-20230802145811801

从官方的定义来看实在是非常复杂,下面看个例子:

demo

下面是一个官方文档给出的例子:

import torch
import torch.nn as nn

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出的结果:

image-20230802150024797

我们一步步理解代码:

  1. 首先,embedding = nn.Embedding(10, 3)即定义一个embedding模块,包含了一个长度为10的张量,每个张量的大小是3。举个例子,[-1.0556, -0.2404, -0.4578]就是一个tensor,那么如何取该tensor?使用下标index去取,注意,理解这点非常重要。
  2. 其次,input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])即输入一个我们需要embedding的变量,输入的每个值最终映射到张量空间中。
  3. 最后,我们发现输出e变成了[2, 4, 3]的张量,那么没有学习过的同学自然是一脸懵逼。我们需要,说说怎么看张量的维度,从最外层的**[]开始,计算里面的独立个体,发现是2;接着从第二维度的[]**开始数,发现是4;依次类推就可以得到张量的维度是[2, 4, 3]。

仍然十分迷茫,但是没关系,我们看看embedding的weight:

embedding.weight

输出:

image-20230802150606779

我们发现embedding.weight是个[10, 3]的向量,那么embedding.weight的值是怎么被我们input取到的呢?
比如index = 1,那么我们取[-1.0556, -0.2404, -0.4578]; index = 2, 取[ 1.3328, 2.5743, -0.7375]; index = 4, 取[-0.0584, -0.6458, 0.8236]。
这时候,聪明的小伙伴已经发现了,这不就刚好对应了e的输入为1/2/4的值吗?只是我们把输入1作为index去embedding.weight取对应的值去填充新的张量e。

所以说,我们待输入的张量[[1,2,4,5],[4,3,2,9]],在经过nn.embedding后,从[2, 4]维度变换为[2, 4, 3],其实就是[2, 4]中的每个值作为索引去nn.embedding中取对应的权重。

练习1——改变embedding_dim

embedding = nn.Embedding(10, 4) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出:

image-20230802152757460

很明显,当embedding是个[10, 4]的张量时,映射出的张量为[2, 4, 4]

练习2——index越界

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4,5],[4,3,2,10]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

报错:IndexError: index out of range in self

输出会报错,那是因为我们的embedding的维度是[10, 3],所以index的取值从0~9,那么我们取10肯定就出现问题了。如果出现对应的问题时,就可以大致猜到输入的值越界了。

练习3——sequence长度不一致

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[1,2,4],[4,3,2,9]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

报错:ValueError: expected sequence of length 3 at dim 1 (got 4)

将第一维[1, 2, 4, 5]减去5变成[1,2,4],出现ValueError: expected sequence of length 3 at dim 1 (got 4)的问题,所以需要每个维度的长度都一致。

练习4——改变输入

embedding = nn.Embedding(10, 3) # an Embedding module containing 10 tensors of size 3
input = torch.LongTensor([[[1,2],[2,3],[4,5],[5,7]],[[4,5],[3,4],[2,3],[8,9]]]) # a batch of 2 samples of 4 indices each
e = embedding(input)
print(e)

输出:

image-20230802153045211

当输入的的维度为[2,4,2]时,经过embedding得到[2,4,2,3]的张量,也是很好理解的。

喜欢的朋友可以点赞三连一下,谢谢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/57199.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

线扫激光算法原理

一:线扫激光算法原理 激光器发出的激光束经准直聚焦后垂直入射到物体表面上,表面的散射光由接收透镜成像于探测器的阵列上。光敏面于接收透镜的光轴垂直。如图: 当被测物体表面移动x,反应到光敏面上像点位移为x’。a为接收透镜到物体的距离(物距),b为接收后主面到成像…

SQL-每日一题【1174. 即时食物配送 II】

题目 配送表: Delivery 如果顾客期望的配送日期和下单日期相同,则该订单称为 「即时订单」,否则称为「计划订单」。 「首次订单」是顾客最早创建的订单。我们保证一个顾客只会有一个「首次订单」。 写一条 SQL 查询语句获取即时订单在所有用户的首次订…

无人驾驶实战-第一课(自动驾驶概述)

在七月算法上报了《无人驾驶实战》课程,老师讲的真好。好记性不如烂笔头,记录一下学习内容。 课程入口,感兴趣的也可以跟着学一下。 ————————————————————————————————————————— 无人驾驶汽车的定义…

HTTP——五、与HTTP协作的Web服务器

HTTP 一、用单台虚拟主机实现多个域名二、通信数据转发程序 :代理、网关、隧道1、代理2、网关3、隧道 三、保存资源的缓存1、缓存的有效期限2、客户端的缓存 一台 Web 服务器可搭建多个独立域名的 Web 网站,也可作为通信路径上的中转服务器提升传输效率。…

windows服务器iis PHP套件出现FastCGI等错误解决方法汇总

如果您的服务器安装了PHP套件,出现了无法打开的情况,请参照如下办法解决: 首先,需要设置IIS允许输出详细的错误信息到浏览器,才好具体分析 错误一: 处理程序“FastCGI”在其模块列表中有一个错误模块“Fast…

应用案例|基于3D视觉的高反光金属管件识别系统解决方案

Part.1 项目背景 在现代制造业中,高反光金属管件的生产以及质量的把控是一个重要的挑战。传统的2D视觉系统常常难以准确地检测和识别高反光金属管件,因为它们的表面特征不够明显,容易受到光照和阴影的干扰。为了应对这个问题,基于…

UE5 c++ 的文件操作(记录备忘)

函数库.h // Fill out your copyright notice in the Description page of Project Settings.#pragma once#include "CoreMinimal.h" #include "Kismet/BlueprintFunctionLibrary.h" #include "Microsoft/AllowMicrosoftPlatformTypes.h" #incl…

windows编译ncnn

官方代码https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-windows-x64-using-visual-studio-community-2017 编译工具 visual studio 2017 一、编译protobuf 1、下载protobuf protobuf-3.11.2:https://github.com/google/protobuf/archive/v3.11…

基于SpringBoot+Vue的在线考试系统设计与实现(源码+LW+部署文档等)

博主介绍: 大家好,我是一名在Java圈混迹十余年的程序员,精通Java编程语言,同时也熟练掌握微信小程序、Python和Android等技术,能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架…

HarmonyOS学习路之方舟开发框架—学习ArkTS语言(状态管理 三)

Link装饰器:父子双向同步 子组件中被Link装饰的变量与其父组件中对应的数据源建立双向数据绑定。 概述 Link装饰的变量与其父组件中的数据源共享相同的值。 装饰器使用规则说明 Link变量装饰器 说明 装饰器参数 无 同步类型 双向同步。 父组件中State, Stor…

Linux常用命令——dpkg-deb命令

在线Linux命令查询工具 dpkg-deb Debian Linux下的软件包管理工具 补充说明 dpkg-deb命令是Debian Linux下的软件包管理工具,它可以对软件包执行打包和解包操作以及提供软件包信息。 语法 dpkg-deb(选项)(参数)选项 -c:显示软件包中的文件列表&am…

springBoot项目导入外部jar包

一、将外部的jar包复制到指定文件夹 二、修改pom文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocati…

专家论道: 唐贤香云纱塑造中国非遗国际品牌

自“香云纱染整技艺”入选第二批国家级非物质文化遗产以来&#xff0c;被誉为纺织界“软黄金”的香云纱&#xff0c;重新焕发青春&#xff0c;频频登上时尚舞台&#xff0c;以不一样的面貌展示在世人面前&#xff0c;成为服装设计师、消费者青睐的材质。而随着北京卫视播出的《…

Win10查询硬盘序列号

添加wmic命令 winR cmd命令 wmic diskdrive get model, serialnumber

java -jar指定外部配置文件

场景 spingboot项目部署jar时,需要时常修改配置,为了方便,将配置文件放到jar包外 操作步骤 在jar包同级目录下创建config文件夹(位置没有强制要求,为了方便而已) 在jar包同级目录下创建start.bat文件,并编辑内容 echo off :: 命令窗口标题 title yudibei_performance_tes…

Spring-1-透彻理解Spring XML的必备知识

学习目标 能够说出Spring的体系结构 能够编写IOC-DI快速入门 思考:为什么学习Spring而不是直接学习SpringBoot 1 Spring介绍 思考:我们为什么要学习Spring框架&#xff1f; 工作上面 Java拥有世界上数量最多的程序员 最多的岗位需求与高额薪资 95%以上服务器端还是要用Jav…

【Apifox】Apifox设置全局Token:

文章目录 一、获取登录Token和设置全局变量&#xff1a;二、设置全局参数&#xff1a;三、效果&#xff1a; 一、获取登录Token和设置全局变量&#xff1a; 二、设置全局参数&#xff1a; 三、效果&#xff1a;

关于会议OA需求分析与开发功能设计

前言&#xff1a;现如今&#xff0c;企业在会议管理方面对OA系统的需求越来越高。因为会议是企业内部沟通和协作的重要环节&#xff0c;一个高效的会议管理系统可以帮助企业提升会议效率、降低成本&#xff0c;并且提高内部信息共享的效果。 目录 一&#xff0c;以下是OA系统在…

活动目录密码更改

定期更改密码是一种健康的习惯&#xff0c;因为它有助于阻止使用被盗凭据的网络攻击&#xff0c;安全专家建议管理员应确保用户使用有效的密码过期策略更改其密码。 管理员可以通过电子邮件通知用户在密码即将过期时更改其密码&#xff0c;但在许多组织中&#xff0c;用户只能…

leetcode 435. 无重叠区间

2023.8.3 本题和引爆气球 这题非常类似&#xff0c;利用同样的思路可以解决&#xff0c;代码如下&#xff1a; class Solution { public:static bool cmp(vector<int>& a , vector<int>& b){if(a[0] b[0]) return a[1] < b[1];return a[0] < b[0];…