基于Pytorch实现图像分类——基于jupyter

分类任务

  • 网络基本构建与训练方法,常用函数解
  • torch.nn.functional模块
  • nn.Module模块

MNIST数据集下载

from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

解压数据集

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

查阅数据

from matplotlib import pyplot
import numpy as np

pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

在这里插入图片描述

网络模型搭建

在这里插入图片描述

import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

在这里插入图片描述

常用函数介绍

import torch.nn.functional as F

loss_func = F.cross_entropy

def model(xb):
    return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
bs = 64
bias = torch.zeros(10, requires_grad=True)

print(loss_func(model(xb), yb))

模型搭建

from torch import nn

class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out  = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
net = Mnist_NN()
print(net)

Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)

for name, parameter in net.named_parameters():
    print(name, parameter,parameter.size())

dataset数据接口

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)

def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
from torch import optim
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))

def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/542013.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue3中使用webstocket

1.在项目中创建webstocket.ts文件 export default class SocketService {// 单例static instance null;static get Instance() {if (!this.instance) {this.instance new SocketService();}return this.instance;}// 和服务端连接的socket对象ws null;// 存储回调函数callB…

202206青少年软件编程(scratch图形化) 等级考试试卷(四级)

第1题:【 单选题】 执行下列程序, 说的内容是? ( ) A:使 B:命 C:初 D:心 【正确答案】: D 【试题解析】 : 注意标点符号也是一个字符, 连接后字符串是“牢记使命! 不忘初心, …

宝藏免费音乐软件LX music

欢迎来到我的博客,代码的世界里,每一行都是一个故事 宝藏免费音乐软件LX music 前言LX Music的特色功能:音乐播放的新境界安装与配置:在不同平台上使用LX Music下载页面 主题定制 本文将深入研究LX Music,一款备受欢迎…

pytorch车牌识别

目录 使用pytorch库中CNN模型进行图像识别收集数据集定义CNN模型卷积层池化层全连接层 CNN模型代码使用模型 使用pytorch库中CNN模型进行图像识别 收集数据集 可以去找开源的数据集或者自己手做一个 最终整合成 类别分类的图片文件 定义CNN模型 卷积层 功能:提…

opencv基础图行展示

"""试用opencv创建画布并显示矩形框(适用于目标检测图像可视化) """ # 创建一个黑色的画布,图像格式(BGR) img np.zeros((512, 512, 3), np.uint8)# 画一个矩形:给定左上角和右下角坐标&#xff0…

Redis入门到通关之Hash命令

文章目录 ⛄介绍⛄命令⛄RedisTemplate API❄️❄️添加缓存❄️❄️设置过期时间(单独设置)❄️❄️添加一个Map集合❄️❄️提取所有的小key❄️❄️提取所有的value值❄️❄️根据key提取value值❄️❄️获取所有的键值对集合❄️❄️删除❄️❄️判断Hash中是否含有该值 ⛄…

文献阅读:猕猴的单个基底外侧杏仁核神经元表现出与额叶皮层不同的连接模式

文献介绍 「文献题目」 Single basolateral amygdala neurons in macaques exhibit distinct connectional motifs with frontal cortex 「研究团队」 Peter H. Rudebeck(美国西奈山伊坎医学院) 「发表时间」 2023-10-18 「发表期刊」 Neuron 「影响因…

Springboot+Vue项目-基于Java+MySQL的母婴商城系统(附源码+演示视频+LW)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &…

Ubuntu去除烦人的顶部【活动】按钮

文章目录 一、需求说明二、打开 extensions 网站三、安装 GNOME Shell 插件四、安装本地连接器五、安装 Hide Activities Button 插件六、最终效果七、卸载本地连接器命令参考 本文所使用的 Ubuntu 系统版本是 Ubuntu 22.04 ! 一、需求说明 使用 Ubuntu 的过程中,屏…

【大语言模型】应用:10分钟实现搜索引擎

本文利用20Newsgroup这个数据集作为Corpus(语料库),用户可以通过搜索关键字来进行查询关联度最高的News,实现对文本的搜索引擎: 1. 导入数据集 from sklearn.datasets import fetch_20newsgroupsnewsgroups fetch_20newsgroups()print(fNu…

在Linux驱动中,如何确保中断上下文的正确保存和恢复?

大家好,今天给大家介绍在Linux驱动中,如何确保中断上下文的正确保存和恢复?,文章末尾附有分享大家一个资料包,差不多150多G。里面学习内容、面经、项目都比较新也比较全!可进群免费领取。 在Linux驱动中&am…

AI图书推荐:如何在课堂上使用ChatGPT 进行教育

ChatGPT是一款强大的新型人工智能,已向公众免费开放。现在,各级别的教师、教授和指导员都能利用这款革命性新技术的力量来提升教育体验。 本书提供了一个易于理解的ChatGPT解释,并且更重要的是,详述了如何在课堂上以多种不同方式…

STM32利用软件I2C通讯读MPU6050的ID号

今天的读ID号是建立在上篇文章中有了底层的I2C通讯的6个基本时序来编写的。首先需要完成的就是MPU6050的初始化函数 然后就是编写 指定地址写函数: 一:开始 二:发送 从机地址读写位(1:读 0&#xff1…

MySQL之索引失效、覆盖、前缀索引及单列、联合索引详细总结

索引失效 最左前缀法则 如果索引了多列(联合索引),要遵守最左前缀法则,最左前缀法则指的是查询从索引的最左列开始,并且不跳过索引中的列。如果跳跃某一列,索引将部分失效(后面的字段索引失效)。 联合索…

第1章 计算机网络体系结构

王道学习 【考纲内容】 (一)计算机网络概述 计算机网络的概念、组成与功能;计算机网络的分类; 计算机网络的性能指标 (二)计算机网络体系结构与参考模型 计算机网络分层结…

权威Scrum敏捷开发企业级实训/敏捷开发培训课程

课程简介 Scrum是目前运用最为广泛的敏捷开发方法,是一个轻量级的项目管理和产品研发管理框架。 这是一个两天的实训课程,面向研发管理者、项目经理、产品经理、研发团队等,旨在帮助学员全面系统地学习Scrum和敏捷开发, 帮助企业快速启动敏…

Ansys Workbench拓扑优化教程

很基础。 前言 进行拓扑优化的好处在于可以简化结构,满足力学性能的同时简化结构。 如赵州桥的一大一小的拱,就可以用拓扑优化优化出来,可见一千四百多年以前古人的智慧是多么丰富。 步骤 大体的步骤是需要 1.先导入模型(需…

练习题(2024/4/13)

1长度最小的子数组 给定一个含有 n 个正整数的数组和一个正整数 target 。 找出该数组中满足其总和大于等于 target 的长度最小的 连续 子数组 [numsl, numsl1, ..., numsr-1, numsr] ,并返回其长度。如果不存在符合条件的子数组,返回 0 。 示例 1&am…

ThreadX:怎么确定一个线程应该开多少内存

ThreadX:如何确定线程的大小 在实时操作系统(RTOS)ThreadX中,线程的大小是一个重要的参数。这个参数决定了线程的堆栈大小,也就是线程可以使用的内存空间。那么,我们应该如何确定一个线程需要多大的字节呢…

Asterisk 21.2.0编译安装经常遇到的问题和解决办法之卸载pjsip

目录 会安装也要会卸载make uninstallldconfig 会安装也要会卸载 有些人就只会装。 最常见的场景就是需要卸载之前版本的pjproject。 一般来说,其他版本的 pjproject 会被作为静态链接库安装。这些库跟 Asterisk可能不兼容。 因此,在安装正确版本的pjpro…