原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列8

在这里插入图片描述

文章目录

  • 前言
  • 一、原始代码
  • 二、对每一行代码的解释:
  • 总结


前言

这是该系列原型网络的最后一段代码及其详细解释,感谢各位的阅读!


一、原始代码

if __name__ == '__main__':
    ##载入数据
    labels_trainData, labels_testData = load_data()  # labels_trainData是字典,是key:value形式
    class_number_train = max(list(labels_trainData.keys())) #963
    class_number_test = max(list(labels_testData.keys())) #658

    wide = labels_trainData[0][0].shape[0]  # 105      #二维张量,shape[0]代表行数,shape[1]代表列数
    length = labels_trainData[0][0].shape[1]  # 105

    for label in labels_trainData.keys():
        labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])

    for label in labels_testData.keys():
        labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])

    ##初始化模型
    protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)  # '''根据需求修改类的初始化参数,参数含义见protonets_net.py'''

    ##训练prototypical_network
    for n in range(100):  ##随机选取x个类进行一个episode的训练
        protonets.train(labels_trainData, class_number_train)
        if n % 2 == 0 and n != 0:  # 每5次存储一次模型,并测试模型的准确率,训练集的准确率和测试集的准确率被存储在model_step_eval.txt中
            torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')
            protonets.save_center('./log/model_center_' + str(n) + '.csv')
            test_accury = protonets.evaluation_model(labels_testData, class_number_test)
            print(test_accury)
            str_data = str(n) + ',' + str('       test_accury     ') + str(test_accury) + '\n'
            with open('./log/model_step_eval.txt', "a") as f:
                f.write(str_data)
        print(n)

二、对每一行代码的解释:

  1. if __name__ == '__main__':
    这是一个Python的惯用写法,表示当脚本直接被运行时(而不是被作为模块导入时),才会执行下面的代码块。

  2. labels_trainData, labels_testData = load_data()
    调用 load_data() 函数加载数据,并将返回的标签训练数据和标签测试数据保存到 labels_trainDatalabels_testData 变量中。

  3. class_number_train = max(list(labels_trainData.keys()))
    获取标签训练数据中的最大键(即最大类别数),并将其保存到 class_number_train 变量中。

  4. class_number_test = max(list(labels_testData.keys()))
    获取标签测试数据中的最大键(即最大类别数),并将其保存到 class_number_test 变量中。

  5. wide = labels_trainData[0][0].shape[0]
    获取标签训练数据中第一个样本的宽度,并将其保存到 wide 变量中。

  6. length = labels_trainData[0][0].shape[1]
    获取标签训练数据中第一个样本的长度,并将其保存到 length 变量中。

  7. for label in labels_trainData.keys():
    遍历标签训练数据中的所有键。

  8. labels_trainData[label] = np.reshape(labels_trainData[label], [-1, 1, wide, length])
    对每个标签训练数据进行重塑,将其形状改为 [-1, 1, wide, length],其中 -1 表示自动计算维度大小。

  9. for label in labels_testData.keys():
    遍历标签测试数据中的所有键。

  10. labels_testData[label] = np.reshape(labels_testData[label], [-1, 1, wide, length])
    对每个标签测试数据进行重塑,将其形状改为 [-1, 1, wide, length]

  11. protonets = Protonets((1, wide, length), 10, 5, 5, 60, './log/', 50)
    创建一个 Protonets 类的实例,传入模型的初始化参数。

  12. for n in range(100):
    从0到99的循环中,执行以下代码块。

  13. protonets.train(labels_trainData, class_number_train)
    调用 protonets 实例的 train() 方法进行模型训练,传入标签训练数据和类别数。

  14. if n % 2 == 0 and n != 0:
    如果 n 是偶数且不为0,则执行以下代码块。

  15. torch.save(protonets.model, './log/model_net_' + str(n) + '.pkl')
    保存模型到 './log/model_net_' + str(n) + '.pkl' 的文件路径。

  16. protonets.save_center('./log/model_center_' + str(n) + '.csv')
    调用 protonets 实例的 save_center() 方法,将模型的中心点保存到 './log/model_center_' + str(n) + '.csv'

  17. test_accury = protonets.evaluation_model(labels_testData, class_number_test)
    调用 protonets 实例的 evaluation_model() 方法,对模型进行评估并返回测试准确率,将其保存到 test_accury 变量中。

  18. print(test_accury)
    打印测试准确率。

  19. str_data = str(n) + ',' + str(' test_accury ') + str(test_accury) + '\n'
    构建一个字符串以保存到文件中。

  20. with open('./log/model_step_eval.txt', "a") as f:
    打开一个文件,以追加模式写入。


