Transformer - Positional Encoding 位置编码 代码实现

Transformer - Positional Encoding 位置编码 代码实现

flyfish

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x +  self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

# 词嵌⼊维度是64维
d_model = 64
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=60

x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)
                           
pe_result = pe(x)

print("pe_result:", pe_result)

绘图

import numpy as np
import matplotlib.pyplot as plt
# 创建⼀张15 x 5⼤⼩的画布
plt.figure(figsize=(15, 5))

pe = PositionalEncoding(d_model, 0, max_len)

y = pe(torch.zeros(1, max_len, d_model))


# 只查看3,4,5,6维的值.
plt.plot(np.arange(max_len), y[0, :, 3:7].data.numpy())

plt.legend(["dim %d"%p for p in [3,4,5,6]])

在这里插入图片描述

register_buffer 的测试

# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MLPNet (nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1 * 28 * 28, 128)
        self.fc2 =nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.dropout1=nn.Dropout2d(0.2)
        self.dropout2=nn.Dropout2d(0.2)
    
        self.tmp = torch.randn(size=(1, 3))
        pe = torch.randn(size=(1, 3))
       
        
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)

print(torch.__version__)

root="mydir/"

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)


criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

epochs = 1
for epoch in range(epochs):
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0

    net.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)
        
        optimizer.zero_grad()
 
        out = net(images)
      
        loss = criterion(out, labels)
       
        train_loss += loss.item()
        train_acc += (out.max(1)[1] == labels).sum().item()
      
        loss.backward()
    
        optimizer.step()
    
        avg_train_loss = train_loss / len(train_loader.dataset)
        avg_train_acc = train_acc / len(train_loader.dataset)

    net.eval()
    with torch.no_grad():
        for (images, labels) in test_loader:
            images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)
            out = net(images)
            loss = criterion(out, labels)
            val_loss += loss.item()
            acc = (out.max(1)[1] == labels).sum()
            val_acc += acc.item()
    avg_val_loss = val_loss / len(test_loader.dataset)
    avg_val_acc = val_acc / len(test_loader.dataset)
    print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'
                   .format(epoch+1, epochs, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))
    
    
    


dir_name = 'output'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)
model_save_path = os.path.join(dir_name, "model.pt")
torch.save(net.state_dict(), model_save_path)

model = MLPNet()
model.load_state_dict(torch.load(model_save_path))


print(model.tmp)
print(model.pe)
# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MLPNet (nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1 * 28 * 28, 128)
        self.fc2 =nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.dropout1=nn.Dropout2d(0.2)
        self.dropout2=nn.Dropout2d(0.2)
    
        self.tmp = torch.randn(size=(1, 3))
        pe = torch.randn(size=(1, 3))
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)


    

dir_name = 'output'
if not os.path.exists(dir_name):
    os.mkdir(dir_name)


model_save_path = os.path.join(dir_name, "model.pt")



model = MLPNet()
model.load_state_dict(torch.load(model_save_path))


print(model.tmp)
print(model.pe)

从模型加载的pe值,从未改变

