深度学习实战模拟——softmax回归(图像识别并分类)

目录

1、数据集:

2、完整代码


1、数据集:

1.1 Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。

1.2 Fashion‐MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像和测试数据 集(test dataset)中的1000张图像组成。因此,训练集和测试集分别包含60000和10000张图像。测试数据集 不会用于训练,只用于评估模型性能。

以下函数用于在数字标签索引及其文本名称之间进行转换。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

2、完整代码

import torch
import torchvision
import pylab
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt
from d2l import torch as d2l
import time

batch_size = 256
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
num_epochs = 5


class Accumulator:
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())


def cross_entropy(y_hat, y):
    return -torch.log(y_hat[range(len(y_hat)), y])


def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp/partition


def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)


def get_dataloader_workers():
    """使用一个进程来读取的数据"""
    return 1


def get_fashion_mnist_labels(labels):
    """返回Fashion-MNIST数据集的文本标签"""
    #共10个类别
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """画一系列图片"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    for i, (img, label) in enumerate(zip(imgs, titles)):
        xloc, yloc = i//num_cols, i % num_cols
        if torch.is_tensor(img):
            # 图片张量
            axes[xloc, yloc].imshow(img.reshape((28, 28)).numpy())
        else:
            # PIL图片
            axes[xloc, yloc].imshow(img)
        # 设置标题并取消横纵坐标上的刻度
        axes[xloc, yloc].set_title(label)
        plt.xticks([], ())
        axes[xloc, yloc].set_axis_off()
    pylab.show()


def load_data_fashion_mnist(batch_size, resize=None):
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = transforms.ToTensor()
    if resize:
        trans.insert(0, transforms.Resize(resize))
    mnist_train = torchvision.datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))


def evaluate_accuracy(net, data_iter):
    """计算在指定数据集上模型的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]


def updater(batch_size):
    lr = 0.1
    return d2l.sgd([W, b], lr, batch_size)


def train_epoch_ch3(net, train_iter, loss, updater):
    if isinstance(net, torch.nn.Module):
        net.train()
    metric = Accumulator(3)
    for X, y in train_iter:
        y_hat = net(X)
        lo = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            lo.backward()
            updater.step()
            metric.add(float(lo)*len(y), accuracy(y_hat, y), y.size().numel())
        else:
            lo.sum().backward()
            updater(X.shape[0])
        metric.add(float(lo.sum()), accuracy(y_hat, y), y.numel())
    return metric[0] / metric[2], metric[1] / metric[2]


class Animator:  #@save
    """绘制数据"""
    def __init__(self, legend=None):
        self.legend = legend
        self.X = [[], [], []]
        self.Y = [[], [], []]

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)

    def show(self):
        plt.plot(self.X[0], self.Y[0], 'r--')
        plt.plot(self.X[1], self.Y[1], 'g--')
        plt.plot(self.X[2], self.Y[2], 'b--')
        plt.legend(self.legend)
        plt.xlabel('epoch')
        plt.ylabel('value')
        plt.title('Visual')
        plt.show()


