学习pytorch19 pytorch使用GPU训练

pytorch使用GPU进行训练

  • 1. 数据 模型 损失函数调用cuda()
  • 2. 使用谷歌免费GPU gogle colab 需要创建谷歌账号登录使用, 网络能访问谷歌
  • 3. 执行
  • 4. 代码

B站土堆学习视频: https://www.bilibili.com/video/BV1hE411t7RN/?p=30&spm_id_from=pageDriver&vd_source=9607a6d9d829b667f8f0ccaaaa142fcb

1. 数据 模型 损失函数调用cuda()

if torch.cuda.is_available():
    net = net.cuda()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()
if torch.cuda.is_available():
    imgs = imgs.cuda()
    targets = targets.cuda()

2. 使用谷歌免费GPU gogle colab 需要创建谷歌账号登录使用, 网络能访问谷歌

土堆YYDS, colab省事省力好多

3. 执行

运行报错
https://stackoverflow.com/questions/59013109/runtimeerror-input-type-torch-floattensor-and-weight-type-torch-cuda-floatte
You get this error because your model is on the GPU, but your data is on the CPU. So, you need to send your input tensors to the GPU.
意思是数据在cpu但是模型在gpu导致报错
查看代码测试数据做模型预测的时候,有个变量名写错了,改正后正常运行。
在这里插入图片描述
在这里插入图片描述

4. 代码

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import time
# from p24_model import *

# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoader

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)

# 2. 导入模型结构 创建模型
class Cifar10Net(nn.Module):
    def __init__(self):
        super(Cifar10Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.net(x)
        return x

net = Cifar10Net()
if torch.cuda.is_available():
    net = net.cuda()

# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()

# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数

# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')

start_time = time.time()
print('start_time: ', start_time)
print(torch.cuda.is_available())
for i in range(epoch):
    # 训练步骤开始
    net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用
    for data in train_loader:
        imgs, targets = data  # 获取数据
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            targets = targets.cuda()
        output = net(imgs)    # 数据输入模型
        loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
        # 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
        optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
        loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数
        optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化
        # 优化一次 认为训练了一次
        total_train_step += 1
        if total_train_step % 100 == 0:
            print('训练次数: {}   loss: {}'.format(total_train_step, loss))
            end_time = time.time()
            print('训练100次需要的时间:', end_time-start_time)
        # 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】
        # print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))
        writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)

    # 利用现有模型做模型测试
    # 测试步骤开始
    total_test_loss = 0
    accuracy = 0
    net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用
    with torch.no_grad():
        for data in test_loader:
            imags, targets = data
            if torch.cuda.is_available():
                imags = imags.cuda()
                targets = targets.cuda()
            output = net(imags)
            loss = loss_fn(output, targets)
            total_test_loss += loss.item()
            # 计算测试集的正确率
            preds = (output.argmax(1)==targets).sum()
            accuracy += preds
    # writer.add_scalar('test-loss', total_test_loss, global_step=i+1)
    writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)
    writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)
    total_test_step += 1
    print("---------test loss: {}--------------".format(total_test_loss))
    print("---------test accuracy: {}--------------".format(accuracy/len(test_data)))
    # 保存每一个epoch训练得到的模型
    torch.save(net, './net_epoch{}.pth'.format(i))

writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/229521.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

机器学习算法(7)-朴素贝叶斯算法和K最近邻算法

一、说明 在在这篇文章中,我将解释两种机器学习算法,称为贝叶斯定理和 K 最近邻算法。贝叶斯定理以 18 世纪英国数学家托马斯贝叶斯的名字命名,是确定条件概率的数学公式。k 最近邻算法,也称为 KNN 或 k-NN,是一种非参…

【Pyqt】QObject::connect: Cannot queue arguments of type ‘QTextCursor‘

问题说明 文本框接收到新的数据 不会自动滚动,并提示警告 QObject::connect: Cannot queue arguments of type ‘QTextCursor’ (Make sure ‘QTextCursor’ is registered using qRegisterMetaType().) 原因 线程回来的槽函数里面 调用了ui的代码 我们不能通过线程…

测试文档---智力冲刺

文章目录 项目背景测试计划UI测试接口测试手工测试 测试总结 项目背景 项目描述:“智力冲刺”是一款网页小游戏,就像我们平时看到的网页游戏一样,前端页面负责展示游戏效果,后端服务器来实现游戏的逻辑。在我们的“智力冲刺”游戏…

【从零认识ECS云服务器 | 快速上线个人网站】三、对外发布网站

3.1 配置域名 用户是如何访问网站的呢? 用户在浏览器(IE、Chrome、FireFox等)上输入域名,如:http://www.aliyun.com ; 浏览器自动调用DNS(域名服务)将域名解析为IP地址,如:123.123…

P3 Qt 控件 —— pushButton

前言 🎬 个人主页:ChenPi 🐻推荐专栏1: 《C_ChenPi的博客-CSDN博客》✨✨✨ 🔥 推荐专栏2: 《Linux C应用编程(概念类)_ChenPi的博客-CSDN博客》✨✨✨ 🌺本篇简介 :这一章我们学一…

