什么是Pytorch?

在这里插入图片描述

当谈及深度学习框架时,PyTorch 是当今备受欢迎的选择之一。作为一个开源的机器学习库,PyTorch 为研究人员和开发者们提供了一个强大的工具来构建、训练以及部署各种深度学习模型。你可能会问,PyTorch 是什么,它有什么特点,以及如何使用它呢?

什么是 PyTorch?

PyTorch 是一个基于 Python 的机器学习库,专注于强大的张量计算(tensor computation)和动态计算图(dynamic computation graph)。与其他框架相比,它的一个显著特点就是动态计算图,这意味着你可以在运行时定义和修改计算图,从而更灵活地构建复杂的模型。PyTorch 由 Facebook 的人工智能研究小组开发,已经得到了广泛的认可和采用。

PyTorch 的特点

  1. 动态计算图: PyTorch 的动态计算图使得模型构建和调试变得更加直观。你可以像编写 Python 代码一样编写神经网络结构,而不需要事先定义静态图。

  2. 张量操作: PyTorch 提供了丰富的张量操作功能,它们类似于 NumPy 数组,但是可以在 GPU 上运行以加速计算,适用于大规模的数据处理和深度学习任务。

  3. 自动求导: PyTorch 自动处理了求导过程,无需手动计算梯度。这使得训练模型变得更加方便和高效。

  4. 模块化设计: PyTorch 的模块化设计使得构建复杂的神经网络变得简单。你可以通过组合不同的模块来创建自己的模型。

如何使用 PyTorch?

让我们通过一个简单的示例来看看如何使用 PyTorch 来构建一个基本的神经网络:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的神经网络类
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建神经网络实例、损失函数和优化器
net = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)

# 加载数据并进行训练
for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss}")
print("Finished Training")

分析环节:

可能会有很多小伙伴不明白,我会进行整个代码的详细分析,逐行解释每个部分的作用和功能。

import torch
import torch.nn as nn
import torch.optim as optim

这部分代码导入了PyTorch库的必要模块,包括torchtorch.nn以及torch.optimtorch是PyTorch的核心模块,提供了张量等基本数据结构和操作;torch.nn提供了神经网络相关的类和函数;torch.optim提供了各种优化器,用于更新神经网络的参数。

# 定义一个简单的神经网络类
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

这部分定义了一个简单的神经网络类SimpleNN,该类继承自nn.Module,是PyTorch中自定义神经网络的一种标准做法。网络有两个全连接层(线性层):fc1fc2forward方法定义了前向传播过程,首先通过fc1进行线性变换,然后使用ReLU激活函数,最后通过fc2输出。

# 创建神经网络实例、损失函数和优化器
net = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)

在这部分,我们实例化了刚刚定义的SimpleNN类,创建了一个神经网络netnn.CrossEntropyLoss()是交叉熵损失函数,适用于多类别分类问题。optim.SGD是随机梯度下降优化器,用于更新网络的权重和偏置。

# 加载数据并进行训练
for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss}")
print("Finished Training")

这部分是训练过程的主体。我们使用一个外层循环进行多次训练迭代(5次),每次迭代中,我们遍历训练数据集,计算并更新网络的参数。

  • for epoch in range(5)::外层循环迭代5次,表示5个训练轮次。

  • running_loss = 0.0:用于记录每个训练轮次的累计损失。

  • for i, data in enumerate(trainloader, 0)::遍历训练数据集。enumerate函数用于同时获取数据的索引i和数据本身data

  • inputs, labels = data:将数据拆分为输入和标签。

  • optimizer.zero_grad():清零梯度,准备进行反向传播。

  • outputs = net(inputs):将输入数据输入神经网络,得到输出。

  • loss = criterion(outputs, labels):计算输出和真实标签之间的损失。

  • loss.backward():进行反向传播,计算梯度。

  • optimizer.step():使用优化器更新网络的参数。

  • running_loss += loss.item():累计损失。

  • print(f"Epoch {epoch+1}, Loss: {running_loss}"):打印每个轮次的训练损失。

  • print("Finished Training"):训练完成后打印提示。

整个代码实现了对一个简单的神经网络的训练过程,通过反向传播更新网络参数,使得模型能够逐渐拟合训练数据,从而实现分类任务。

案例分析

我们要说个典型案例:使用 PyTorch 进行图像分类。通过构建神经网络模型、加载数据集、定义损失函数和优化器,可以训练出一个能够识别不同类别的图像的分类器。

我们将创建了一个卷积神经网络(CNN)模型,加载CIFAR-10数据集,通过定义损失函数和优化器,进行模型的训练。这个模型可以用来对CIFAR-10数据集中的图像进行分类,识别不同的物体类别。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 步骤 2:加载和预处理数据集
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

# 使用 torchvision 加载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# 创建一个 DataLoader,用于批量加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

