Pytorch教程(代码逐行解释)

0、配准环境教程

1、开始导入相应的包

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

torch是pytorch的简写
torch.utils.data import DataLoader 是用于读取数据的迭代器
torchvision是视觉处理包,datasets导入的是视觉相关的数据集
transforms 是用于图像变换的。

2、下载数据集(准备数据集)

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

datasets.FashionMNIST,指的是一个数据集,这个数据集用于服饰的识别。FashionMNIST是一个非常流行的图像分类数据集,其中包含10个类别的70000个28x28灰度图像。
当然,pytorch还有很多其他的数据集格式。例如以下的数据集。其他数据集可点击这个连接在这里插入图片描述

3、加载数据集

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

DataLoader是PyTorch中一个非常有用的模块,它主要用于批量加载数据,特别是当数据集非常大时,DataLoader可以极大地提高数据加载速度并减少内存占用。
DataLoader的主要功能包括:
批量处理数据:DataLoader可以将数据划分为多个批次(batch),每个批次包含一定数量的数据样本,然后一次处理一个批次的数据,这样可以大大减少内存占用。
数据打乱:通过设置shuffle=True参数,DataLoader可以在每个epoch开始时随机打乱数据集的顺序,这样可以增加模型的泛化能力。
batch_size 指的是每次读取的数据的大小,这里设置一次读取64张

4、创建训练的模型

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

super().init()表示调用父类(nn.Module)的 init() 方法
self.flatten = nn.Flatten(),这行代码的作用主要是在神经网络模型中的作用是将输入数据从多维(例如二维或三维)转化为一维,这个操作通常被称为"flatten"。
在这个例子中,该模型预期的输入是一个形状为[batch_size, 28, 28]的张量,即一个包含多个(这里是28*28=784个)特征值的数据集。nn.Flatten()层将这个三维数据转化为一维数组,以便后续的线性层(nn.Linear)能以更高效的方式进行操作。

nn.Sequential 是 PyTorch 中一个用于创建顺序神经网络模型的模块。它是一个有序的容器,可以包含任意数量的其他模块。当你将数据输入到 nn.Sequential 模型时,数据会按照你在容器中定义的顺序通过每个模块。

5、设置优化器以及损失函数

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

损失函数还有很多种,其他的参考点击这个链接
优化器也有很多种,如ASGD,ADAM等等,其他的参考这个链接

6、模型的训练

定义训练的过程

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

从数据集中,每次取一个图像个标签进行训练,然后反向传播,梯度优化,完成训练。
item():.item()是用来从张量中提取标量值的方法。当你调用.item()方法时,如果张量中只有一个元素,那么这个元素会被返回;如果张量中有多个元素,则会抛出一个错误。

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

correct += (pred.argmax(1) == y).type(torch.float).sum().item():解释:
(pred.argmax(1) == y):首先,这行代码通过argmax(1)获取了每个样本的预测类别。然后,它将预测类别与真实类别进行比较(==)。这将返回一个布尔型的张量,表示每个样本的预测是否正确。
(pred.argmax(1) == y).type(torch.float):接下来,这行代码将布尔型的张量转换为浮点型。在PyTorch中,布尔型的张量会自动转换为浮点型。
(pred.argmax(1) == y).type(torch.float).sum():然后,这行代码计算了所有样本中预测正确的总数。这是通过调用sum()函数实现的,该函数会返回一个张量中所有元素的和。
correct += …:最后,这行代码将预测正确的总数加到了变量correct上。+=是一个累加操作符,它将左侧的变量与右侧的表达式结果相加。

7、定义训练的轮次

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

8、保存模型

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

model.state_dict():解释:
model.state_dict()函数返回一个包含模型所有参数的字典,torch.save()函数则将这个字典保存到磁盘上的一个文件。

9、加载模型

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))

10、模型的测试

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

所有的完整代码:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/145649.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

抖斗音_快块手直播间获客助手+采集脚本+引流软件功能介绍

软件功能: 支持同时采集多个直播间,弹幕,关*注,礼*物,进直播间,部分用户手*号,粉*丝团采集 不支持采集匿*名直播间 设备需求: 电脑(win10系统) 文章分享者&#xff1…

线程有哪些状态

线程的生命周期 线程在Java中有以下几种状态: 新建(New):初始化状态就绪(Runnable):可运行、运行状态阻塞(Blocked):等待状态,无时限等待&#…

【k8s集群搭建(一):基于虚拟机的linux的k8s集群搭建_超详细_解决并记录全过程步骤以及自己的踩坑记录】

虚拟机准备3台Linux系统 k8s集群安装 每一台机器需要安装以下内容: docker:容器运行环境 kubelet:控制机器中所有资源 bubelctl:命令行 kubeladm:初始化集群的工具 Docker安装 安装一些必要的包,yum-util 提供yum-config-manager功能,另两…

dalle3:Improving image generation with better captions

文生图——DALL-E 3 —论文解读——第一版-CSDN博客文章浏览阅读236次。本文主要是DALLE 3官方第一版技术报告(论文)的解读。 一句话省流版,数据方面,训练时使用95%模型(CoCa)合成详细描述caption 5%原本人…

【Python】【应用】Python应用之一行命令搭建http、ftp服务器

🐚作者简介:花神庙码农(专注于Linux、WLAN、TCP/IP、Python等技术方向)🐳博客主页:花神庙码农 ,地址:https://blog.csdn.net/qxhgd🌐系列专栏:Python应用&…

R语言——taxize(第一部分)

ropensci 系列之 taxize (中译手册) taxize 包1. taxize支持的网络数据源简介目前支持的API:针对Catalogue of Life(COL) 2. 浅尝 taxize 的一些使用例子2.1. **从NCBI上获取唯一的分类标识符**2.2. **获取分类信息**2…

