Pytorch 三小时极限入门教程

一、引言

在当今的人工智能领域,深度学习占据了举足轻重的地位。而 Pytorch 作为一款广受欢迎的深度学习框架,以其简洁、灵活的特性,吸引了大量开发者投身其中。无论是科研人员探索前沿的神经网络架构,还是工程师将深度学习技术落地到实际项目,Pytorch 都提供了强大的支持。本教程将带你从零基础开始,一步步深入了解 Pytorch 的核心知识,助你顺利踏上深度学习的征程。

二、Pytorch 基础环境搭建

安装 Anaconda

Anaconda 是一个强大的 Python 包管理器和环境管理器,方便我们创建独立的 Python 开发环境。首先,从 Anaconda 官方网站下载对应操作系统的安装包,一路默认安装即可。安装完成后,打开终端(Linux/Mac)或命令提示符(Windows),输入 conda --version 验证是否安装成功。

创建虚拟环境

使用 conda create -n pytorch_env python=3.8 创建一个名为 pytorch_env 的虚拟环境,这里指定 Python 版本为 3.8,你可以根据实际需求调整。激活虚拟环境,在 Linux/Mac 下使用 source activate pytorch_env,Windows 下使用 activate pytorch_env。

安装 Pytorch

访问 Pytorch 官方网站,根据你的系统配置(如 CUDA 是否可用)选择合适的安装命令。例如,如果你的电脑有 NVIDIA GPU 且支持 CUDA 11.3,安装命令可能为 conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch。如果没有 GPU,则选择 CPU 版本的安装命令,如 conda install pytorch torchvision torchaudio cpuonly -c pytorch。安装完成后,在 Python 交互式环境中输入 import torch,没有报错则说明安装成功。

三、张量(Tensor):深度学习的基石

张量的定义与创建

张量是 Pytorch 中最基本的数据结构,类似于 NumPy 中的数组,但具有更强的功能。可以使用 torch.tensor() 函数从 Python 列表或 NumPy 数组创建张量,例如:

import torch

import numpy as np

# 从列表创建张量

data_list = [1, 2, 3, 4]

tensor_from_list = torch.tensor(data_list)

# 从 NumPy 数组创建张量

np_array = np.array([5, 6, 7, 8])

tensor_from_numpy = torch.from_numpy(np_array)

还可以使用 torch.zeros()、torch.ones()、torch.rand() 等函数创建具有特定形状的全 0、全 1 或随机值张量。

张量的属性与操作

张量具有形状(shape)、数据类型(dtype)等属性。可以通过 .shape 和 .dtype 来访问,例如:

tensor = torch.rand(3, 4)

print(tensor.shape)

print(tensor.dtype)

张量支持丰富的数学运算,如加法、减法、乘法、除法等,操作符重载使得代码简洁直观:

a = torch.rand(2, 3)

b = torch.rand(2, 3)

c = a + b

d = a * b

同时,也有大量的函数可供调用,像 torch.sum()、torch.mean() 等用于统计计算。

四、自动求导(Autograd):神经网络训练的关键

自动求导原理简介

在深度学习中,模型训练的核心是反向传播算法,而 Pytorch 的自动求导机制极大地简化了这一过程。当创建一个张量时,如果设置 requires_grad=True,Pytorch 会记录该张量上的所有操作,构建一个计算图。在反向传播时,利用这个计算图自动计算梯度。

示例:简单函数求导

x = torch.tensor([2.], requires_grad=True)

y = x ** 2 + 3 * x

y.backward()

print(x.grad)

这里定义了一个简单的函数 ,对 x 求导后,x.grad 存储了梯度值,即 在 时的值 7。

 复杂模型中的应用

在构建神经网络时,模型参数都设置为 requires_grad=True。在每一次前向传播计算损失后,通过 loss.backward() 反向传播梯度,然后使用优化器(如 SGD、Adam 等)根据梯度更新参数,实现模型的训练。

五、神经网络模块(nn.Module):构建模型的利器

自定义神经网络

