pytorch:to()、device()、cuda()将Tensor或模型移动到指定的设备上

将Tensor或模型移动到指定的设备上:tensor.to(‘cuda:0’)

  • 最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行
  • 在做高维特征运算的时候,采用GPU无疑是比用CPU效率更高,如果两个数据中一个加了.cuda()或者.to(device),而另外一个没有加,就会造成类型不匹配而报错。

1. Tensor.to(device)

功能:将Tensor移动到指定的设备上。

以下代码将Tensor移动到GPU上:
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

1.1 修改dtype

a = tensor.to(torch.float64).

  • tensor.dtype : torch.float32
  • a.dtype : torch.float64

1.2 改变device:用字符串形式给出

a = tensor.to('cuda:0').

  • tensor.device : device(type=‘cpu’)
  • a.device : device(type=‘cuda’, index=0)

1.3 改变device:用torch.device给出

cuda0 = torch.device('cuda:0') .
b = tensor.to(cuda0) .

  • tensor.device : device(type=‘cpu’)
  • b.device : device(type=‘cuda’, index=0)

1.4 同时改变device和dtype

c = tensor.to('cuda:0',torch.float64) .
other = torch.randn((), dtype=torch.float64, device=cuda0) .
d = tensor.to(other, non_blocking=True) .

  • tensor.device:device(type=‘cpu’)
  • d :tensor([], device=‘cuda:0’, dtype=torch.float64))

2. model.to(device)

功能:将模型移动到指定的设备上。

使用以下代码将模型移动到GPU上:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        
model = Net()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

如果有多个GPU,使用以下方法:

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model,device_ids=[0,1,2])
  
model.to(device)

将由GPU保存的模型加载到GPU上。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由GPU保存的模型加载到CPU上。
torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
model.to(device)

将由CPU保存的模型加载到GPU上。
torch.load()函数中的map_location参数设置为torch.device('cuda')

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

参考:PyTorch之Tensor.to(device)和model.to(device)

3. .to(device) 和.cuda()的区别

  • .to(device) 可以指定CPU 或者GPU
  • .cuda() 只能指定GPU

图参考:pytorch中.to(device) 和.cuda()的区别
在这里插入图片描述

官方文档:CUDA SEMANTICS

with torch.cuda.device(1):
    # allocates a tensor on GPU 1
    a = torch.tensor([1., 2.], device=cuda)
 
    # transfers a tensor from CPU to GPU 1
    b = torch.tensor([1., 2.]).cuda()
    # a.device and b.device are device(type='cuda', index=1)
 
    # You can also use ``Tensor.to`` to transfer a tensor:
    b2 = torch.tensor([1., 2.]).to(device=cuda)
    # b.device and b2.device are device(type='cuda', index=1)
  • 两个方法都可以达到同样的效果,在pytorch中,即使是有GPU的机器,它也不会自动使用GPU,而是需要在程序中显示指定。
  • 调用model.cuda(),可以将模型加载到GPU上去。这种方法不被提倡,而建议使用model.to(device)的方式,这样可以显示指定需要使用的计算资源,特别是有多个GPU的情况下。

4. CUDA相关信息查询

import torch
print('CUDA版本:',torch.version.cuda)
print('Pytorch版本:',torch.__version__)
print('显卡是否可用:','可用' if(torch.cuda.is_available()) else '不可用')
print('显卡数量:',torch.cuda.device_count())
print('当前显卡的CUDA算力:',torch.cuda.get_device_capability(0))
print('当前显卡型号:',torch.cuda.get_device_name(0))
>>>
CUDA版本: 11.7
Pytorch版本: 1.13.1
显卡是否可用: 可用
显卡数量: 1
当前显卡的CUDA算力: (8, 6)
当前显卡型号: NVIDIA GeForce RTX 3060 Laptop GPU

参考:https://blog.csdn.net/weixin_43845386/article/details/131723010

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/246146.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

智慧城市/一网统管建设:人员危险行为检测算法,为城市安全保驾护航

随着人们压力的不断增加,经常会看见在日常生活中由于小摩擦造成的大事故。如何在事故发生时进行及时告警,又如何在事故发生后进行证据搜索与事件溯源?旭帆科技智能视频监控人员危险行为/事件检测算法可以给出答案。 全程监控,有源…

智能优化算法应用:基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.探路者算法4.实验参数设定5.算法结果6.参考文…

人工智能导论习题集(5)

第四章:搜索策略 题1题2 题1 题2 参考文献:https://blog.csdn.net/m0_64372178/article/details/134890561

平头哥玄铁 E906 编译与使用

1、E906 概述 与玄铁 E902 一样, E906 核的 RTL 代码也是开源的,开源仓库地址:https://github.com/T-head-Semi/opene906 玄铁 E906 是一款完全可综合的中端 MCU 级处理器,兼容 RV32IMA[F][D]C[P] 指令集,提供可观的…

C# WPF上位机开发(动态添加控件)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 写图形界面软件的时候,我们经常会遇到一种情况。那就是图形界面上面,显示的控件可能是不定的。有可能多,也有可…

JupyterNotebook VS JupyterLab 如果jupyter安装成功,点击jupyterlab即可进入lab环境

简介 JupyterNotebook 是一个款以网页为基础的交互计算环境,可以创建Jupyter的文档,支持多种语言,包括Python, Julia, R等等。一般来说,如果是使用R语言的话,使用Rstudio居多,使用Python的话,使…

