原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!由于工作量大,准备整8个系列完事,-----系列5

在这里插入图片描述

文章目录

  • 前言
  • 一、原始程序---计算原型,开始训练,计算损失
  • 二、每一行代码的详细解释
    • 2.1 粗略分析
    • 2.2 每一行代码详细分析


前言

承接系列4,此部分属于原型类中的计算原型,开始训练,计算损失函数。


一、原始程序—计算原型,开始训练,计算损失

def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点
		center = 0
		for i in range(self.Ns):
			data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
			data = Variable(torch.from_numpy(data))
			data = self.model(data)[0]	#将查询点嵌入另一个空间
			if i == 0:
				center = data
			else:
				center += data
		center /= self.Ns
		return center
	
	def train(self,labels_data,class_number):	#网络的训练
		#Select class indices for episode
		class_index = list(range(class_number))
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#选20个类
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#从D_set随机取支持集和查询集
			support_set,query_set = self.randomSample(D_set)
			#计算中心点
			self.center[label] = self.compute_center(support_set)
			#将中心和查询集存储在list中
			sample['xc'].append(self.center[label])	#list
			sample['xq'].append(query_set)
		#优化器
		optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
		optimizer.zero_grad()
		protonets_loss = self.loss(sample)
		protonets_loss.backward()
		optimizer.step()
	
	def loss(self,sample):	#自定义loss
		loss_1 = autograd.Variable(torch.FloatTensor([0]))
		for i in range(self.Nc):
			query_dataSet = sample['xq'][i]
			for n in range(self.Nq):
				data = np.reshape(query_dataSet[n], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
				data = Variable(torch.from_numpy(data))
				data = self.model(data)[0]	#将查询点嵌入另一个空间
				#查询点与每个中心点逐个计算欧氏距离
				predict = 0
				for j in range(self.Nc):
					center_j = sample['xc'][j]
					if j == 0:
						predict = eucli_tensor(data,center_j)
					else:
						predict = torch.cat((predict, eucli_tensor(data,center_j)), 0)
				#为loss叠加
				loss_1 += -1*F.log_softmax(predict,dim=0)[i]
		loss_1 /= self.Nq*self.Nc
		return loss_1

二、每一行代码的详细解释

2.1 粗略分析

第一个函数 compute_center(self,data_set) 用于计算支持集中心点的坐标。输入参数 data_set 是一个 numpy 对象,代表支持集。该函数中用了一个 for 循环遍历了每一个支持集中的样本,将其嵌入到另一个空间,并计算其总和来求得所有样本的中心点。最后返回计算出的中心点的坐标。

第二个函数 train(self,labels_data,class_number) 是网络的训练函数。其中 labels_data 是标签数据,class_number 是类别数。首先从 class_number 中随机选取出 Nc 个类,对于每个选出来的类,从其标签数据 D_set 中随机选取出支持集和查询集,并将支持集传入 compute_center() 函数计算中心点。接着将计算出的中心点和查询集存储在样本字典 sample 中。最后使用 Adam 优化器对模型进行优化,并计算损失(调用了 loss 函数),将反向传播得到的梯度更新到模型中。

第三个函数def loss(self,sample)是一个自定义的损失函数,它的作用是计算样本的损失值。在这个损失函数中,使用了欧氏距离和softmax函数。

2.2 每一行代码详细分析

def compute_center(self,data_set): - 这是一个方法,用于计算给定数据集(支持集)的中心点。

2-4. center = 0 - 初始化中心点的变量为0。

5-8. for i in range(self.Ns): - 遍历数据集中的每个数据点。

9-14. 这部分代码将数据集中的每个数据点重塑为适应模型输入的形状,并将其转换为PyTorch的Variable。然后,使用模型将查询点嵌入另一个空间。

if i == 0: - 如果这是第一个数据点,则将查询点设置为中心点。

16-19. 否则,将查询点添加到中心点。

center /= self.Ns - 计算中心点,这是所有数据点的平均值。

return center - 返回计算得到的中心点。

接下来是 train 方法:

23-24. 从给定的标签数据中选择类别索引并随机洗牌。选择特定数量的类别(self.Nc)。

25-30. 对于所选类别中的每一个,从其数据中随机选择支持集和查询集。

31-33. 使用 compute_center 方法计算每个类的中心点,并将其存储在列表中。同时将查询集也存储在列表中。

34-37. 初始化优化器,这里使用Adam优化算法,学习率设置为0.001。然后清空梯度缓存。

38-42. 计算损失函数值,该损失函数是根据自定义的损失函数计算的。然后进行反向传播以计算梯度。

optimizer.step() - 使用优化器更新模型的参数。

最后是自定义的损失函数 loss

45-46. 初始化一个张量 loss_1 为0,它用于累计损失值。

47-52. 对于每个类别(self.Nc),遍历查询集中的每个数据点。对于每个查询点,将其嵌入到另一个空间中,并计算它与每个中心点之间的欧氏距离。

53-57. 将所有的距离组合在一起,并使用softmax函数将其转换为概率值。然后,对于每个查询点,累加其与所有中心点的负对数似然损失值。

loss_1 /= self.Nq*self.Nc - 将损失值除以查询集中的数据点数量和类别数量以获得平均损失值。

return loss_1 - 返回计算得到的损失值。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/161471.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redis持久化机制详解

使用缓存的时候,我们经常需要对内存中的数据进行持久化也就是将内存中的数据写入到硬盘中。大部分原因是为了之后重用数据(比如重启机器、机器故障之后恢复数据),或者是为了做数据同步(比如 Redis 集群的主从节点通过 …

链式队列的基本操作与实现(数据结构与算法)

链队列的表示与实现如下图&#xff1a; 代码如下&#xff1a; #include<iostream> using namespace std;#define MAXQSIZE 100 //最大队列长度 typedef int QElemType; //typedef struct Qnode {QElemType data;struct Qnode* next; }QNode, *QueuePtr; //队列结点类型…

python基础练习题库实验2

题目1 编写一个程序&#xff0c;要求用户输入产品代码、产品名称、产品尺寸和产品价格。 然后使用字符串格式来显示产品信息&#xff0c;就像下面的示例一样。 请注意&#xff0c;价格必须使用两位十进制数字显示。 代码 product_code input("Enter product code: &q…

10-19 HttpServletResponse

相应的对象 web开发模型&#xff1a;基于请求与相应的模型 一问一答的模型 Response对象:响应对象,封装服务器给客户端的相关的信息 顶级接口: ServletResponse 父接口:HttpServletResponse response对象的功能分为以下四种:(都是服务器干的事注意) 设置响应头信息; 发送状态码…

[内存泄漏][PyTorch](create_graph=True)

PyTorch保存计算图导致内存泄漏 1. 内存泄漏定义2. 问题发现背景3. pytorch中关于这个问题的讨论 1. 内存泄漏定义 内存泄漏&#xff08;Memory Leak&#xff09;是指程序中已动态分配的堆内存由于某种原因程序未释放或无法释放&#xff0c;造成系统内存的浪费&#xff0c;导致…

Vite Vue3+Element Plus框架布局

App根组件&#xff1a;框架布局 <template><el-container class"layout-container-demo" style"height: 98vh"><!-- 菜单栏 --><el-aside width"200px"><el-scrollbar><!-- router:是否启用 vue-router 模式。…

4、FFmpeg命令行操作8

生成测试文件 找三个不同的视频每个视频截取10秒内容 ffmpeg -i 沙海02.mp4 -ss 00:05:00 -t 10 -codec copy 1.mp4 ffmpeg -i 复仇者联盟3.mp4 -ss 00:05:00 -t 10 -codec copy 2.mp4 ffmpeg -i 红海行动.mp4 -ss 00:05:00 -t 10 -codec copy 3.mp4 如果音视…

IDEA创建文件添加作者及时间信息

前言 当使用IDEA进行软件开发时&#xff0c;经常需要在代码文件中添加作者和时间信息&#xff0c;以便更好地维护和管理代码。 但是如果每次都手动编辑 以及修改那就有点浪费时间了。 实践 其实我们可以将注释日期 作者 配置到 模板中 同时配置上动态获取内容 例如时间 这样…

记录一些涉及到界的题

文章目录 coppersmith的一些相关知识题1 [N1CTF 2023] e2Wrmup题2 [ACTF 2023] midRSA题3 [qsnctf 2023]浅记一下 coppersmith的一些相关知识 上界 X c e i l ( 1 2 ∗ N β 2 d − ϵ ) X ceil(\frac{1}{2} * N^{\frac{\beta^2}{d} - \epsilon}) Xceil(21​∗Ndβ2​−ϵ) …

【机器学习Python实战】线性回归

&#x1f680;个人主页&#xff1a;为梦而生~ 关注我一起学习吧&#xff01; &#x1f4a1;专栏&#xff1a;机器学习python实战 欢迎订阅&#xff01;后面的内容会越来越有意思~ ⭐内容说明&#xff1a;本专栏主要针对机器学习专栏的基础内容进行python的实现&#xff0c;部分…

ThinkPHP 系列漏洞

目录 2、thinkphp5 sql注入2 3、thinkphp5 sql注入3 4、 thinkphp5 SQL注入4 5、 thinkphp5 sql注入5 6、 thinkphp5 sql注入6 7、thinkphp5 文件包含漏洞 8、ThinkPHP5 RCE 1 9、ThinkPHP5 RCE 2 10、ThinkPHP5 rce3 11、ThinkPHP 5.0.X 反序列化漏洞 12、ThinkPHP…

字符串函数详解

一.字母大小写转换函数. 1.1.tolower 结合cppreference.com 有以下结论&#xff1a; 1.头文件为#include <ctype.h> 2.使用规则为 #include <stdio.h> #include <ctype.h> int main() {char ch A;printf("%c\n",tolower(ch));//大写转换为小…

vscode编写verilog的插件【对齐、自动生成testbench文件】

vscode编写verilog的插件&#xff1a; 插件名称&#xff1a;verilog_testbench,用于自动生成激励文件 安装教程&#xff1a;基于VS Code的Testbench文件自动生成方法——基于VS Code的Verilog编写环境搭建SP_哔哩哔哩_bilibili 优化的方法&#xff1a;https://blog.csdn.net…

数据结构与算法-哈夫曼树与图

&#x1f31e; “永远积极向上&#xff0c;永远豪情满怀&#xff0c;永远热泪盈眶&#xff01;” 哈夫曼树与图 &#x1f388;1.哈夫曼树&#x1f52d;1.1树与二叉树的转换&#x1f52d;1.2森林与二叉树的转换&#x1f52d;1.3哈夫曼树&#x1f50e;1.3.1哈夫曼树的概念&#x…

Web之CSS笔记

Web之HTML、CSS、JS 二、CSS&#xff08;Cascading Style Sheets层叠样式表&#xff09;CSS与HTML的结合方式CSS选择器CSS基本属性CSS伪类DIVCSS轮廓CSS边框盒子模型CSS定位 Web之HTML笔记 二、CSS&#xff08;Cascading Style Sheets层叠样式表&#xff09; Css是种格式化网…

传输层——TCP协议

文章目录 一.TCP协议二.TCP协议格式1.序号与确认序号2.窗口大小3.六个标志位 三.确认应答机制&#xff08;ACK&#xff09;四.超时重传机制五.连接管理机制1.三次握手2.四次挥手 六.流量控制七.滑动窗口八.拥塞控制九.延迟应答十.捎带应答十一.面向字节流十二.粘包问题十三.TCP…

【有源码】基于asp.net的旅游度假村管理系统C#度假村美食住宿一体化平台源码调试 开题 lw ppt

&#x1f495;&#x1f495;作者&#xff1a;计算机源码社 &#x1f495;&#x1f495;个人简介&#xff1a;本人七年开发经验&#xff0c;擅长Java、Python、PHP、.NET、微信小程序、爬虫、大数据等&#xff0c;大家有这一块的问题可以一起交流&#xff01; &#x1f495;&…

OpenAI 解雇了首席执行官 Sam Altman

Sam Altman 已被 OpenAI 解雇&#xff0c;原因是担心他与董事会的沟通和透明度&#xff0c;可能会影响公司的发展。该公司首席技术官 Mira Murati 将担任临时首席执行官&#xff0c;但 OpenAI 可能会从科技行业寻找新的首席执行官来领导未来的产品开发。Altman 的解雇给 OpenAI…

YOLOv8优化与量化(1000+ FPS性能)

YOLO家族又添新成员了&#xff01;作为目标检测领域著名的模型家族&#xff0c;you only look once (YOLO) 推 出新模型的速度可谓是越来越快。就在刚刚过去的1月份&#xff0c;YOLO又推出了最新的YOLOv8模型&#xff0c;其模型结构和架构上的创新以及所提供的性能提升&#xf…