继承 nn.Module 类可以方便地自定义神经网络。首先在 __init__() 函数中定义模型的层结构,如全连接层 nn.Linear,卷积层 nn.Conv2d 等,然后在 forward() 函数中定义数据的前向传播路径。

import torch.nn as nn

class SimpleNet(nn.Module):

def __init__(self):

super(SimpleNet, self).__init__()

self.fc1 = nn.Linear(10, 20)

self.fc2 = nn.Linear(20, 1)

def forward(self, x):

x = torch.relu(self.fc1(x))

x = self.fc2(x)

return x

这里定义了一个简单的两层全连接神经网络,输入维度为 10,中间层维度为 20,输出维度为 1,中间使用 ReLU 作为激活函数。

预训练模型的使用与微调

Pytorch 提供了丰富的预训练模型,如 ResNet、VGG 等经典的图像分类模型。可以通过 torchvision.models 模块加载预训练模型,然后根据自己的任务需求,修改最后几层的结构并进行微调。例如:

import torchvision.models as models

resnet = models.resnet18(pretrained=True)

# 修改最后一层输出维度为自定义类别数

resnet.fc = nn.Linear(resnet.fc.in_features, 10)

这使得在数据量有限的情况下,也能利用预训练模型的强大特征提取能力,快速搭建高性能模型。

六、数据加载与预处理(DataLoader)

数据集类的构建

要使用自己的数据训练模型,需要构建自定义数据集类,继承 torch.utils.data.Dataset。在类中实现 __getitem__() 方法用于获取单个样本及其标签,__len__() 方法返回数据集的大小。例如,对于图像分类数据集:

from torch.utils.data import Dataset

import os

import cv2

class ImageDataset(Dataset):

def __init__(self, root_dir, transform=None):

self.root_dir = root_dir

self.image_files = os.listdir(root_dir)

self.transform = transform

def __getitem__(self, index):

image_path = os.path.join(self.root_dir, self.image_files[index])

image = cv2.imread(image_path)

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

label = int(self.image_files[index].split('.')[0])

if self.transform:

image = self.transform(image)

return image, label

def __len__(self):

return len(self.image_files)

数据加载器的使用

使用 torch.utils.data.DataLoader 将数据集封装成可迭代的数据加载器,方便在训练过程中批量获取数据。可以设置批量大小(batch_size)、是否打乱数据(shuffle)等参数,例如:

from torch.utils.data import DataLoader

dataset = ImageDataset(root_dir='data/images', transform=transforms.ToTensor())

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

在训练循环中,通过遍历数据加载器获取批量数据,送入模型进行训练。

七、模型训练与评估

训练循环

模型训练通常包括多个 epoch,每个 epoch 遍历一遍整个数据集。在每个 epoch 内,按批次获取数据,前向传播计算损失,反向传播更新参数。以下是一个简单的训练循环示例:

model = SimpleNet()

criterion = nn.MSELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(10):

running_loss = 0.0

for i, (inputs, labels) in enumerate(dataloader):

optimizer.zero_grad()

outputs = model(inputs)

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

running_loss += loss.item()

print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}')

评估指标与方法

根据任务不同,评估指标各异。对于分类任务,常用准确率(Accuracy),可以通过比较模型预测结果与真实标签计算得出:

correct = 0

total = 0

with torch.no_grad():

for inputs, labels in dataloader:

outputs = model(inputs)

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += (predicted == labels).sum().item()

accuracy = correct / total

print(f'Accuracy: {accuracy}')

对于回归任务,可能使用均方误差(MSE)、平均绝对误差(MAE)等指标。

八、模型保存与加载

保存模型

可以使用 torch.save() 保存模型的参数或整个模型结构,例如保存模型参数:

torch.save(model.state_dict(), 'model.pth')

若要保存整个模型,包括结构和参数:

torch.save(model, 'whole_model.pth')

加载模型

加载模型参数时,先创建模型实例,再使用 model.load_state_dict(torch.load('model.pth')) 加载。若加载整个模型,则直接 model = torch.load('whole_model.pth')。加载后,模型即可用于预测或继续训练。

九、可视化工具(TensorBoard)

安装与配置