def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save
    """训练模型"""
    animator = Animator(legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        train_loss, train_acc = train_metrics
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
        print(f'epoch: {epoch+1},train_loss:{train_loss:.4f}, train_acc:{train_acc:.4f}, test_acc:{test_acc:.4f}')
    animator.show()


def predict_ch3(net, test_iter, n=12):
    """预测标签"""
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    show_images(
        X[0:n].reshape((n, 28, 28)), 2, int(n/2), titles=titles[0:n])


if __name__ == '__main__':
    train_iter, test_iter = load_data_fashion_mnist(batch_size)
    train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
    predict_ch3(net, test_iter)

分类效果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/468124.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PSCA系统控制集成之复位层次结构

PPU 提供以下对复位控制的支持。 • 复位信号Reset signals:PPU 提供冷复位和热复位输出信号。PPU 还为实现部分保留的电源域管理提供了额外的热复位输出信号。 • 电源模式控制Power mode control:PPU 硬件适当地管理每个支持的电源模式转换的复位信号…

高端资源素材源码图库下载平台整站系统,附带整站源码

推荐高端图库素材下载站的响应式模板和完整的整站源码,适用于娱乐网资源网。 该模板支持移动端,并集成了支付宝接口。 页面设计精美,不亚于大型网站的美工水准,并且用户体验也非常人性化。主要用户分为两类:下载用户…

Python大数据实践:selenium爬取京东评论数据

准备工作 selenium安装 Selenium是广泛使用的模拟浏览器运行的库,用于Web应用程序测试。 Selenium测试直接运行在浏览器中,就像真正的用户在操作一样,并且支持大多数现代 Web 浏览器。 #终端pip安装 pip install selenium #清华镜像安装 p…

探索编程迷宫:选择你的职业赛道

在现代科技的浪潮中,程序员的职业赛道就像是一座迷宫,充满着前端的美丽花园,后端的黑暗洞穴,以及数据科学的神秘密室。这个迷宫中,每一条通道都充满了挑战和机遇,而每一个行走其中的人都在寻找着属于自己的…

<商务世界>《第15课 投标文件一般包含的子文件》

1 响应文件封面 招标文件中的响应文件封面是投标人或参与者在提交响应文件时所使用的封面设计。这个封面不仅仅是文件的外包装,更是投标人形象和专业素质的直观展示,对于给招标方留下良好的第一印象至关重要。 首先,响应文件封面通常会包含…

Linux---基本操作命令之用户管理命令

1.1useradd 添加新用户 root用户:/root 普通用户:/home/ 创建的用户还是david,只是在dave文件夹下 1.2 passwd 设置密码 给用户tony设置密码: 123456 1.3 id 查看用户是否存在 查看有没有这个用户:id 名字 gid:用…

Css提高——Css3的新增选择器

目录 1、Css3新增选择器列举 2、属性选择器 2.1、语法 2.2、代码: 2.3、效果图 3、结构伪类选择器 3.1、语法 3.2、代码 3.3、效果图 3.4、nth:child(n)的用法拓展 nth-child(n)与nth-of-type&#x…

mac os 配置两个github账号

1. 清空git全局配置的username和email git config --global --unset user.name git config --global --unset user.emailgit config --list 可以查看是否清空了 2. 定义两个标识符,这两个标识符以后会被用来代替“github.com”来使用。 假设两个账号的邮箱地址分别是a@gmai…

docker init 生成Dockerfile和docker-compose.yml —— 筑梦之路

官网:https://docs.docker.com/engine/reference/commandline/init/ 简介 docker init是一个命令行实用程序,可帮助初始化项目中的 Docker 资源。.dockerignore它根据项目的要求创建 Dockerfile、Compose 文件。这简化了为项目配置 Docker 的过程&#…

关于前端的学习

目录 前言: 1.初识HTML: 1.1超文本: 1.2标记语言: 2.关于html的基本框架: 3.HTML基本文字标签: 3.1.h标题标签: 3.3 文本内容: 3.4换行的和分割的: 3.5 特殊文字标签: 3.5.1表面上看着三对的结果呈现都是一样的: 3.5.2但是其背后的效果其实是不一样的: 3.6转义字符:…

CentOS 7 编译安装 Nginx

CentOS 7 编译安装 Nginx 背景下载 Nginx 源码包安装依赖包编译添加环境变量添加守护查考文献 背景 一开始使用 docker 搭建了一个 web 服务器,但是由于 docker 不太方便的部署 TLS 证书,故使用 Nginx 做反向代理,实现 https 连接。 下载 N…

MySQL—数据库导入篇

什么是数据库? 数据库是干啥的? 数据库(Database)是按照数据结构来组织、存储和管理数据的仓库。 MySQL属于哪一类数据库? MySQL是一种关系型数据库。所谓的关系型数据库,是建立在关系模型基础上的数据库&a…

File文件对象

在计算机系统中,文件是非常重要的存储方式。Files(java.nio.file.Files)提供了多种方法来处理文件系统中的文件。比直接使用File文件要方便。 Files工具类:读取指定文件中的所有文本 package study1;import java.io.IOException; import ja…

LLM预备知识、工具篇——LLM+LangChain+web UI的架构解析

目录 【常见名词】一、LLM的低资源模型微调二、向量数据库1、Milvus(v2.1.4):云原生自托管向量数据库(Ubuntu下)1)安装(Docker Compose方式):2)管理工具(仅支持Milvus 2.…

JVM的双亲委派模型和垃圾回收机制

jvm的作用是解释执行java字节码.java的跨平台就是靠jvm实现的.下面看看一个java程序的执行流程. 1. jvm中的内存区域划分 jvm也是一个进程,进程在运行过程中,要行操作系统申请一些资源.这些内存空间就支撑了后续java程序的执行. jvm从系统申请了一大块内存,这块内存在java程序使…

在Visual Studio中调试 .NET源代码

前言 在我们日常开发过程中常常会使用到很多其他封装好的第三方类库(NuGet依赖项)或者是.NET框架中自带的库。如果可以设置断点并在NuGet依赖项或框架本身上使用调试器的所有功能,那么我们的源码调试体验和生产效率会得到大大的提升。今天我…

医疗器械经营许可证办理流程及申请流程有哪些?

1、证书内容差异: 1.医疗器械经营许可证应当载明许可证号码、法定代表人、负责人、住所、经营范围、仓库地址、发证部门、日期及有效期、公司名称等事项。 2.医疗器械生产经营管理注册证书应当载明编号、公司产品名称、法定代表人、住所、经营活动场所、业务发展方…

从WAF到WAAP的研究

对于需要保护Web应用程序和API的企业来说,从WAF到WAAP的转变已成为一种必然趋势。采用WAAP平台可以更为全面和高效地保护Web应用程序和API的安全,同时避免了高昂的维护成本和攻击绕过WAF的风险。 网络安全领域的发展趋势是从WAF到WAAP的转变。WAF作为传…

2024年发布jar到国外maven中央仓库最新教程

2024年发布jar到国外maven中央仓库最新教程 文章目录 1.国外sonatype仓库的版本1.1老OSSHR账号注册说明1.2新账号注册说明 2.新账号注册(必选)3.新账号登录创建Namespace3.1创建Namespace的名字的格式要求(必选)3.2发布一个静态网站(可选&…

java数据结构与算法刷题-----LeetCode1005. K 次取反后最大化的数组和(这就不是简单题)

java数据结构与算法刷题目录(剑指Offer、LeetCode、ACM)-----主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/123063846 卷来卷去,把简单题都卷成中等题了 文章目录 1. 排序后从小到大…