tensor([[0.0566, 0.8944, 0.0873]])
tensor([[ 0.2529,  0.5227, -0.2610]])
tensor([[ 0.4632, -0.2602, -1.0032]])
tensor([[-0.3486,  0.8183, -1.3838]])
tensor([[ 0.7163,  0.5574, -0.0848]])
tensor([[-0.3415, -0.9013, -1.6136]])
tensor([[ 0.5490,  1.7691, -1.1375]])
tensor([[-0.3486,  0.8183, -1.3838]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/511480.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ArcGis研究区边界提取

ArcGis研究区边界提取 *0* 引言*1* 有的步骤0 引言 GRACE数据处理前要先确定研究范围,而大多情况下所选的研究区都是有特殊意义的,比如常年干旱、经济特区、降水丰富等,这些区域往往有精确的边界,那就要从大的区块中将研究范围抠出来,获取相应坐标,以量化区域重力变化。那…

基于MiniExcel的三种常用导出Excel方法(固定列导出、动态列导出、按模板导出)

为了方便代码编写和测试,把很多代码都放在一个class里面,实际开发根据需要放到对应的目录下即可。 1.使用nuget下载安装miniexcel; 2.编写对应的测试接口,具体代码如下: using Microsoft.AspNetCore.Authorization; using Micr…

etcdctl配置

下载etcdctl压缩文件 wget https://github.com/etcd-io/etcd/releases/download/v3.5.0/etcd-v3.5.0-linux-amd64.tar.gztar -xvf etcd-v3.5.0-linux-amd64.tar.gzsudo mv ./etcd-v3.5.0-linux-amd64/etcdctl /usr/local/bin/ chmod x /usr/local/bin/etcdctl 配置etcdctl 配…

路由器拨号失败解决方法

目录 一、遇到问题 二、测试 三、解决方法 (一)路由器先单插wan口设置 (二)mac地址替换 (三)更改路由器DNS 一、遇到问题 1 .在光猫使用桥接模式,由路由器进行拨号的时候,出现…

网络播放解码器 SIP-7102网络音频解码播放器

网络播放解码器 SIP-7102网络音频解码播放器 SIP-7102是一款支持SIP协议的网络音频解码播放器,具有10/100M以太网接口,其接收网络的音频数据,通过设备的DSP及放大电路,可以直接连接两路15W的扬声器。可以用于公共广播、报警系统、…

CSS 清除浮动

浮动副作用 当元素设置float浮动后,该元素就会脱离文档流并向左/向右浮动。浮动也会产生一些问题,所以在使用的时候还需要清除浮动,这样操作起来还是很麻烦的,在没有浮动的时候就不要乱使用,因为还需要清除掉&#xff…

windows系统kafka小白入门篇——下载安装,环境配置,入门代码书写

目录 1. kafka 下载 2. 修改配置文件 2.1 文件夹内容 2.2 创建一个 data 空文件夹 2.3 修改 zookeeper.properties 配置文件 2.4 修改 server.properties 配置文件 2.5 创建 "zk.cmd" windows脚本文件 2.6 创建 "kfk.cmd" windows脚本文件 3. 启动…

docker搭建CI/CD环境配置过程中的常见问题

一、Jenkins 1、pull镜像问题 docker pull jenkins/jenkins:lts Using default tag: latest Trying to pull repository docker.io/library/centos ... Get https://registry-1.docker.io/v2/library/centos/manifests/latest: Get https://auth.docker.io/token?scoperepo…

小剧场短剧影视小程序源码,附带系统搭建教程

安装教程 linux/win任选 PHP版本:7.3/7.2(测试时我用的7.2要安装sg扩展 不会的加QQ295526639) 批量替换域名http://video.owoii.com更换为你的 批量替换域名http://120.79.77.163:1更换为你的 这两个都替换你的 /extend/yzf/lib/epay.config.…

论文笔记 - :DIGGING INTO OUTPUT REPRESENTATION FOR MONOCULAR 3D OBJECT DETECTION

Title: 深入研究单目 3D 物体检测的输出表示 Abstract 单目 3D 对象检测旨在从单个图像中识别和定位 3D 空间中的对象。最近的研究取得了显着的进展,而所有这些研究都遵循基于 LiDAR 的 3D 检测中的典型输出表示。 然而,在本文中,我们认为…

分享OpenTiny总结VUE目录结构

OpenTiny总结VUE目录结构,在设计规范和基础组件的基础上, 继续向上构建,提炼出典型模板/业务组件/配套设计资源,进一步提升企业级中后台产品设计研发过程中的『用户』和『设计者』的体验。 Vue 项目目录结构 ├── node_modules…

js实现websocket断线重连功能

在项目开发中我们可能经常要使用websocket技术,当连接发生断线后,如果不进行页面刷新将不能正常接收来自服务端的推送消息。为了有效避免这种问题,我们需要在客户端做断线重连处理。当网络或服务出现问题后,客户端会不断检测网络状…

ubuntu-server部署hive-part3-安装mysql

参照 https://blog.csdn.net/qq_41946216/article/details/134345137 操作系统版本:ubuntu-server-22.04.3 虚拟机:virtualbox7.0 部署mysql 下载上传 下载地址 https://downloads.mysql.com/archives/community/ 以root用户上传,/usr/loc…

Transformer模型-softmax的简明介绍

今天介绍transformer模型的softmax softmax的定义和目的: softmax:常用于神经网络的输出层,以将原始的输出值转化为概率分布,从而使得每个类别的概率值在0到1之间,并且所有类别的概率之和为1。这使得Softmax函数特别适…

利用IP地址判断羊毛用户:IP数据云提供IP风险画像

在当今数字化社会,互联网已经成为人们日常生活和商业活动中不可或缺的一部分。然而,随着网络的普及,网络欺诈行为也日益猖獗,其中包括了羊毛党这一群体。羊毛党指的是利用各种手段获取利益、奖励或者优惠而频繁刷取优惠券、注册账…

微信小程序自定义弹窗组件

业务背景&#xff1a;弹窗有时字体较多&#xff0c;超过7个字&#xff0c;不适用wx.showToast. 组件代码 <view class"toast-box {{isShow? show:}}" animation"{{animationData}}"><view class"toast-content" ><view class&q…

【FTP,EMail】

文章目录 FTPFTP&#xff1a;文件传输协议FTP: 控制连接与数据连接分开FTP命令、响应 EMail电子邮件&#xff08;EMail&#xff09;邮件服务器EMail: SMTP [RFC 2821]SMTP&#xff1a;总结 FTP FTP&#xff1a;文件传输协议 向远程主机上传输文件或从远程主机接收文件。客户/服…

抖音视频关键词批量下载工具|视频爬虫采集软件

抖音视频批量提取工具&#xff0c;搜索即下载&#xff0c;轻松获取所需视频&#xff01; 正文&#xff1a; 想要轻松获取抖音上的精彩视频吗&#xff1f;现在&#xff0c;有了我们的抖音视频批量提取工具&#xff0c;一切变得简单易行&#xff01;Q:290615413无论是针对特定关…

美摄科技AI智能图像矫正解决方案

图像已经成为了企业传播信息、展示产品的重要媒介&#xff0c;在日常拍摄过程中&#xff0c;由于摄影技巧的限制和拍摄环境的复杂多变&#xff0c;许多企业面临着图像内容倾斜、构图效果不佳等挑战&#xff0c;这无疑给企业的形象展示和信息传递带来了不小的困扰。 美摄科技深…

55、美国德克萨斯大学奥斯汀分校、钱德拉家族电气与计算机工程系:通过迁移学习解决BCI个体差异性[不得不说,看技术还得是老美]

2024年2月5日跨被试最新文章&#xff1a; 德州州立大学奥斯汀分校研究团队最近的一项研究成果&#xff0c;通过非侵入式的脑机接口&#xff0c;可以让被试不需要任何校准就可以使用脑机接口设备&#xff0c;这意味着脑机接口具备了大规模被使用的潜力。 一般来说&#xff0c;…