【LeetCode刷题-滑动窗口】-- 643.子数组最大平均数I

643.子数组最大平均数I 方法&#xff1a;滑动窗口 class Solution {public double findMaxAverage(int[] nums, int k) {int n nums.length;int winSum 0;//先求出第一个窗口的和for(int i 0;i<k;i){winSum nums[i];}//通过遍历求出除了第一窗口的和int res winSum;fo…

算法通关村第十六关青铜挑战——原来滑动窗口如此简单!

大家好&#xff0c;我是怒码少年小码。 从本篇开始&#xff0c;我们就要开始算法的新篇章了——四大思想&#xff1a;滑动窗口、贪心、回溯、动态规划。现在&#xff0c;向我们迎面走来的是——滑动窗口思想&#xff01;&#x1f61d; 滑动窗口思想 概念 在数组双指针里&am…

SQL使用

--天空会的像哭过&#xff0c;离开你以后 并没有更自由 SQL进行数据的删除 一、删除delete 语法 delete [from] 表名称 where 条件数据删除&#xff0c;不能删除某一列&#xff0c;因为删除是对记录而言 2.1 删除是一条一条删除&#xff0c;每次删除都会将操作写入日志文件 删…

记录第一次利用CVE-2023-33246漏洞实现RocketMQ宿主机远程代码执行的兴奋

我依然记得自己第一次发现xss漏洞时候的兴奋: 我也记得自己第一次发现sql输入时候的快乐: 直到最近我终于收获了人生的第一个远程代码执行漏洞的利用&#xff08;RCE:remote code execute&#xff09;&#xff0c;虽然这个漏洞的危害远超过了前两个&#xff0c;但是快乐不如前…

RocketMQ(二):基础API

Spring源码系列文章 RocketMQ(一)&#xff1a;基本概念和环境搭建 RocketMQ(二)&#xff1a;基础API 目录 一、RocketMQ快速入门1、生产者发送消息2、消费者接受消息3、代理者位点和消费者位点 二、消费模型特点1、同一个消费组的不同消费者&#xff0c;订阅主题必须相同2、不…

伊朗黑客对以色列科技行业发起恶意软件攻击

最近&#xff0c;安全研究人员发现了一场由“Imperial Kitten”发起的新攻击活动&#xff0c;目标是运输、物流和科技公司。 “Imperial Kitten”又被称为“Tortoiseshell”、“TA456”、“Crimson Sandstorm”和“Yellow Liderc”&#xff0c;多年来一直使用“Marcella Flore…

加密磁盘密钥设置方案浅析 — LUKS1

虚拟化加密磁盘密钥设置方案浅析 前言元数据分析元数据格式整体格式头部格式加密算法密码校验key slot格式其它字段 流程验证 前言 我们在虚拟化加密磁盘密钥设置方案浅析 — TKS1中介绍了加密磁盘密钥设置方案&#xff0c;TKS1对密钥设置(Linux Unified Key Setup)的流程和方…

模拟散列表(哈希表拉链法)

维护一个集合&#xff0c;支持如下几种操作&#xff1a; I x&#xff0c;插入一个整数 x&#xff1b;Q x&#xff0c;询问整数 x 是否在集合中出现过&#xff1b; 现在要进行 N 次操作&#xff0c;对于每个询问操作输出对应的结果。 输入格式 第一行包含整数 N&#xff0c;…

举报“将我的电脑控作己用者”!

既然“麻辣800727”都说是“街子电信”干的&#xff0c;那么&#xff0c;我现在就正式举报&#xff1a;请依法管理宽带网&#xff0c;你国营的也不可以随意侵犯用户的人权&#xff0c;更不可以将自己变成法外之地&#xff01; 请公开答复&#xff0c;并改正&#xff0c;否则把…

机器学习线性代数知识补充

线性代数知识补充 正交矩阵与正交变换方阵特征值与特征向量相似矩阵对角化二次型正定二次型 正交矩阵与正交变换 方阵特征值与特征向量 相似矩阵 对角化 二次型 正定二次型

H5游戏源码分享-超级染色体小游戏

H5游戏源码分享-超级染色体小游戏 游戏玩法 不断地扩大发展同颜色的色块 用最少的步数完成游戏 <!DOCTYPE html> <html><head><meta charset"UTF-8"><meta name"viewport"content"widthdevice-width,user-scalableno,init…

应届裁员,天胡开局——谈谈我的前端一年经历

应届裁员&#xff0c;天胡开局——谈谈我的前端一年经历 许久没有更新了&#xff0c;最近一个月都在忙&#xff0c;没错&#xff0c;正如题目所说&#xff0c;裁员然后找工作… 这周刚重新上班&#xff0c;工作第二天&#xff0c;感慨良多&#xff0c;记录些什么吧。 去年十…

学习samba

文章目录 一、samba介绍二、samba的主要进程三、配置文件四、例子 一、samba介绍 1、SMB&#xff08;Server Message Block&#xff09;协议实现文件共享&#xff0c;也称为CIFS&#xff08;Common Internet File System&#xff09;。 2、是Windows和类Unix系统之间共享文件的…

【Linux】gitee仓库的注册使用以及在Linux上远程把代码上传到gitee上的方法

君兮_的个人主页 即使走的再远&#xff0c;也勿忘启程时的初心 C/C 游戏开发 Hello,米娜桑们&#xff0c;这里是君兮_&#xff0c;今天为大家介绍一个在实际工作以及项目开发过程中非常实用的网站gitee&#xff0c;并教如何正确的使用这个网站以及常见问题的解决方案&#xf…