原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列4

在这里插入图片描述

文章目录

  • 原型网络进行分类的基本流程
  • 一、原始代码---计算欧氏距离,设计原型网络(计算原型+开始训练)
  • 二、每一行代码的详细解释
  • 总结


原型网络进行分类的基本流程

利用原型网络进行分类,基本流程如下:

1.对于每一个样本使用编码的方式fφ (),学习到每一个样本的编码表示(信息抽取)。
2.学习到每一个样本的编码表示之后,对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。
3.当一个新的数据样本被输入到网络中的时候,对于这个样本使用fφ(),生成其编码表示。
4.计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。
5.在计算出所有的分类之间的距离之后,使用softmax的方式将距离转换成概率的形式。

一、原始代码—计算欧氏距离,设计原型网络(计算原型+开始训练)

def eucli_tensor(x,y):	#计算两个tensor的欧氏距离,用于loss的计算
	return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)

class Protonets(object):
	def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):
		#Ns:支持集数量,Nq:查询集数量,Nc:每次迭代所选类数,log_data:模型和类对应的中心所要储存的位置,step:若trainval==True则读取已训练的第step步的模型和中心,trainval:是否从新开始训练模型
		self.input_shape = input_shape
		self.outDim = outDim
		self.batchSize = 1
		self.Ns = Ns
		self.Nq = Nq
		self.Nc = Nc
		if trainval == False:
			#若训练一个新的模型,初始化CNN和中心点
			self.center = {}
			self.model = CNNnet(input_shape,outDim)
		else:
			#否则加载CNN模型和中心点
			self.center = {}
			self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl')		#'''修改,存储模型的文件名'''
			self.load_center(log_data+'model_center_'+str(step)+'.csv')	#'''修改,存储中心的文件名'''
	
	def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点
		center = 0
		for i in range(self.Ns):
			data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
			data = Variable(torch.from_numpy(data))
			data = self.model(data)[0]	#将查询点嵌入另一个空间
			if i == 0:
				center = data
			else:
				center += data
		center /= self.Ns
		return center
	
	def train(self,labels_data,class_number):	#网络的训练
		#Select class indices for episode
		class_index = list(range(class_number))
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#选20个类
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#从D_set随机取支持集和查询集
			support_set,query_set = self.randomSample(D_set)
			#计算中心点
			self.center[label] = self.compute_center(support_set)
			#将中心和查询集存储在list中
			sample['xc'].append(self.center[label])	#list
			sample['xq'].append(query_set)
		#优化器
		optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
		optimizer.zero_grad()
		protonets_loss = self.loss(sample)
		protonets_loss.backward()
		optimizer.step()

二、每一行代码的详细解释

def eucli_tensor(x, y):
    return -1 * torch.sqrt(torch.sum((x - y) * (x - y))).view(1)

这是一个函数,用于计算两个张量(tensor)之间的欧氏距离(Euclidean Distance)。它通过计算两个张量差的平方和的平方根,并乘以-1。最后通过 view(1) 将结果转换成一个形状为 (1,) 的张量。

class Protonets(object):
    def __init__(self, input_shape, outDim, Ns, Nq, Nc, log_data, step, trainval=False):
        self.input_shape = input_shape
        self.outDim = outDim
        self.batchSize = 1
        self.Ns = Ns
        self.Nq = Nq
        self.Nc = Nc
        if trainval == False:
            self.center = {}
            self.model = CNNnet(input_shape, outDim)
        else:
            self.center = {}
            self.model = torch.load(log_data + 'model_net_' + str(step) + '.pkl')
            self.load_center(log_data + 'model_center_' + str(step) + '.csv')

这是一个 Protonets 类的定义,它有一个构造函数 __init__,用于初始化类的属性。其中的参数含义如下:

  • input_shape:输入数据的形状。
  • outDim:输出维度。
  • Ns:支持集(support set)的数量。
  • Nq:查询集(query set)的数量。
  • Nc:每次迭代所选类别数。
  • log_data:模型和中心的存储位置。
  • step:训练的步数。
  • trainval:是否重新开始训练模型。

根据 trainval 的取值,分为两种情况进行初始化:

  1. trainval=False:表示训练一个新的模型。此时,初始化一个空的中心字典 self.center,并创建一个名为 CNNnet 的模型对象 self.model,其输入形状为 input_shape,输出维度为 outDim
  2. trainval=True:表示加载已经训练好的模型和中心。同样,初始化一个空的中心字典 self.center。然后通过 torch.load 加载之前训练保存的模型文件 log_data + 'model_net_' + str(step) + '.pkl',并将其赋给 self.model。接着调用 load_center 方法加载之前训练保存的中心文件 log_data + 'model_center_' + str(step) + '.csv'

总结

这段代码是一个用于实现 Protonets 算法的类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/161358.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

信号完整性分析基础知识之有损传输线、上升时间衰减和材料特性(十):有损传输线在时域中的表现

如果高频衰减大于低频衰减,随着信号传输,上升时间将会增加。上升时间通常定义为边沿在最终值的 10% 到 90% 之间过渡的时间。这假设信号的边缘轮廓看起来有点高斯分布,中间是最快的斜率区域。对于该波形,10%−90% 的上升时间是有意…

MIB 6.1810实验Xv6 and Unix utilities(4)primes

难度: hard/moderate Write a concurrent prime sieve program for xv6 using pipes and the design illustrated in the picture halfway down this page and the surrounding text. This idea is due to Doug McIlroy, inventor of Unix pipes. Your solution should be in …

让你的Mac体验更便捷,快速启动工具Application Wizard为你助力!

亲爱的Mac用户们,你是否经常感到在繁琐的软件启动过程中浪费了太多时间?你是否希望能够以更快的速度找到并启动你所需的应用程序?如果是的话,那么不要犹豫,让我们来介绍一款强大的软件快速启动工具——Application Wiz…

23年宁波职教中心CTF竞赛-决赛

Web 拳拳组合 进去页面之后查看源码,发现一段注释,写着小明喜欢10的幂次方,那就是10、100、1000、10000 返回页面,在点击红色叉叉的时候抓包,修改count的值为10、100、1000、10000 然后分别获得以下信息 ?count1…

Spring面试题:(八)Spring事务

Spring事务概述 Spring事务基于数据库,基于数据库的事务封装了统一的接口。 编程式事务和声明式事务。 声明式事务分为Xml声明式或者注解声明式 实现事务相关的三个类 事务管理器 事务定义 事务状态 XML声明式事务的使用方法 导入坐标配置目标类配置切面 导入…

JS判断是否存在某个元素(includes、indexOf、find、findeIndex、some)(every 数组内所有值是否相同)

方法一:array.includes(searcElement[,fromIndex]) 此方法判断数组中是否存在某个值,如果存在返回true,否则返回false。 searchElement:需要查找的元素,必选。fromIndex:可选,从该索引处开始查…

浏览器页面被恶意控制时的解决方法

解决360流氓软件控制浏览器页面 提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、接受360安全卫士的好意(尽量不要选)二、拒绝360安全卫士的好意(强烈推荐)第…

【Vue渲染】 条件渲染 | v-if | v-show | 列表渲染 | v-for

目录 前言 v-if和v-show的区别和联系 v-show和v-if如何选择 条件渲染|v-if|v-show v-if v-if v-else v-if v-else-if v-else template v-show 列表渲染|v-for v-for 前言 本文介绍Vue渲染,包含条件渲染v-if和v-show的区别和联系以及列表渲染v-for v-if和…

“腾易视连”构建汽车生态新格局 星选计划赋能创作者价值提升

11月16日,在2023年广州国际车展前夕,以“腾易视连,入局视频号抓住增长新机会”为主题的腾易创作者大会在广州隆重举办。此次大会,邀请行业嘉宾、媒体伙伴、生态伙伴、视频号汽车领域原生达人等共济一堂,结合汽车行业数…

轻量级的资源授权:基于 OAuth 规范

了解 OAuth 感觉 OAuth 太负盛名了,以至于后来在 OIDC 反而难以企及前辈 OAuth。倒是大家谈论比较多的是 JWT(例如https://www.cnblogs.com/lyzg/p/6132801.html),——实际谈 JWT 就是在实现 OIDC,反而 OIDC 大家不怎…

Android跨进程通信,IPC,RPC,Binder系统,C语言应用层调用

文章目录 Android跨进程通信,IPC,RPC,Binder系统,C语言应用层调用()1.概念2.流程3.bctest.c3.1 注册服务,打开binder驱动3.2 获取服务 4.binder_call Android跨进程通信,IPC&#xf…

组件插槽,生命周期,轮播图组件的封装,自定义指令的封装等详解以及axios的卖座案例

3.组件插槽 3-1组件插槽 注意 插槽内容可以访问到父组件的数据作用域,因为插槽内容本身就是在父组件模版中定义的 插槽内容无法访问子组件的数据.vue模版中的表达式只能访问其定义时所处的作用域,这和JavaScript的词法作用域是一致的,换言之: 父组件模版的表达式只能访问父组…

计算机毕业设计选题推荐-掌心办公微信小程序/安卓APP-项目实战

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

模块一、任务一.数据分析概述

一、module1 预测未来-总统大选 样本偏差 二、module2 优化现状-化妆品销售 1、数据分析师从业务类型上划分 2、目标:总销量 达到 目标销量 3、固定基本流程 (1)确定 一、目标值节节升高,是否合理?根据什么定的&…

zabbix-proxy分布式监控

Zabbix是一款开源的企业级网络监控软件,可以监测服务器、网络设备、应用程序等各种资源的状态和性能指标。在大型环境中,如果只有一个Zabbix Server来监控所有的节点,可能会遇到性能瓶颈和数据处理难题。 为了解决这个问题,Zabbi…

爱拖延怎么办?如何改变拖延症?

拖延症是我们日常生活中多见的问题,也是不怎么受重视的问题,大多数人都会认为拖延不是什么大问题,办事拖拉怎么也不可能和心理疾病扯上关系。这里小猫测试网分不同情况来讨论。 偶尔的拖延没什么关系,建议忘掉这种偶然性拖延&…

[C国演义] 第二十一章

第二十一章 最长公共子序列不相交的线 最长公共子序列 力扣链接 单个数组的子序列问题 – dp[i] -- 以nums[i] 为结尾的所有子序列中, xxx xxx. 然后状态转移方程根据 最后一个位置的归属问题进行讨论 两个数组的子序列问题 – 以小见大, 分别分析nums1中的一个区间 和 nums…

山西电力市场日前价格预测【2023-11-19】

1.日前价格预测 预测说明: 如上图所示,预测明日(2023-11-19)山西电力市场全天平均日前电价为591.63元/MWh。其中,最高日前电价为1500.00元/MWh,预计出现在16:45~20:45。最低日前电价为268.57元/MWh&#x…

JAVAEE 初阶 多线程基础(一)

多线程基础 一.线程的概念二.为什么要有线程三.进程和线程的区别和关系四.JAVA的线程和操作系统线程的关系五.第一个多线程程序1.继承Thread类 一.线程的概念 一个线程就是一个 “执行流”. 每个线程之间都可以按照顺讯执行自己的代码. 多个线程之间 “同时” 执行着多份代码 同…