# 步骤 3:定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # 输入通道数为3,输出通道数为6,卷积核大小为5x5
        self.pool = nn.MaxPool2d(2, 2)  # 最大池化,窗口大小为2x2
        self.conv2 = nn.Conv2d(6, 16, 5)  # 输入通道数为6,输出通道数为16,卷积核大小为5x5
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 全连接层,输入维度为16x5x5,输出维度为120
        self.fc2 = nn.Linear(120, 84)  # 全连接层,输入维度为120,输出维度为84
        self.fc3 = nn.Linear(84, 10)  # 全连接层,输入维度为84,输出维度为10(类别数)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 使用ReLU激活函数
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)  # 将张量展平,以适应全连接层
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建神经网络实例
net = Net()

# 步骤 4:定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适用于分类问题
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  # 使用随机梯度下降进行优化

# 步骤 5:训练神经网络模型
for epoch in range(2):  # 进行两个 epoch 的训练
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()  # 梯度归零,防止累加
        outputs = net(inputs)  # 前向传播,得到预测结果
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播,计算梯度
        optimizer.step()  # 更新参数
        running_loss += loss.item()  # 累加损失
        if i % 2000 == 1999:
            print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")  # 打印损失
            running_loss = 0.0
print("Finished Training")  # 训练完成

案例通过加载 CIFAR-10 数据集,构建一个简单的卷积神经网络,定义损失函数和优化器,并进行模型训练。训练过程中,我们采用了随机梯度下降(SGD)优化算法,使用交叉熵损失函数来优化分类任务。每个 epoch 的训练过程会在控制台输出损失值,以便我们监控训练的进展情况。

总结而言,PyTorch 是一个功能强大且易用的深度学习框架,适用于各种机器学习和深度学习任务。它的动态计算图、张量操作和自动求导等特性使得模型的构建和训练变得更加高效和灵活。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/81701.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springboot、java实现调用企业微信接口向指定用户发送消息

因为项目的业务逻辑需要向指定用户发送企业微信消息&#xff0c;所以在这里记录一下 目录 引入相关依赖创建配置工具类创建发送消息类测试类最终效果 引入相关依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-…

【Python机器学习】实验14 手写体卷积神经网络(PyTorch实现)

文章目录 LeNet-5网络结构&#xff08;1&#xff09;卷积层C1&#xff08;2&#xff09;池化层S1&#xff08;3&#xff09;卷积层C2&#xff08;4&#xff09;池化层S2&#xff08;5&#xff09;卷积层C3&#xff08;6&#xff09;线性层F1&#xff08;7&#xff09;线性层F2 …

数据可视化-canvas-svg-Echarts

数据可视化 技术栈 canvas <canvas width"300" height"300"></canvas>当没有设置宽度和高度的时候&#xff0c;canvas 会初始化宽度为 300 像素和高度为 150 像素。切记不能通过样式去设置画布的宽度与高度宽高必须通过属性设置&#xff0c;…

Gateway网关路由以及predicates用法(项目中使用场景)

1.Gatewaynacos整合微服务 服务注册在nacos上&#xff0c;通过Gateway路由网关配置统一路由访问 这里主要通过yml方式说明&#xff1a; route: config: #type:database nacos yml data-type: yml group: DEFAULT_GROUP data-id: jeecg-gateway-router 配置路由&#xff1a;…

Liunx系统编程:进程信号的概念及产生方式

目录 一. 进程信号概述 1.1 生活中的信号 1.2 进程信号 1.3 信号的查看 二. 信号发送的本质 三. 信号产生的四种方式 3.1 按键产生信号 3.2 通过系统接口发送信号 3.2.1 kill -- 向指定进程发送信号 3.2.2 raise -- 当自身发送信号 3.2.3 abort -- 向自身发送进程终止…

使用 Elasticsearch 轻松进行中文文本分类

本文记录下使用 Elasticsearch 进行文本分类&#xff0c;当我第一次偶然发现 Elasticsearch 时&#xff0c;就被它的易用性、速度和配置选项所吸引。每次使用 Elasticsearch&#xff0c;我都能找到一种更为简单的方法来解决我一贯通过传统的自然语言处理 (NLP) 工具和技术来解决…

基于Python的HTTP代理爬虫开发初探

前言 随着互联网的发展&#xff0c;爬虫技术已经成为了信息采集、数据分析的重要手段。然而在进行爬虫开发的过程中&#xff0c;由于个人或机构的目的不同&#xff0c;也会面临一些访问限制或者防护措施。这时候&#xff0c;使用HTTP代理爬虫可以有效地解决这些问题&#xff0…

麦肯锡发布《2023科技趋势展望报告》,生成式AI、下一代软件开发成为趋势,软件测试如何贴合趋势?

近日&#xff0c;麦肯锡公司发布了《2023科技趋势展望报告》。报告列出了15个趋势&#xff0c;并把他们分为5大类&#xff0c;人工智能革命、构建数字未来、计算和连接的前沿、尖端工程技术和可持续发展。 类别一&#xff1a;人工智能革命 生成式AI 生成型人工智能标志着人工智…

元宇宙电商—NFG系统:区块链技术助力商品确权。