消息队列 - RabbitMQ

消息队列 - RabbitMQ 1. 初识 MQ1.1 同步调用1.2 异步调用1.3.技术选型 2. RabbitMQ2.1 安装2.2 收发信息2.2.1 交换机(Exchange)2.2.2 队列2.2.3 绑定关系2.2.4 发送消息 2.3 数据隔离 相关文章: SpringAMQP 快速入门 RabbitMQ 如何保证消息可靠性 1. 初识 MQ 微服务一旦拆分…

beyond compare文件夹比较时候文本乱码问题解决

“格式”,在下面的左侧编码重写和右侧编码覆盖选择 GB2312/UTF-8/GBK,这个可以根据自己喜好和文本自身的encode选择。

C++ 指针基础

指针的定义 #define _CRT_SECURE_NO_WARNINGS #include <stdio.h>int main() {int a 10;printf("%p\n", &a); // &a 表示取&#xff08;a&#xff09;的地址 000000842FD9FB64// 存放指针&#xff08;地址&#xff09;的变量就是指针变量&#xff0c…

Java中的链表

文章目录 前言一、链表的概念及结构二、单向不带头非循坏链表的实现2.1打印链表2.2求链表的长度2.3头插法2.4尾插法2.5任意位置插入2.6查找是否包含某个元素的节点2.7删除第一次出现这个元素的节点2.8删除包含这个元素的所以节点2.9清空链表单向链表的测试 三、双向不带头非循坏…

spring 笔记三 Spring与Web环境集成

文章目录 Spring与Web环境集成ApplicationContext应用上下文获取方式导入Spring集成web的坐标置ContextLoaderListener监听器通过工具获得应用上下文对象SpringMVC概述SpringMVC快速入门 Spring与Web环境集成 ApplicationContext应用上下文获取方式 应用上下文对象是通过new …

leetcode面试经典150题——36 旋转图像

题目&#xff1a; 旋转图像 描述&#xff1a; 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像&#xff0c;这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1&#xff1a; 输入&#x…

中国经济增长:全球复苏的引擎

近年来&#xff0c;中国经济以其强劲的增长势头成为全球经济的重要引擎。中国的经济崛起不仅对自身国家发展具有重要意义&#xff0c;而且也对全球经济复苏和稳定有着积极影响。本文将从多个角度探讨中国经济增长对全球经济的影响及其作为全球复苏的引擎。 首先&#xff0c;中国…

《使用ThinkPHP6开发项目》 - ThinkPHP6使用JWT生成Token

《使用ThinkPHP6开发项目》 - 登录接口一-CSDN博客 《使用ThinkPHP6开发项目》 - 登录接口二-CSDN博客 《使用ThinkPHP6开发项目》 - 登录接口三【表单验证】-CSDN博客 登录接口成功后返回Token&#xff0c;便于其他需要验证用户登录的接口携带Token进行请求校验 这里Think…

基于ssm网络安全宣传网站设计论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本网络安全宣传网站就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据信息…

SI24R03国产自主可控RISC-V架构MCU低功耗2.4GHz收发芯片SoC

目录 RISC-V架构的优势SI24R03/04特性射频收发器模块特征MCU 模块特征 其他特征 RISC-V架构的优势 相对于目前主流的英特尔X86架构及ARM等架构来说&#xff0c;RISC-V架构具有指令精简、模块化、可扩展、开源、免费等优点。RISC-V的基础指令集只有40多条&#xff0c;加上其他基…

爬虫的分类

爬虫的分类 网络爬虫按照系统结构和实现技术&#xff0c;大致可分为4类&#xff0c;即通用网络爬虫、聚焦网络爬虫、增量网络爬虫和深层次网络爬虫。 1.通用网络爬虫&#xff1a;搜索引擎的爬虫 比如用户在百度搜索引擎上检索对应关键词时&#xff0c;百度将对关键词进行分析…

比特币价格创新高:加密货币的崛起与未来

一、引言 近年来&#xff0c;比特币的价格一路上涨&#xff0c;引起了全球投资者和市场的广泛关注。作为最早一批区块链技术应用案例之一&#xff0c;比特币的成功带动了整个加密货币市场的兴起。本文将探讨比特币价格创新高的原因、加密货币的崛起以及未来发展趋势。 二、比特…

ChatGPT热门项目

1.智能GPT 项目地址&#xff1a;智能GPT&#xff1a;你只要提供OpenAI的API Key&#xff0c;那么它就可以根据你设定的目标&#xff0c;采用Google搜索、浏览网站、执行脚本等方式 主要语言&#xff1a;Python 推荐理由&#xff1a;这是由开发者Significant Gravitas推出的项目…

【C++练级之路】【Lv.4】类和对象(下)(初始化列表,友元,static成员,编译器的优化)

目录 一、再谈构造函数1.1 构造函数体赋值1.2 初始化列表1.3 explicit关键字 二、static成员2.1 概念2.2 特性 三、友元3.1 引入3.2 友元函数3.2.1 概念3.2.2 特性 3.3 友元类3.3.1 概念3.3.2 特性 四、内部类4.1 概念4.2 特性 五、匿名对象六、编译器的优化6.1 传参优化6.1.1 …