PyTorch 快速入门

我们将通过一个简单的示例,快速了解如何使用 PyTorch 进行机器学习任务。PyTorch 是一个开源的机器学习库,它提供了丰富的工具和库,帮助我们轻松地构建、训练和测试神经网络模型。以下是本教程的主要内容:

一、数据处理

PyTorch 提供了两个基本的数据处理工具:torch.utils.data.DataLoadertorch.utils.data.DatasetDataset 用于存储样本及其对应的标签,而 DataLoader 则为 Dataset 提供了一个可迭代的包装器。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

PyTorch 还提供了许多特定领域的库,如 TorchText、TorchVision 和 TorchAudio,这些库中都包含了各种数据集。在本教程中,我们将使用 TorchVision 中的 FashionMNIST 数据集。每个 TorchVision Dataset 都包含两个参数:transformtarget_transform,分别用于修改样本和标签。

# 下载 FashionMNIST 数据集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

我们将 Dataset 作为参数传递给 DataLoaderDataLoader 为我们的数据集提供了一个可迭代的包装器,并支持自动批处理、采样、洗牌和多进程数据加载。在这里,我们定义了一个大小为 64 的批次,即数据加载器可迭代对象的每个元素将返回一个包含 64 个特征和标签的批次。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

二、创建模型

在 PyTorch 中,我们通过创建一个继承自 nn.Module 的类来定义神经网络。我们在 __init__ 函数中定义网络的层,并在 forward 函数中指定数据如何通过网络传递。为了加速神经网络中的操作,我们将网络移动到加速器(如 CUDA、MPS、MTIA 或 XPU)上。如果当前加速器可用,我们将使用它;否则,我们将使用 CPU。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

三、优化模型参数

为了训练模型,我们需要一个损失函数和一个优化器。在一个训练循环中,模型会对训练数据集(以批次形式提供)进行预测,并将预测误差反向传播以调整模型的参数。

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

我们还需要检查模型在测试数据集上的性能,以确保它正在学习。

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程会在多个迭代(epoch)中进行。在每个 epoch 中,模型会学习参数以做出更好的预测。我们在每个 epoch 打印模型的准确率和损失;我们希望看到准确率随着每个 epoch 的增加而增加,损失则随着每个 epoch 的增加而减少。

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

四、保存模型

保存模型的一种常见方法是序列化内部状态字典(包含模型参数)。

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

五、加载模型

加载模型的过程包括重新创建模型结构,并将状态字典加载到其中。

model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

现在,我们可以使用这个模型来进行预测。

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

以上就是使用 PyTorch 进行机器学习任务的基本流程。

六、完整代码

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# 下载 FashionMNIST 数据集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# 创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

# 定义神经网络模型
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# 检查是否有可用的加速器
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# 实例化模型并移动到加速器上
model = NeuralNetwork().to(device)
print(model)

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# 定义测试函数
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# 训练和测试模型
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

# 保存模型
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

# 加载模型
model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

# 使用模型进行预测
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962293.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Attention--人工智能领域的核心技术

1. Attention 的全称与基本概念 在人工智能(Artificial Intelligence,AI)领域,Attention 机制的全称是 Attention Mechanism(注意力机制)。它是一种能够动态分配计算资源,使模型在处理输入数据…

DeepSeek能执行程序吗?

1. 前言 大过年的,继续蹭DeepSeek的热点,前面考察了DeepSeek能否进行推理(DeekSeek能否进行逻辑推理),其实似乎没有结论,因为还没有到上难度,DeepSeek似乎就纠结在一些与推理无关的事情上了&am…

5.3.2 软件设计原则

文章目录 抽象模块化信息隐蔽与独立性衡量 软件设计原则:抽象、模块化、信息隐蔽。 抽象 抽象是抽出事物本质的共同特性。过程抽象是指将一个明确定义功能的操作当作单个实体看待。数据抽象是对数据的类型、操作、取值范围进行定义,然后通过这些操作对数…

STM32 TIM编码器接口测速

编码器接口简介: Encoder Interface 编码器接口 编码器接口可接收增量(正交)编码器的信号,根据编码器旋转产生的正交信号脉冲,自动控制CNT自增或自减,从而指示编码器的位置、旋转方向和旋转速度 每个高级定…

四.4 Redis 五大数据类型/结构的详细说明/详细使用( zset 有序集合数据类型详解和使用)

四.4 Redis 五大数据类型/结构的详细说明/详细使用( zset 有序集合数据类型详解和使用) 文章目录 四.4 Redis 五大数据类型/结构的详细说明/详细使用( zset 有序集合数据类型详解和使用)1. 有序集合 Zset(sorted set)2. zset 有序…

Spring AI 在微服务中的应用:支持分布式 AI 推理