在国内&#xff0c;以“数字藏品”之名崛起以来&#xff0c;其与NFT的对比就从未停歇。从上链模式到数据主权&#xff0c;从炒作需求到实际应用&#xff0c;从售卖形式到价值属性&#xff0c;在各种抽丝剥茧般的比较中&#xff0c;围绕两者孰优孰劣的讨论不绝于耳。 NFT的每一…

机器学习知识点总结:什么是EM(最大期望值算法)

什么是EM(最大期望值算法) 在现实生活中&#xff0c;苹果百分百是苹果&#xff0c;梨百分白是梨。 生活中还有很多事物是概率分布&#xff0c;比如有多少人结了婚&#xff0c;又有多少人有工作&#xff0c; 如果我们想要调查人群中吸大麻者的比例呢&#xff1f;敏感问题很难得…

React如何配置env环境变量

React版本&#xff1a; "react": "^18.2.0" 1、在package.json平级目录下创建.env文件 2、在‘.env’文件里配置环境变量 【1】PUBLIC_URL 描述&#xff1a;编译时文件的base-href 官方描述&#xff1a; // We use PUBLIC_URL environment variable …

解决C#报“MSB3088 未能读取状态文件*.csprojAssemblyReference.cache“问题

今天在使用vscode软件C#插件&#xff0c;编译.cs文件时&#xff0c;发现如下warning: 图(1) C#报cache没有更新 出现该warning的原因&#xff1a;当前.cs文件修改了&#xff0c;但是其缓存文件*.csprojAssemblyReference.cache没有更新&#xff0c;需要重新清理一下工程&#x…

clickhouse-监控配置

一、概述 监控是运维的一大利器&#xff0c;要想运维好clickhouse,首先就要对其进行监控&#xff0c;clickhouse有几种监控数据的方式&#xff0c;一种是系统本身监控&#xff0c;一种是通过exporter来监控&#xff0c;下面分别描述一下 二、系统自带监控 我下面会对监控做一…

三角形添加数--夏令营

题目 tips&#xff1a; 1.本题不要求正三角形输出&#xff0c;只要输出左下三角即可 2.这种输入三角形的&#xff0c;都是可以理解为左下三角形的模型&#xff0c;然后去写f[i][j]f[i-1][j]f[i-1][j1]&#xff0c;写行列 3.还有双重for循环输入输出三角形&#xff0c;注意第二…

linux 搭建 nexus maven私服

目录 环境&#xff1a; 下载 访问百度网盘链接 官网下载 部署 &#xff1a; 进入目录&#xff0c;创建文件夹,进入文件夹 将安装包放入nexus文件夹&#xff0c;并解压​编辑 启动 nexus,并查看状态.​编辑 更改 nexus 端口为7020,并重新启动&#xff0c;访问虚拟机7020…

【Spring专题】Spring之Bean的生命周期源码解析——阶段二(二)(IOC之属性填充/依赖注入)

目录 前言阅读准备阅读指引阅读建议 课程内容一、依赖注入方式&#xff08;前置知识&#xff09;1.1 手动注入1.2 自动注入1.2.1 XML的autowire自动注入1.2.1.1 byType&#xff1a;按照类型进行注入1.2.1.2 byName&#xff1a;按照名称进行注入1.2.1.3 constructor&#xff1a;…

如何解决使用npm出现Cannot find module ‘XXX\node_modules\npm\bin\npm-cli.js’错误

遇到问题&#xff1a;用npm下载组件时出现Cannot find module ‘D&#xff1a;software\node_modules\npm\bin\npm-cli.js’ 问题&#xff0c;导致下载组件不能完成。 解决方法&#xff1a;下载缺少的npm文件即可解决放到指定node_modules目录下即可解决。 分析问题&#xff1…

【自创】关于前端js的“嵌套地狱”的遍历算法

欢迎大家关注我的CSDN账号 欢迎大家关注我的哔哩哔哩账号&#xff1a;卢淼儿的个人空间-卢淼儿个人主页-哔哩哔哩视频 此saas系统我会在9月2号之前&#xff0c;在csdn及哔哩哔哩上发布成套系列教学视频。敬请期待&#xff01;&#xff01;&#xff01; 首先看图 这是我们要解…

Unity进阶–通过PhotonServer实现联网登录注册功能(客户端)–PhotonServer(三)

文章目录 Unity进阶–通过PhotonServer实现联网登录注册功能(客户端)–PhotonServer(三)前情提要客户端部分 Unity进阶–通过PhotonServer实现联网登录注册功能(客户端)–PhotonServer(三) 前情提要 单例泛型类 using System.Collections; using System.Collections.Generic; …

探索高级UI、源码解析与性能优化,了解开源框架及Flutter,助力Java和Kotlin筑基,揭秘NDK的魅力!

课程链接&#xff1a; 链接: https://pan.baidu.com/s/13cR0Ip6lzgFoz0rcmgYGZA?pwdy7hp 提取码: y7hp 复制这段内容后打开百度网盘手机App&#xff0c;操作更方便哦 --来自百度网盘超级会员v4的分享 课程介绍&#xff1a; &#x1f4da;【01】Java筑基&#xff1a;全方位指…