CSGO游戏盲盒开箱源码 盲盒对战、幸运开箱、积分商城、Fl盲盒

源码介绍: CSGO游戏盲盒开箱源码 盲盒对战、幸运开箱、积分商城、Fl盲盒。这个是一个新花样玩法的盲盒程序。 仅供学习,请勿商用!请购买正版程序运营。 代码下载:百度网盘

一篇文章带你快速入门 Vue 核心语法

一篇文章带你快速入门 Vue 核心语法 一、为什么要学习Vue 1.前端必备技能 2.岗位多,绝大互联网公司都在使用Vue 3.提高开发效率 4.高薪必备技能(Vue2Vue3) 二、什么是Vue 概念:Vue (读音 /vjuː/,类似于 view) …

OpenVINS学习1——数据集配置与运行

前言 OpenVINS是基于MSCKF的开源VIO算法,有非常详细的官网文档可以学习使用,将来一段时间的主要实践工作,就是深度掌握这份开源代码。 https://docs.openvins.com/ 一、环境配置与Euroc数据集运行 我的环境是Ubuntu20.04,ROS1&a…

Python与ArcGIS系列(十六)重复节点检测

目录 0 简述1 实例需求2 arcpy开发脚本0 简述 在处理gis线图层和面图层数据时,有时候会遇到这种情况:数据存在重复节点或伪重复节点(两个节点距离极小),往往我们需要对这种数据进行检测标注或进行修改。本篇将介绍如何利用arcpy及arcgis的工具箱实现这个功能。 1 实例需求…

Axure原型图表组件库,数据可视化元件(Axure9大屏组件)

针对Axure制作的大屏图表元件库,帮助产品经理更高效地制作高保真图表原型,是产品经理必备元件工具。现分享完整的组件库,大家一起学习。 本组件库的图表模块,已包含所有常用的图表,以下为部分组件截图示意。文末可下载…

微信头像上传失败?照片尺寸修改方法大揭秘

微信上传头像失败是大家经常遇到的问题之一,其中的原因有可能是因为网络问题,也有可能是因为图片本身的问题,比如图片尺寸不符合平台要求,这时候,我们只需要通过给照片尺寸修改一下就可以解决,那么怎么修改…

LinuxBasicsForHackers笔记 -- 文件系统和存储设备管理

设备目录/dev Linux 有一个特殊的目录,其中包含代表每个连接设备的文件:相应命名的 /dev 目录。 /dev中有很多设备列表。 特别令人感兴趣的是设备 sda1、sda2、sda3、sdb 和 sdb1,它们通常是硬盘驱动器及其分区以及 USB 闪存驱动器及其分区…

Go实现http同步文件操作 - 增删改查

http同步文件操作 - 增删改查 http同步文件操作 - 增删改查1. 前置要求1.1. 构建结构体 文件名 文件内容1.1.1. 页面结构体1.1.2. 为Page结构体绑定方法:Save1.1.3. 对Page结构体支持页面内容查看方法,同时提供页面文件是否存在的方法 1.2. 简单验证上面…

如何快速完成企业私有云部署

快解析赋能企业私有云部署 ​ 很多企业形成了以总部为中心的多点生产体系结构,并借助网络化办公工具搭建跨区域协同办公系统,满足总部与分支机构间的信息互通,进而促进异地业务的信息共享,提高办公处理效率和综合管理水平。 北…

参与创作①周年啦~

写在前面 今天看了消息才知道,原来开始创作已经一年了。此篇无干货,纯白话,纯记录。 机缘 参与CSDN创作已经一年有余,犹记得第一篇博文是为了整理好所学内容,方便自己复习。没想到后面也陆陆续续发了些其他内容&…

python的websocket方法教程

WebSocket是一种网络通信协议,它在单个TCP连接上提供全双工的通信信道。在本篇文章中,我们将探讨如何在Python中使用WebSocket实现实时通信。 websockets是Python中最常用的网络库之一,也是websocket协议的Python实现。它不仅作为基础组件在…

SIMULINK 单相半波整流电路仿真

SIMULINK 单相半波整流电路仿真 去除毛刺,改步长 光滑如下:

嵌入版python作为便携计算器(安装及配置ipython)

今天用别的电脑调试C,需要计算反三角函数时发现没有趁手工具,忽然想用python作为便携计算器放在U盘,遂想到嵌入版python 懒得自己配可以直接下载,使用方法见第4节 1,下载embeddable python(嵌入版python&…

PCTFWP

web Sign In 查看源代码,访问提示文件就可得到flag easyupload 先上传一个图片马,会提示特洛伊文件,于是就改文件名,重新上传就可以得到flag 类似: [MoeCTF 2022]what are y0u uploading? 先做我 ?po…

SRR参考设计中SRR_DSS_mmWaveTask函数源码分析

相应的说明都已注释的形式写出来。 上面两个中断是在初始化的时候注册的,在SRR_DSS_initTask函数中,如下所示 后面会接着分析函数MmwDemo_processChirp和MmwDemo_interFrameProcessing。