总结

原型网络(Prototypical Network)是一种用于小样本学习的模型,由Jake Snell等人于2017年提出。它是一种基于元学习(meta-learning)的方法,主要用于解决在具有少量标记样本的情况下进行分类任务的问题。

传统的深度学习模型在处理小样本学习时通常表现不佳,因为它们需要大量的标记样本来进行训练。然而,在现实世界中,我们往往只有少量标记样本可用。原型网络通过引入一个用于表示类别的中心向量(原型)的概念,解决了这个问题。

原型网络的功能和优势如下:

  1. 小样本学习:原型网络适用于具有少量标记样本的分类任务,可以在只有几个样本可用时进行准确的分类。

  2. 元学习能力:原型网络通过学习类别的原型向量,能够在遇到新类别时进行快速学习,从而实现元学习的目标。

  3. 欧氏距离度量:原型网络使用欧氏距离来度量样本与原型之间的相似性,从而进行分类推断。这种度量方式非常直观和可解释,使得模型更易于理解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/163663.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【AI】行业消息精选和分析(23-11-19)

行业动态 1、对标GPTs,微软连夜发布100多项更新!微软CEO:Copilot时代来了 2、英伟达联手微软推出AI代工服务 3、全新雅虎搜索将于 2024 年上线,未来还会推出更多 AI 和高级功能 4、Instagram 推出定制 AI 贴纸和滤镜功能&#xff…

适用于全部安卓手机的 5 大免费 Android 数据恢复

您是否面临这样一种情况,即在Android设备上丢失了一些重要文件,但不知道应该选择哪种Android数据恢复来取回它们? 如果您以前从未备份过Android数据,则很难解决问题。 本文将介绍排名前5位的免费Android数据恢复软件。 您可以获…

【教3妹学编程-算法题】三个无重叠子数组的最大和

2哥 : 3妹,咋啦?一副苦大仇深的样子? 3妹:不开心呀不开心,羽生结弦宣布离婚。 2哥 : 羽生什么? 3妹:羽生结弦! 2哥 : 什么结弦? 3妹:羽生结弦!&am…

理论与实践相结合之Cisco Packet Tracer网络模拟器安装教程

简介 Packet Tracer是由思科设计的跨平台可视化仿真工具,它允许用户创建网络拓扑以模仿计算机网络和使用命令行界面来模拟配置思科路由器和交换机。Packet Tracer的用户界面为拖放式,允许用户根据自己的需要添加和删除模拟的网络设备。 Packet Tracer很…

安卓手机投屏到电视,跨品牌、跨地域同样可以实现!

在手机网页上看到的视频,也可以投屏到电视上看! 长时间使用手机,难免脖子会酸。这时候,如果你将手机屏幕投屏到大电视屏幕,可以减缓脖子的压力,而且大屏的视觉体验更爽。 假设你有一台安卓手机,…

odoo17前端js框架的演化

odoo17发布了,从界面上看,变化还是很明显的,比16更漂亮了,本来以为源码不会发生太大的变化,结果仔细一瞧,变化也不小。 1、打包好的文件数量和大小发生了变化 打包好的文件从两个变成了一个,在…

在excel中设置图表的标题

已经在excel做好了一个图,默认是没有标题的: 现在来设置一个标题。 双击图表,进入编辑状态: 右键,选择“插入标题”: 输入标题:

0基础编程教学,编程零基础该学什么,中文编程工具下载

0基础编程教学,编程零基础该学什么,中文编程工具下载 给大家分享一款中文编程工具,零基础轻松学编程,不需英语基础,编程工具可下载。 这款工具不但可以连接部分硬件,而且可以开发大型的软件,象…

vscode pandas无法使用