1. 引言 在现代企业中,微服务架构 已成为开发复杂系统的主流方式,而 AI 模型推理 也越来越多地被集成到业务流程中。如何在分布式微服务架构下高效地集成 Spring AI,使多个服务可以协同完成 AI 任务,并支持分布式 AI 推理&#x…

使用Ollama和Open WebUI快速玩转大模型:简单快捷的尝试各种llm大模型,比如DeepSeek r1

Ollama本身就是非常优秀的大模型管理和推理组件,再使用Open WebUI更加如虎添翼! Ollama快速使用指南 安装Ollama Windows下安装 下载Windows版Ollama软件:Release v0.5.7 ollama/ollama GitHub 下载ollama-windows-amd64.zip这个文件即可…

EasyExcel写入和读取多个sheet

最近在工作中,作者频频接触到Excel处理,因此也对EasyExcel进行了一定的研究和学习,也曾困扰过如何处理多个sheet,因此此处分享给大家,希望能有所帮助 目录 1.依赖 2. Excel类 3.处理Excel读取和写入多个sheet 4. 执…

《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》

DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance)订阅 已识别 - 已识别问题,并且正在实施修复。 1月 29, 2025 - 20:57 CST 更新 - 我们将继续监控任何其他问题。 1月 28, 2025 - 22&am…

安卓(android)饭堂广播【Android移动开发基础案例教程(第2版)黑马程序员】

一、实验目的(如果代码有错漏,可查看源码) 1.熟悉广播机制的实现流程。 2.掌握广播接收者的创建方式。 3.掌握广播的类型以及自定义官博的创建。 二、实验条件 熟悉广播机制、广播接收者的概念、广播接收者的创建方式、自定广播实现方式以及有…

分享|借鉴传统操作系统中分层内存系统的理念(虚拟上下文管理技术)提升LLMs在长上下文中的表现

《MemGPT: Towards LLMs as Operating Systems》 结论: 大语言模型(LLMs)上下文窗口受限问题的背景下, 提出了 MemGPT,通过类操作系统的分层内存系统的虚拟上下文管理技术, 提升 LLMs 在复杂人物&#…

games101-作业3

由于此次试验需要加载模型,涉及到本地环节,如果是windows系统,需要对main函数中的路径稍作改变: 这么写需要: #include "windows.h" 该段代码: #include "windows.h" int main(int ar…

Spring Boot 日志:项目的“行车记录仪”

一、什么是Spring Boot日志 (一)日志引入 在正式介绍日志之前,我们先来看看上篇文章中(Spring Boot 配置文件)中的验证码功能的一个代码片段: 这是一段校验用户输入的验证码是否正确的后端代码&#xff0c…

【大厂AI实践】OPPO:大规模知识图谱及其在小布助手中的应用

导读:OPPO知识图谱是OPPO数智工程系统小布助手团队主导、多团队协作建设的自研大规模通用知识图谱,目前已达到数亿实体和数十亿三元组的规模,主要落地在小布助手知识问答、电商搜索等场景。 本文主要分享OPPO知识图谱建设过程中算法相关的技…

机器学习周报-文献阅读

文章目录 摘要Abstract 1 相关知识1.1 WDN建模1.2 掩码操作(Masking Operation) 2 论文内容2.1 WDN信息的数据处理2.2 使用所收集的数据构造模型2.2.1 Gated graph neural network2.2.2 Masking operation2.2.3 Training loss2.2.4 Evaluation metrics 2…

工具的应用——安装copilot

一、介绍Copilot copilot是一个AI辅助编程的助手,作为需要拥抱AI的程序员可以从此尝试进入,至于好与不好,应当是小马过河,各有各的心得。这里不做评述。重点在安装copilot的过程中遇到了一些问题,然后把它总结下&…

后盾人JS--闭包明明白白

延伸函数环境生命周期 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> <…

顺启逆停程序

两台电机用Q0.0和Q0.1表示&#xff0c;分别有自身的启动和停止按钮&#xff0c;第一台电机启动后&#xff0c;第二台电机才能启动。停止时&#xff0c;第二台电机停止后&#xff0c;第一台电机才能停止。 1. 按下按钮SB1&#xff0c;接触器KM1线圈得电吸合&#xff0c;主触点…

登录授权流程

发起一个网络请求需要&#xff1a;1.请求地址 2.请求方式 3.请求参数 在检查中找到request method&#xff0c;在postman中设置同样的请求方式将登录的url接口复制到postman中&#xff08;json类型数据&#xff09;在payload中选择view parsed&#xff0c;将其填入Body-raw中 …

CUDA学习-内存访问

一 访存合并 1.1 说明 本部分内容主要参考: 搞懂 CUDA Shared Memory 上的 bank conflicts 和向量化指令(LDS.128 / float4)的访存特点 - 知乎 1.2 share memory结构 图1.1 share memory结构 放在 shared memory 中的数据是以 4 bytes(即 32 bits)作为 1 个 word,依…