TensorBoard 是一个强大的可视化工具,用于监控模型训练过程。使用 pip install tensorboard 安装,在 Pytorch 代码中引入相关模块:

from torch.utils.tensorboard import SummaryWriter

创建一个 SummaryWriter 实例,指定日志目录,如 writer = SummaryWriter('logs')。

可视化训练过程

在训练过程中,可以使用 writer.add_scalar() 记录损失、准确率等指标随 epoch 的变化:

for epoch in range(10):

# 训练代码...

writer.add_scalar('Loss', running_loss / len(dataloader), epoch)

writer.add_scalar('Accuracy', accuracy, epoch)

writer.close()

运行 tensorboard --logdir=logs 命令后,在浏览器中打开相应地址,即可查看可视化图表,直观了解模型训练动态。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/950723.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2025_0105_生活记录

3号去内蒙看了流星雨。还记得上次看流星的时间是2018年,也是冬天,大家在雁栖湖校区的操场上仰望星空。那个时候幸运的看到了一颗流星,便迅速地在心里许愿。这次看到了三颗流星,我也许了愿,希望实现。 24年走过了十多个…

(四)ROS通信编程——服务通信

前言 学完了话题通信其实操作流程基本都已经很熟悉了,因此服务通讯的学习就会流畅许多。 服务通信也是ROS中一种极其常用的通信模式,服务通信是基于请求响应模式的,是一种应答机制。也即: 一个节点A向另一个节点B发送请求,B接收…

《繁星路》V1.8.3(Build16632266)官方中文学习版

《繁星路》官方中文版https://pan.xunlei.com/s/VODae2_2Z3QyMF02I5y321uHA1?pwdqgsh# 作为一款星际模拟游戏,完美融合了硬科幻元素与基地建设玩法,体验改造行星的恢弘与壮阔。化身人工意识AMI,遵照基本指示推进火星改造的各项工作&#xf…

学习threejs,导入wrl格式的模型

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️THREE.VRMLLoader wrl模型加…

Element-plus、Element-ui之Tree 树形控件回显Bug问题。

需求&#xff1a;提交时&#xff0c;需要把选中状态和半选中状态 的数据id提交。如图所示&#xff1a; 数据回显时&#xff0c;会出现代码如下&#xff1a; <template><el-tree ref"treeRef" :data"data" show-checkbox node-key"id" …

专家混合(MoE)大语言模型:免费的嵌入模型新宠

专家混合&#xff08;MoE&#xff09;大语言模型&#xff1a;免费的嵌入模型新宠 今天&#xff0c;我们深入探讨一种备受瞩目的架构——专家混合&#xff08;Mixture-of-Experts&#xff0c;MoE&#xff09;大语言模型&#xff0c;它在嵌入模型领域展现出了独特的魅力。 一、M…

cursor试用出现:Too many free trial accounts used on this machine 的解决方法

文章精选推荐 1 JetBrains Ai assistant 编程工具让你的工作效率翻倍 2 Extra Icons&#xff1a;JetBrains IDE的图标增强神器 3 IDEA插件推荐-SequenceDiagram&#xff0c;自动生成时序图 4 BashSupport Pro 这个ides插件主要是用来干嘛的 &#xff1f; 5 IDEA必装的插件&…

若依 ruoyi-vue HandlerInterceptor 拦截器 文件接口自定义权限

文件资源添加自定义权限 package com.huida.framework.config;import com.huida.framework.interceptor.FileInterceptor; import com.huida.framework.interceptor.RequestInterceptor; import org.springframework.beans.factory.annotation.Autowired; import org.springfr…

网络安全-web渗透环境搭建-BWAPP(基础篇)

01--所需系统环境&#xff1a; 虚拟主机系统部署&#xff08;vmware&#xff0c;虚拟主机创建、虚拟主机网络配置&#xff08;桥接&#xff0c;便于网络中多个主机都能访问虚拟主机&#xff09;、虚拟软件功能&#xff0c;快照、克隆、镜像文件加载&#xff0c;ova文件制作&am…

【计算机网络】课程 实验四 配置快速生成树协议(RSTP)

