TrustGeo代码理解(一)main.py

代码链接:https://github.com/ICDM-UESTC/TrustGeo

一、导入各种模块和数据库

# -*- coding: utf-8 -*-
import torch.nn

from lib.utils import *
import argparse, os
import numpy as np
import random
from lib.model import *
import copy
from thop import profile
import pandas as pd

整体功能是准备运行一个 PyTorch 深度学习模型的环境,具体的功能实现需要查看 lib.utils、lib.model 中的代码,以及整个文件的后续部分。

1、# -*- coding: utf-8 -*-:指定脚本的字符编码为UTF-8。

2、import torch.nn:导入 PyTorch 的神经网络模块,用于定义和训练神经网络。

3、from lib.utils import *从 lib.utils 模块中导入所有内容,这可能包括一些工具函数或辅助函数,用于该脚本或项目的其他部分。

4、import argparse, os:导入 argparse 模块用于解析命令行参数,os 模块用于与操作系统交互。

5、import numpy as np:导入 NumPy 库,用于进行科学计算,特别是多维数组的处理。

6、import random:导入 random 模块,用于生成伪随机数。

7、from lib.model import *:从 lib.model 模块中导入所有内容,这可能包括定义神经网络模型的类等。

8、import copy:导入 copy 模块,用于复制对象,通常用于创建对象的深拷贝

9、from thop import profile从 thop 模块中导入 profile 函数,该函数用于计算 PyTorch 模型的 FLOPs(浮点运算数)和参数数量。(在代码链接中没有找到)

10、import pandas as pd:导入 Pandas 库,用于数据处理和分析,通常用于处理表格型数据。

二、参数初始化(通过命令行参数)

parser = argparse.ArgumentParser()
# parameters of initializing
parser.add_argument('--seed', type=int, default=2022, help='manual seed')
parser.add_argument('--model_name', type=str, default='TrustGeo')
parser.add_argument('--dataset', type=str, default='New_York', choices=["Shanghai", "New_York", "Los_Angeles"],
                    help='which dataset to use')

这部分代码的目的是通过命令行参数设置一些初始化的参数,例如随机数种子、模型名称和数据集名称。这使得在运行脚本时可以通过命令行参数来指定这些参数的值。

1、parser = argparse.ArgumentParser():创建一个 argparse.ArgumentParser 对象,用于解析命令行参数。

2、# parameters of initializing:注释,表示接下来是初始化参数的部分。

3、parser.add_argument('--seed', type=int, default=2022, help='manual seed'):添加一个命令行参数,名称为 '--seed',表示随机数种子,类型为整数,默认值为 2022,help 参数是在命令行中输入 --help 时显示的帮助信息。

4、parser.add_argument('--model_name', type=str, default='TrustGeo'):添加一个命令行参数,名称为 '--model_name',表示模型的名称,类型为字符串,默认值为 'TrustGeo'。

5、parser.add_argument('--dataset', type=str, default='New_York', choices=["Shanghai", "New_York", "Los_Angeles"], help='which dataset to use'):添加一个命令行参数,名称为 '--dataset',表示数据集的名称,类型为字符串,默认值为 'New_York',choices 参数指定了可选的值为 ["Shanghai", "New_York", "Los_Angeles"],用户只能从这三个值中选择。help 参数是在命令行中输入 --help 时显示的帮助信息。

三、训练过程参数设置

# parameters of training
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--lambda1', type=float, default=7e-3)
parser.add_argument('--lr', type=float, default=5e-3)
parser.add_argument('--harved_epoch', type=int, default=5) 
parser.add_argument('--early_stop_epoch', type=int, default=50)
parser.add_argument('--saved_epoch', type=int, default=200)  

这部分代码的目的是设置一些训练过程中的超参数,例如优化器的动量参数、学习率、权重参数等。这些参数在训练过程中会影响模型的更新和收敛速度。

1、# parameters of training:注释,表示接下来是训练参数的部分。

