详细分析Pytorch中的masked_fill基本知识(附Demo)

目录

  • 1. 基本知识
  • 2. Demo

1. 基本知识

基本的原理知识如下:

  1. 输入张量和掩码
    masked_fill 接受两个主要参数:一个输入张量和一个布尔掩码
    掩码的形状必须与输入张量相同,True 表示需要填充的位置,False 表示保持原值

  2. 掩码操作
    在执行 masked_fill 操作时,函数会检查掩码中每个元素的值
    如果掩码对应的位置为 True,则在输出张量中填充指定的值;
    如果为 False,则保留输入张量中对应位置的值

  3. 输出结果
    最终生成的新张量包含了在掩码位置上被替换的值,其余位置保持原样


在代码逻辑上

  1. 创建掩码
    mask 是一个布尔张量,标识了哪些位置需要填充:
[[False, True, False],
 [True, False, True],
 [False, False, True]]
  1. 执行 masked_fill
    当调用 tensor.masked_fill(mask, -1) 时,PyTorch 会遍历掩码中的每个元素:对于 mask 中的每个 True 值,tensor 在对应位置的值会被替换为 -1,对于 False 值,保持原值不变

masked_fill 操作是基于 C/C++ 的实现,因此在处理大规模数据时性能较高。常用于深度学习模型中的数据预处理,比如在填充序列、处理缺失值或标记特定条件的数据时

2. Demo

Demo 1: 基本用法

import torch

# 创建一个 3x3 的张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# 创建一个掩码,标记要填充的位置
mask = torch.tensor([[False, True, False],
                     [True, False, True],
                     [False, False, True]])

# 使用 masked_fill 填充掩码位置为 -1
result = tensor.masked_fill(mask, -1)

print("原始张量:")
print(tensor)
print("\n填充后的张量:")
print(result)

截图如下:

在这里插入图片描述

Demo 2: 与条件结合使用

import torch

# 创建一个随机张量
tensor = torch.randn(3, 3)

# 创建掩码:标记负值的位置
mask = tensor < 0

# 将负值位置填充为 0
result = tensor.masked_fill(mask, 0)

print("原始张量:")
print(tensor)
print("\n填充后的张量 (负值填充为 0):")
print(result)

截图如下:

在这里插入图片描述

Demo 3: 结合计算

import torch

# 创建一个张量
tensor = torch.tensor([[10, 20, 30],
                       [40, 50, 60],
                       [70, 80, 90]])

# 创建掩码:标记大于 50 的位置
mask = tensor > 50

# 用 999 填充大于 50 的位置
result = tensor.masked_fill(mask, 999)

print("原始张量:")
print(tensor)
print("\n填充后的张量 (大于 50 的位置填充为 999):")
print(result)

截图如下:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/901290.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

TikTok运营对IP有什么要求?

TikTok在进行直播带货时&#xff0c;网络环境的配置尤为关键&#xff0c;网络质量直接影响到直播效果&#xff0c;因此选择稳定的IP地址很重要。那么&#xff0c;TikTok直播时该选择什么样的IP地址呢&#xff1f;接下来&#xff0c;我们来深入分析一下。 TikTok对IP地址的要求 …

HBuilder X 中Vue.js基础使用->计算属性的应用(三)

一、通过简单的计算属性&#xff1a;对两数进行加法&#xff0c;减法&#xff0c;乘法&#xff0c;除法运算 <template><div><h1>computed 计算属性</h1><el-input type"text" v-model"numOne" /> <el-input type"t…

容器化核心快速入门

概述 物理机&#xff1a;好比是独立的大船&#xff0c;独立发动机&#xff0c;独立船舱。所有资源共用。运水果的同时就不能运鱼&#xff08; 1964年&#xff09;虚拟机&#xff1a;相当于把大船进行改造&#xff0c;把大船的资源进行独立的拆分&#xff0c;独立的部分都有单独…

【Linux学习工具篇】之vim编辑器和gcc编译器

&#x1f4c3;博客主页&#xff1a; 小镇敲码人 &#x1f49a;代码仓库&#xff0c;欢迎访问 &#x1f680; 欢迎关注&#xff1a;&#x1f44d;点赞 &#x1f442;&#x1f3fd;留言 &#x1f60d;收藏 &#x1f30f; 任尔江湖满血骨&#xff0c;我自踏雪寻梅香。 万千浮云遮碧…

Jmeter使用js对入参使用MD5加密

新增前置处理器JSR223 PreProcessor 注意: 加密的js文件需要放到jmtere的bin目录下,不需要使用给包围,如下图即可(这里不是真实的加密方法,需要自己引入加密算法) 脚本中不要使用let需要使用var 可以先尝试最简单的脚本在使用复杂的脚本 load方法用来加载js文件,不同的jmet…

qt 滚动条 美化

qt QScrollBar 滚动条分为竖直与水平滚动条&#xff0c;两者设置上类似&#xff0c;但也有一些不同&#xff0c;下面主要讲述美化及注意事项。 一、竖直滚动条 竖直滚动条分为7个部分&#xff1a; sub-line、 up-arrow 、sub-page、 hanle、 add-line、 dow-arrow、 add-pag…