一、代码内容 import csv csv_reader csv.reader(open("data.csv")) for row in csv_reader:print(row) print(row[2]) 二、错误提示 ModuleNotFoundError: No module named pandas 三、安装pandas 然后我安装pandas,因为我的python的版本是python …

Dynamsoft Barcode Reader新框架将医疗视觉提升到新水平

Dynamsoft Vision 框架将医疗保健领域的计算机视觉提升到新的水平 引入图像捕获、内容理解、结果解析和交互式工作流程的聚合 SDK,以简化复杂的流程。 温哥华 – 2023 年 7 月 17 日 – Dynamsoft™ 发布了 Dynamsoft Barcode Reader SDK C Edition v10.0.0。更新…

Kotlin 知识体系

Kotlin 知识体系 1、Kotlin 文档2、Kotlin 基础3、桌面应用程序4、Android 与 iOS 应用程序 1、Kotlin 文档 Kotlin 是一门现代但已成熟的编程语言,旨在让开发人员更幸福快乐。 它简洁、安全、可与 Java 及其他语言互操作,并提供了多种方式在多个平台间复…

基于STM32单片机数字电压表自动切换量程及源程序

一、系统方案 1、本设计采用这STM32单片机作为主控器。 2、液晶1602显示。 3、内部ADC采集电压0-12V,自动切换档位。 二、硬件设计 原理图如下: 三、单片机软件设计 1、首先是系统初始化 u8 i; u16 a,b,c,d; u16 adcx; float adc; unsigned char datas…

(免费领源码)python+django+mysql线上兼职平台系统83320-计算机毕业设计项目选题推荐

摘 要 信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对线上兼职等问题,对线上兼职进行…

后端技术知识点内容-全部内容-面试宝典-后端面试知识点

文章目录 -2 flink-1 linux of viewlinux查看占用cup最高的10个进程的命令; 〇、分布式锁 & 分布式事务0-1分布式锁--包含CAP理论模型概述分布式锁:分布式锁应该具备哪些条件:分布式锁的业务场景: 分布式锁的实现方式有&#…

【算法挨揍日记】day22——面试题 17.16. 按摩师、213. 打家劫舍 II

面试题 17.16. 按摩师 面试题 17.16. 按摩师 题目描述: 一个有名的按摩师会收到源源不断的预约请求,每个预约都可以选择接或不接。在每次预约服务之间要有休息时间,因此她不能接受相邻的预约。给定一个预约请求序列,替按摩师找…

实用小算法

开头提醒: 打开自己本地任意一个SpringBoot项目,复制代码到test包下跟着敲。 后面几篇文章不再提醒,希望大家养成习惯。看10篇文章,不如自己动手做一次。 我们不执着于一天看多少篇,但求把每一篇都搞懂,…

python 计算最大回撤

1. 什么是最大回撤 最大回撤是评估金融产品收益的一个非常重要的风险指标,它指的是在选定历史周期内任一历史时点往后推,产品净值走到最低点时的收益率回撤幅度的最大值。 以上图为例, 最大回撤 ( V a l u e A − V a l u e B ) V a l u e …

《2020年最新面经》—字节跳动Java社招面试题

文章目录 前言:一面:01、Java基础知识答疑,简单概述一下?02、倒排索引了解吗?使用Java语言怎么实现倒排?03、详细讲解一下redis里面的哈希表,常用的Redis哈希表命名有哪些,举例说明其…

MongoDB相关基础操作(库、集合、文档)

文章目录 一、库的相关操作1、查看数据库2、查看当前库3、创建数据库4、删除数据库 二、集合的相关操作1、查看库中所有集合2、创建集合2.1、显示创建2.2、隐式创建 3、删除集合 三、文档的相关操作1、插入文档1.1、插入单条文档1.2、插入多条文档1.3、脚本方式 2、查询文档3、…

<Linux>权限管理|权限分类|权限设置|权限掩码|粘滞位

文章目录 Linux权限的概念Linux权限管理a. 文件访问者的分类b. 文件类型和访问权限c. 文件权限表示方法d. 文件权限的设置权限掩码file指令粘滞位 权限总结权限作业 Linux权限的概念 Linux下有两种用户:超级用户(root)和普通用户。 超级用户:可以在Lin…