2、parser.add_argument('--beta1', type=float, default=0.9):添加一个命令行参数,名称为 '--beta1',表示 Adam 优化器的第一个动量(momentum)参数,类型为浮点数,默认值为 0.9。

3、parser.add_argument('--beta2', type=float, default=0.999):添加一个命令行参数,名称为 '--beta2',表示 Adam 优化器的第二个动量参数,类型为浮点数,默认值为 0.999。

4、parser.add_argument('--lambda1', type=float, default=7e-3):添加一个命令行参数,名称为 '--lambda1',表示某个权重参数,类型为浮点数,默认值为 7e-3。

5、parser.add_argument('--lr', type=float, default=5e-3):添加一个命令行参数,名称为 '--lr',表示学习率,类型为浮点数,默认值为 5e-3。

6、parser.add_argument('--harved_epoch', type=int, default=5):添加一个命令行参数,名称为 '--harved_epoch',表示某个 epoch 的值,类型为整数,默认值为 5。

7、parser.add_argument('--early_stop_epoch', type=int, default=50):添加一个命令行参数,名称为 '--early_stop_epoch',表示早停(early stop)的 epoch 数,类型为整数,默认值为 50。

8、parser.add_argument('--saved_epoch', type=int, default=200):  添加一个命令行参数,名称为 '--saved_epoch',表示保存模型的 epoch 数,类型为整数,默认值为 200。

四、模型参数设置

# parameters of model
parser.add_argument('--dim_in', type=int, default=30, choices=[51, 30], help="51 if Shanghai / 30 else")

opt = parser.parse_args()
print("Learning rate: ", opt.lr)
print("Dataset: ", opt.dataset)

这部分代码的目的是解析命令行参数,并打印出学习率和数据集名称。--dim_in 参数用于指定输入维度,可以选择是 51 或者 30。

1、# parameters of model:注释,表示接下来是训模型参数的部分。

2、parser.add_argument('--dim_in', type=int, default=30, choices=[51, 30], help="51 if Shanghai / 30 else"):添加一个命令行参数,名称为 '--dim_in',表示输入的维度,类型为整数,默认值为 30。choices 参数指定了可选的值为 [51, 30],用户只能从这两个值中选择。help 参数是在命令行中输入 --help 时显示的帮助信息。

3、opt = parser.parse_args():使用 argparse 解析命令行参数,将结果存储在 opt 变量中

4、print("Learning rate: ", opt.lr):打印学习率,即 opt 对象中的 lr 属性

5、print("Dataset: ", opt.dataset):打印数据集名称,即 opt 对象中的 dataset 属性

五、设置随机种子数

if opt.seed:
    print("Random Seed: ", opt.seed)
    random.seed(opt.seed)
    torch.manual_seed(opt.seed)
torch.set_printoptions(threshold=float('inf'))

这一部分的目的是确保在使用随机数的场景中,每次运行程序得到的随机结果是可复现的。通过设置相同的随机数种子,可以使得每次运行得到相同的随机数序列。

1、如果 opt 对象中的 seed 属性存在(不为 0 或 False 等假值),则执行以下操作:

  • 打印随机数种子的信息。
  • 使用 random 模块设置 Python 内建的随机数生成器的种子。
  • 使用 PyTorch 的 torch 模块设置随机数种子。

2、torch.set_printoptions(threshold=float('inf')):设置 PyTorch 的打印选项,将打印的元素数量限制设置为无穷大,即不限制打印的元素数量。这样可以确保在打印张量时,所有元素都会被打印出来,而不会被省略。

六、过滤所有警告信息

warnings.filterwarnings('ignore')

过滤掉所有警告信息,将警告信息忽略。这通常用于在代码中避免显示一些不影响程序执行的警告信息,以保持输出的清晰。在某些情况下,警告信息可能是有用的,但如果明确知道这些警告对程序执行没有影响,可以选择忽略它们。

七、动态选择运行环境

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device:", device)
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

这部分代码的目的是根据硬件环境动态选择运行模型的设备,并选择相应的 PyTorch 张量类型。如果有可用的 GPU,就使用 GPU 运行模型和 GPU 张量类型;否则,使用 CPU 运行模型和 CPU 张量类型。