SpringBoot最佳实践之 - 项目中统一记录正常和异常日志

1. 前言 此篇博客是本人在实际项目开发工作中的一些总结和感悟。是在特定需求背景下&#xff0c;针对项目中统一记录日志(包括正常和错误日志)需求的实现方式之一&#xff0c;并不是普适的记录日志的解决方案。所以阅读本篇博客的朋友&#xff0c;可以参考此篇博客中记录日志的…

使用JUC包的AtomicXxxFieldUpdater实现更新的原子性

写在前面 本文一起来看下使用JUC包的AtomicXxxxFieldUpdater实现更新的原子性。代码位置如下&#xff1a; 当前有针对int&#xff0c;long&#xff0c;ref三种类型的支持。如果你需要其他类型的支持的话&#xff0c;也可以照葫芦画瓢。 1&#xff1a;例子 1.1&#xff1a;普…

构建中小企业设备管理平台:Spring Boot应用

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

Android 开发 调节声音 SeekBar自定义样式

效果图 xml布局 mipmap/seekbar图片随意一张图都可以&#xff0c;这里我的图就不贴出来了 <SeekBarandroid:id"id/seekBar"android:layout_marginLeft"8dp"android:layout_width"377dp"android:layout_height"8dp"android:layou…

沈阳乐晟睿浩科技有限公司抖音小店领域的强者

在当今数字化浪潮的推动下&#xff0c;电子商务以其便捷性、高效性和广泛的覆盖面&#xff0c;成为了推动经济发展的新引擎。而抖音小店&#xff0c;作为短视频平台上的新兴电商形态&#xff0c;更是凭借其庞大的用户基础、精准的内容推送机制以及独特的购物体验&#xff0c;迅…

方形件组批优化问题

在中国制造 2025 目标背景之下&#xff0c;发展环境保护型、资源节约型的智能制造业已成为制造行业的当务之急。为了应对客户提出的各式各样的产品需求、订单组批难且产品质量 要求高的问题&#xff0c;使用数学模型辅助企业对定制化产品进行组批优化具有重要意义。本文通 过…

2024.7最新子比主题zibll7.9.2开心版源码+授权教程

授权教程&#xff1a; 1.进入宝塔搭建一个站点 绑定 api.zibll.com 域名 并上传 index.php 文件 2.设置伪静态 3.开启SSL证书&#xff0c;找一个能用的域名证书&#xff0c;将密钥(KEY)和证书(PEM格式)复制进去即可 4.在宝塔文件地址栏中输入 /etc 找到 hosts文件并打开&a…

hcia复习篇

计算机网络&#xff1a; 云技术&#xff1a; 云储存---将数据通过计算机网络传输并储存在第三方服务器。&#xff08;百度网盘&#xff09; 云计算---分布式计算。&#xff08;即共享硬件资源&#xff09; 计算机技术&#xff1a; 文字、图片、视频等---抽象文字。 抽象语言…

【Nginx系列】499错误

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Java应用程序的测试覆盖率之设计与实现(一)-- 总体设计

一、背景 作为测试,如何保证开发人员提交上来的代码都被测试覆盖到,是衡量测试质量的一个重要指标。 本系列文章将要说一说,如何搭建一套测试覆盖率的系统。 包括以下内容: jacoco agent采集执行覆盖率数据jacoco climaven集成jacoco:jacoco-maven-pluginant集成jacoco:…

Linux: network: wireshark IO图的一个问题

今天遇到一个问题&#xff0c;发现wireshark画的IO图&#xff0c;前几秒没有数据&#xff0c;但是根据Raw的pcap看&#xff0c;是有包的&#xff0c;这就迷惑了。 经同事提醒&#xff0c;这个IO在设置了多个画图filter的时候&#xff0c;可能导致开始前几秒没有输出。如下图 这…

HexForge:一款用于扩展安全汇编和十六进制视图的IDA插件

关于HexForge HexForge是一款用于扩展安全汇编和十六进制视图的IDA插件&#xff0c;在该工具的帮助下&#xff0c;广大研究人员可以方便地直接从 IDA Pro 界面数据解码、解密或执行安全数据审计任务。 功能介绍 1、从 IDA 的反汇编或十六进制视图复制原始十六进制&#xff1b;…

安康旅游指南:基于SpringBoot的网站开发实践

第一章 绪论 1.1 研究现状 时代的发展&#xff0c;我们迎来了数字化信息时代&#xff0c;它正在渐渐的改变着人们的工作、学习以及娱乐方式。计算机网络&#xff0c;Internet扮演着越来越重要的角色&#xff0c;人们已经离不开网络了&#xff0c;大量的图片、文字、视频冲击着我…

spring整合使用xml方式整合Druid数据源连接池

1.普通的JDBC数据库连接使用 DriverManager 来获取&#xff0c;每次向数据库建立连接的时候都要将 Connection加载到内存中&#xff0c;再验证用户名和密码(得花费0.05s&#xff5e;1s的时间)。需要数据库连接的时候&#xff0c;就向数据库要求 一个&#xff0c;执行完成后再断…