实验四 配置快速生成树协议&#xff08;RSTP&#xff09; 一、实验目的 1&#xff0e;理解快速生成树协议RSTP的工作原理。 2&#xff0e;掌握如何在交换机上配置快速生成树。 二、实验分析与设计 【背景描述】 某学校为了开展计算机教学和网络办公&#xff0c;建立了一个计…

API架构风格的深度解析与选择策略:SOAP、REST、GraphQL与RPC

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

目标检测文献阅读-DETR:使用Transformer进行端到端目标检测

目录 摘要 Abstract 1 引言 2 DETR结构 2.1 Backbone 2.2 Encoder 2.3 Decoder 2.4 FFN 3 目标检测集合预测损失 3.1 二分图匹配损失 3.2 损失函数 总结 摘要 本周阅读的论文题目是《End-to-End Object Detection with Transformers》(使用Transformer进行端到端目…

服务器双网卡NCCL通过交换机通信

1、NCCL变量设置 export CUDA_DEVICE_MAX_CONNECTIONS1 export NCCL_SOCKET_IFNAMEeno2 export NCCL_IB_DISABLE0 #export NCCL_NETIB export NCCL_IB_HCAmlx5_0,mlx5_1 export NCCL_IB_GID_INDEX3 export NCCL_DEBUGINFOGPUS_PER_NODE4MASTER_ADDR192.168.1.2 MASTER_PORT600…

B树及其Java实现详解

文章目录 B树及其Java实现详解一、引言二、B树的结构与性质1、节点结构2、性质 三、B树的操作1、插入操作1.1、插入过程 2、删除操作2.1、删除过程 3、搜索操作 四、B树的Java实现1、节点类实现2、B树类实现 五、使用示例六、总结 B树及其Java实现详解 一、引言 B树是一种多路…

数据分析思维(八):分析方法——RFM分析方法

数据分析并非只是简单的数据分析工具三板斧——Excel、SQL、Python&#xff0c;更重要的是数据分析思维。没有数据分析思维和业务知识&#xff0c;就算拿到一堆数据&#xff0c;也不知道如何下手。 推荐书本《数据分析思维——分析方法和业务知识》&#xff0c;本文内容就是提取…

微信小程序用的SSL证书有什么要求吗?

微信小程序主要建立在手机端使用&#xff0c;然而手机又涉及到各种系统及版本&#xff0c;所以对SSL证书也有要求&#xff0c;如果要小程序可以安全有效的访问需要满足以下要求&#xff1a; 1、原厂SSL证书&#xff08;原厂封&#xff09;。 2、DV单域名或者DV通配符。 3、兼…

手动安装 Maven 依赖到本地仓库

文章目录 手动安装 Maven 依赖到本地仓库1. 下载所需的 JAR 文件2. 安装 JAR 文件到本地仓库3. 验证安装4. 在项目中使用该依赖 手动安装 Maven 依赖到本地仓库 遇到的问题&#xff1a; idea导入一个新的工程&#xff0c;发现pom文件中的一些依赖死活下载不下来&#xff0c;这…

VSCode Live Server 插件安装和使用

VSCode Live Server是一个由Ritwick Dey开发的Visual Studio Code扩展插件&#xff0c;它提供了一个带有实时重载功能的本地开发服务器。在VSCode中安装和使用Live Server插件进行实时预览和调试Web应用程序。这将大大提高前端开发效率&#xff0c;使网页设计和开发变得更为流畅…

UART串口数据分析

串口基础知识详细介绍&#xff1a; 该链接详细介绍了串并行、单双工、同异步、连接方式 https://blog.csdn.net/weixin_43386810/article/details/127156063 该文章将介绍串口数据的电平变化、波特率计算、脉宽计算以及数据传输量的计算。 捕获工具&#xff1a;逻辑分析仪&…

Internet协议原理

文章目录 考试说明Chapter 0: 本书介绍Chapter 1: Introduction And Overview 【第1章&#xff1a;引言与概述】Chapter 2: Overview Of Underlying Network Technologies 【第2章&#xff1a;底层网络技术的回顾】Chapter 3: Internetworking Concept And Architectural Model…