1、device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'):创建一个 PyTorch 设备对象,表示运行模型的设备。如果 CUDA 可用(即有可用的 GPU),则使用 'cuda:0' 表示第一个 GPU,否则使用 'cpu' 表示 CPU。

2、print("device:", device):打印设备的信息,即使用的是 GPU 还是 CPU。

3、cuda = True if torch.cuda.is_available() else False:根据 CUDA 是否可用设置一个布尔值,表示是否使用 GPU。如果 CUDA 可用,则 cuda

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/250611.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

devc++如何建立一个c++项目?devc++提示源文件未编译?

打开devc APP后是这样的界面; 点击文件-> 新建->项目,这一点应该不难,主要是最后这个选择什么? 这样即可。 devc提示源文件未编译? 点击工具->编译选项; 如果不能解决,那就是可能路径…

NNDL 循环神经网络-梯度爆炸实验 [HBU]

目录 6.2.1 梯度打印函数 6.2.2 复现梯度爆炸现象 6.2.3 使用梯度截断解决梯度爆炸问题 【思考题】梯度截断解决梯度爆炸问题的原理是什么? 总结 前言: 造成简单循环网络较难建模长程依赖问题的原因有两个:梯度爆炸和梯度消失。 循环…

代码随想录算法训练营第53天| 1143.最长公共子序列 1035.不相交的线 53. 最大子序和 动态规划

JAVA代码编写 1143.最长公共子序列 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。 一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字符的相对顺序的情…

软件测试面试八股文(答案解析+视频教程)

1、B/S架构和C/S架构区别 B/S 只需要有操作系统和浏览器就行,可以实现跨平台,客户端零维护,维护成本低,但是个性化能力低,响应速度较慢。 C/S响应速度快,安全性强,一般应用于局域网中&#xf…

【华为数据之道学习笔记】3-10元数据管理架构及策略

元数据管理架构包括产生元数据、采集元数据、注册元数据和运 维元数据。 产生元数据: 制定元数据管理相关流程与规范的落地方案,在IT产品开发过程中实现业务元数据与技术元数据的连接。 采集元数据: 通过统一的元模型从各类IT系统中自动采集元…

Linux下FFmepg使用

1.命令行录一段wav,PCM数据 ffmpeg -f alsa -i hw:0,0 xxx.wav//录制 ffplay out.wav//播放ffmpeg -f alsa -i hw:0,0 -ar 16000 -channels 1 -f s16le 1.pcm ffplay -ar 16000 -channels 1 -f s16le 1.pcm -ar freq 设置音频采样率 -ac channels 设置通道 缺省为1 2.将pcm…

002.Java实现两数相加

题意 给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。 请你将两个数相加,并以相同形式返回一个表示两数之和的新链表。 示例 输入:l1[2,4,3],l2[5,6,4] 输出…

【从零开始学习JVM | 第七篇】深入了解 堆回收

前言: Java堆作为内存管理中最核心的一部分,承担着对象实例的存储和管理任务。堆内存的高效使用对于保障程序的性能和稳定性至关重要。因此,深入理解Java堆回收的原理、机制和优化策略,对于Java开发人员具有重要的意义。 本文旨在…

springcloud-分布式缓存

文章目录 一.Redis持久化1.RDB持久化2.AOF持久化 二.Redis主从1.搭建主从架构2.全量同步3.增量同步 三.Redis哨兵1.哨兵的作用和原理2.搭建哨兵架构3.RedisTemplate的哨兵模式 四.Redis分片集群1.搭建分片集群2.散列插槽3.集群伸缩4.故障转移5.RedisTemplate访问分片集群 为什么…

树莓派(Raspberry Pi)4B密码忘记了,怎么办?

树莓派长时间不用,导致密码忘记了,这可咋整? 第1步:取出SD卡 将树莓派关机,移除sd卡,使用读卡器,插入到你的电脑。 第2步:编辑 cmdline.txt 在PC上打开SD卡根目录,启动…

Kotlin ArrayList类型toTypedArray转换Array

Kotlin ArrayList类型toTypedArray转换Array data class Point(val x: Float, val y: Float)fun array_test(points: ArrayList<Array<Point>>) {points.forEachIndexed { idx, ap ->ap.forEach {print("$idx $it ")}println()} }fun main(args: Arra…

2697. 字典序最小回文串

2697. 字典序最小回文串 难度: 简单 来源: 每日一题 2023.12.13 给你一个由 小写英文字母 组成的字符串 s &#xff0c;你可以对其执行一些操作。在一步操作中&#xff0c;你可以用其他小写英文字母 替换 s 中的一个字符。 请你执行 尽可能少的操作 &#xff0c;使 s 变…

RTX 40 SUPER发布时间定了!价格也有了

快科技12月16日消息&#xff0c;NVIDIA RTX 40 SUPER系列显卡基本确定将在2024年1月8日正式发布&#xff0c;也就是CES 2024大展期间&#xff0c;随后在1月中下旬陆续解禁上市。 RTX 4070 SUPER 1月16日解禁公版/原价丐版&#xff0c;1月17日解禁高价高配版&#xff0c;上市开…

鸿蒙开发编辑器设置

首先需要知道如何打开设置页面&#xff0c;以下所有设置都需要在设置界面中进行修改&#xff0c;有三种方式可以打开&#xff0c; 1、编辑器左上角file菜单下的Setting菜单。 2、编辑器右上角的设置按钮 3、按快捷键 ctrlalts 注意不要和其他软件案件重复。 一、设置每次打开…

制作一个简单 的maven plugin

流程 首先&#xff0c; 你需要创建一个Maven项目&#xff0c;推荐用idea 创建项目 会自动配置插件 pom.xml文件中添加以下配置&#xff1a; <project> <!-- 项目的基本信息 --> <groupId>com.example</groupId> <artifactId>my-maven-plugi…

腾讯云服务器优惠活动大全页面_全站搜优惠合集

腾讯云推出优惠全站搜页面 https://curl.qcloud.com/PPrF9NFe 在这个页面可以一键查询所需云服务器、轻量应用服务器、数据库、存储、CDN、网络、安全、大数据等云产品优惠活动大全&#xff0c;活动打开如下图&#xff1a; 腾讯云优惠全站搜 腾讯云优惠全站搜页面 txybk.com/go…

springboot笔记

尚硅谷SpringBoot3零基础教程&#xff0c;springboot入门到实战_哔哩哔哩_bilibili SpringBOOT 只会扫描在主程序下的包!!!!!!!!!!!!写在其他包上面会有问题 //SpringBootApplication(scanBasePackages "com") //也可以自己设置扫描路径 &#xff33;&#xff50…

【Qt开发流程】之UDP

概述 UDP (User Datagram Protocol)是一种简单的传输层协议。与TCP不同&#xff0c;UDP不提供可靠的数据传输和错误检测机制。UDP主要用于那些对实时性要求较高、对数据传输可靠性要求较低的应用&#xff0c;如音频、视频、实时游戏等。 UDP使用无连接的数据报传输模式。在传…

白日门引擎传奇手游架设教程-GM的成长之路

准备工具 服务器一台&#xff08;Windows系统&#xff09;白日门引擎服务端版本一个 前言&#xff1a; 此次教程使用的是版本是一个决战斗罗的一个版本、服务器使用的是驰网科技的游戏高频系列服务器。 教程开始 在我们拿到版本之后、我们需要先把版本解压到服务器D盘的根目录…

Mac安装Typora实现markdown自由

一、什么是markdown Markdown 是一种轻量级标记语言&#xff0c;创始人为约翰格鲁伯&#xff08;John Gruber&#xff09;。 它允许人们使用易读易写的纯文本格式编写文档&#xff0c;然后转换成有效的 XHTML&#xff08;或者HTML&#xff09;文档。这种语言吸收了很多在电子邮…