MindSpore基础教程:使用 MindCV和 Gradio 创建一个图像分类应用

MindSpore基础教程:使用 MindCV和 Gradio 创建一个图像分类应用

官方文档教程使用已经弃用的MindVision模块,本文是对官方文档的更新
在这篇博客中,我们将探索如何使用 MindSpore 框架和 Gradio 库来创建一个基于深度学习的图像分类应用。我们将使用预训练的 ResNet50 模型,以 CIFAR-10 数据集为例进行训练,并通过 Gradio 接口进行图像分类预测。下面是一个简单、直观的指南,适用于希望将深度学习模型转换为交互式应用的开发者。

训练模型

环境设置

首先,我们需要设置 GPU 作为训练的目标设备。MindSpore 提供了一个便捷的方式来配置环境。

from mindspore import context
context.set_context(device_target="GPU")

解析参数

我们使用 argparse 来解析命令行参数。这样可以方便地在训练时调整参数,例如数据集路径、学习率和训练周期数。

import argparse
def parse_args():
    """
    解析命令行参数。

    返回:
        argparse.Namespace: 包含命令行参数的命名空间。
    """
    parser = argparse.ArgumentParser(description="训练 ResNet 模型",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--pretrain_path', type=str, default='',
                        help='预训练文件的路径')
    parser.add_argument('--data_path', type=str, default='datasets/drizzlezyk/cifar10/',
                        help='训练数据的路径')
    parser.add_argument('--output_path', default='train/resnet/', type=str,
                        help='模型保存路径')
    parser.add_argument('--epochs', default=10, type=int, help='训练周期数')
    parser.add_argument('--lr', default=0.0001, type=int, help='学习率')
    return parser.parse_args()

创建数据集

使用 MindSpore 的 create_dataset 方法,我们可以轻松创建和预处理 CIFAR-10 训练数据集。

from mindcv.data import create_dataset, create_transforms, create_loader


def create_training_dataset(data_path, batch_size):
    """
    创建训练数据集。

    参数:
        data_path (str): 数据集的路径。
        batch_size (int): 批量大小。

    返回:
        Tuple[DataLoader, int]: 数据加载器和每个 epoch 的批次数量。
    """
    dataset_train = create_dataset(name='cifar10', root=data_path, split='train', shuffle=True)
    transform_train = create_transforms(dataset_name='cifar10', image_resize=224)
    train_loader = create_loader(dataset=dataset_train, batch_size=batch_size, is_training=True,
                                 num_classes=10, transform=transform_train)
    num_batches = train_loader.get_dataset_size()
    return train_loader, num_batches

模型训练

接下来,我们定义 train_model 函数来实现模型的训练逻辑。这包括模型的初始化、损失函数、优化器的设置,以及训练过程的启动。

from mindcv import create_model, create_loss, create_scheduler, create_optimizer
from mindspore.train import Model
from mindspore import load_checkpoint, load_param_into_net

def train_model(args):
    """
    训练模型。

    参数:
        args (argparse.Namespace): 包含命令行参数的命名空间。
    """
    train_loader, num_batches = create_training_dataset(args.data_path, batch_size=32)

    net = create_model(model_name='resnet50', num_classes=10)

    if args.pretrain_path:
        param_dict = load_checkpoint(args.pretrain_path)
        load_param_into_net(net, param_dict)

    loss_fn = create_loss(name='CE', reduction='mean')

    lr_scheduler = create_scheduler(steps_per_epoch=num_batches, scheduler='constant', lr=args.lr)

    optimizer = create_optimizer(net.trainable_params(), opt='adam', lr=lr_scheduler)

    model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})

    checkpoint_config = CheckpointConfig(save_checkpoint_steps=num_batches, keep_checkpoint_max=10)
    checkpoint_callback = ModelCheckpoint(prefix='checkpoint_resnet', directory=args.output_path,
                                          config=checkpoint_config)

    model.train(args.epochs, train_loader,
                callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor(data_size=num_batches)])

构建 Gradio 接口

预测函数

在 Gradio 接口中,我们定义一个 predict_image 函数来处理图像输入并返回预测结果。

import gradio as gr
import numpy as np
from mindspore import Tensor
import cv2

def predict_image(img):
    # 创建模型实例
    net = create_model(model_name='resnet50', num_classes=NUM_CLASS)
    param_dict = load_checkpoint('/root/MyCode/pycharm/ResNet50/train/resnet/checkpoint_resnet-5_1563.ckpt')
    load_param_into_net(net, param_dict)

    # 封装模型为 Model 类实例
    model = Model(net)
    # 调整图像格式和大小
    img = cv2.resize(img, (224, 224))
    img = np.array(img, dtype=np.float32) / 255.0  # 归一化并确保数据类型为 Float32

    # 如果图像是 BGR 格式,转换为 RGB 格式
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # 标准化处理
    img = (img - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)

    # 转换维度 - 通道优先格式 (C, H, W)
    img = np.transpose(img, (2, 0, 1))

    # 添加批次维度 (N, C, H, W)
    img = np.expand_dims(img, axis=0)

    # 将图像数据转换为 MindSpore 张量
    img_tensor = Tensor(img, dtype=mindspore.float32)  # 显式指定数据类型

    # 预测图像
    output = model.predict(img_tensor)

    # 应用 Softmax 获取概率
    softmax = Softmax(axis=1)
    predict_probability = softmax(output).asnumpy()
    predict_probability = predict_probability[0]  # 获取批量中的第一个元素

    # 将预测概率映射到类别名称
    return {class_names[i]: float(predict_probability[i]) for i in range(NUM_CLASS)}

Gradio 界面

使用 Gradio,我们可以快速构建一个交互式界面。用户可以上传图片,模型将返回图像分类的预测结果。

image = gr.Image()
label = gr.Label(num_top_classes=NUM_CLASS)

gr.Interface(css=".footer {display:none !important}",
             fn=predict_image,
             inputs=image,
             live=False,
             description="Please upload a image in JPG, JPEG or PNG.",
             title='Image Classification by ResNet50',
             outputs=gr.Label(num_top_classes=NUM_CLASS, label="预测类别"),
             examples=['./example_img/airplane.jpg', './example_img/automobile.jpg', './example_img/bird.jpg',
                       './example_img/cat.jpg', './example_img/deer.jpg', './example_img/dog.jpg',
                       './example_img/frog.jpg', './example_img/horse.JPG', './example_img/ship.jpg',
                       './example_img/truck.jpg']
             ).launch(share=True)

image-20231121192446268

完整代码

import argparse

from mindcv import create_model, create_loss, create_scheduler, create_optimizer
from mindspore.train import Model
from mindspore import load_checkpoint, load_param_into_net
from mindcv.data import create_dataset, create_transforms, create_loader
from mindspore import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint

# 设置GPU
from mindspore import context

context.set_context(device_target="GPU")


def parse_args():
    """
    解析命令行参数。

    返回:
        argparse.Namespace: 包含命令行参数的命名空间。
    """
    parser = argparse.ArgumentParser(description="训练 ResNet 模型",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--pretrain_path', type=str, default='',
                        help='预训练文件的路径')
    parser.add_argument('--data_path', type=str, default='datasets/drizzlezyk/cifar10/',
                        help='训练数据的路径')
    parser.add_argument('--output_path', default='train/resnet/', type=str,
                        help='模型保存路径')
    parser.add_argument('--epochs', default=10, type=int, help='训练周期数')
    parser.add_argument('--lr', default=0.0001, type=int, help='学习率')
    return parser.parse_args()


def create_training_dataset(data_path, batch_size):
    """
    创建训练数据集。

    参数:
        data_path (str): 数据集的路径。
        batch_size (int): 批量大小。

    返回:
        Tuple[DataLoader, int]: 数据加载器和每个 epoch 的批次数量。
    """
    dataset_train = create_dataset(name='cifar10', root=data_path, split='train', shuffle=True)
    transform_train = create_transforms(dataset_name='cifar10', image_resize=224)
    train_loader = create_loader(dataset=dataset_train, batch_size=batch_size, is_training=True,
                                 num_classes=10, transform=transform_train)
    num_batches = train_loader.get_dataset_size()
    return train_loader, num_batches


def train_model(args):
    """
    训练模型。

    参数:
        args (argparse.Namespace): 包含命令行参数的命名空间。
    """
    train_loader, num_batches = create_training_dataset(args.data_path, batch_size=32)

    net = create_model(model_name='resnet50', num_classes=10)

    if args.pretrain_path:
        param_dict = load_checkpoint(args.pretrain_path)
        load_param_into_net(net, param_dict)

    loss_fn = create_loss(name='CE', reduction='mean')

    lr_scheduler = create_scheduler(steps_per_epoch=num_batches, scheduler='constant', lr=args.lr)

    optimizer = create_optimizer(net.trainable_params(), opt='adam', lr=lr_scheduler)

    model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})

    checkpoint_config = CheckpointConfig(save_checkpoint_steps=num_batches, keep_checkpoint_max=10)
    checkpoint_callback = ModelCheckpoint(prefix='checkpoint_resnet', directory=args.output_path,
                                          config=checkpoint_config)

    model.train(args.epochs, train_loader,
                callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor(data_size=num_batches)])


if __name__ == '__main__':
    train_model(parse_args())
import gradio as gr
import numpy as np
from mindspore import Tensor
from mindspore.nn import Softmax
import cv2
from typing import Type, Union, List, Optional
from mindspore import nn
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindcv.models import create_model
import mindspore

print(mindspore.__version__)

NUM_CLASS = 10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


def predict_image(img):
    # 创建模型实例
    net = create_model(model_name='resnet50', num_classes=NUM_CLASS)
    param_dict = load_checkpoint('/root/MyCode/pycharm/ResNet50/train/resnet/checkpoint_resnet-5_1563.ckpt')
    load_param_into_net(net, param_dict)

    # 封装模型为 Model 类实例
    model = Model(net)
    # 调整图像格式和大小
    img = cv2.resize(img, (224, 224))
    img = np.array(img, dtype=np.float32) / 255.0  # 归一化并确保数据类型为 Float32

    # 如果图像是 BGR 格式,转换为 RGB 格式
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # 标准化处理
    img = (img - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)

    # 转换维度 - 通道优先格式 (C, H, W)
    img = np.transpose(img, (2, 0, 1))

    # 添加批次维度 (N, C, H, W)
    img = np.expand_dims(img, axis=0)

    # 将图像数据转换为 MindSpore 张量
    img_tensor = Tensor(img, dtype=mindspore.float32)  # 显式指定数据类型

    # 预测图像
    output = model.predict(img_tensor)

    # 应用 Softmax 获取概率
    softmax = Softmax(axis=1)
    predict_probability = softmax(output).asnumpy()
    predict_probability = predict_probability[0]  # 获取批量中的第一个元素

    # 将预测概率映射到类别名称
    return {class_names[i]: float(predict_probability[i]) for i in range(NUM_CLASS)}


image = gr.Image()
label = gr.Label(num_top_classes=NUM_CLASS)

gr.Interface(css=".footer {display:none !important}",
             fn=predict_image,
             inputs=image,
             live=False,
             description="Please upload a image in JPG, JPEG or PNG.",
             title='Image Classification by ResNet50',
             outputs=gr.Label(num_top_classes=NUM_CLASS, label="预测类别"),
             examples=['./example_img/airplane.jpg', './example_img/automobile.jpg', './example_img/bird.jpg',
                       './example_img/cat.jpg', './example_img/deer.jpg', './example_img/dog.jpg',
                       './example_img/frog.jpg', './example_img/horse.JPG', './example_img/ship.jpg',
                       './example_img/truck.jpg']
             ).launch(share=True)

总结

通过 MindSpore 和 Gradio,我们可以不仅训练强大的深度学习模型,还可以将这些模型转化为交互式应用,使非专业人士也能轻松体验 AI 的魅力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/172734.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

京东优惠券查询API接口接入方案,item_search_coupon - 京东优惠券查询接口演示

要接入京东优惠券查询API接口(item_search_coupon),您可以按照以下步骤进行操作: 注册并获取API密钥:首先,您需要在京东开放平台上注册并获取API密钥。这将为您提供唯一的标识符和密钥,用于访问…

博主都在用的网站,一键制作电子杂志

​随着互联网的发展,越来越多的人开始使用电子杂志来展示自己的作品或宣传自己的品牌。而制作电子杂志的工具也越来越多,其中一些工具非常受欢迎,被许多博主使用。今天,我们就来介绍一款博主都在用的网站,它可以帮助你…

【python】Python生成GIF动图,多张图片转动态图,pillow

pip install pillow 示例代码: from PIL import Image, ImageSequence# 图片文件名列表 image_files [car.png, detected_map.png, base64_image_out.png]# 打开图片 images [Image.open(filename) for filename in image_files]# 设置输出 GIF 文件名 output_g…

svn文件不显示红色感叹号

如下图所示,受svn版本控制的文件不显示下图中红色感叹号和绿色对号时, 可以试着如下操作 空白处单击右键,具体操作如下图

CF 1894A 学习笔记 思维 题意理解分析

原题 A. Secret Sport time limit per test 3 seconds memory limit per test 512 megabytes input standard input output standard output Lets consider a game in which two players, A and B, participate. This game is characterized by two positive integer…

Spring源码-5.aop代理

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码🔥如果感觉博主的文章还不错的话,请👍三连支持&…

4.2、Linux进程(1)

个人主页:Lei宝啊 愿所有美好如期而遇 目录 基本概念 描述进程-PCB task_struct-PCB的一种 task_struct内容分类 查看进程 通过系统调用获取进程标识符 前言 进入进程前,我建议读一读这两篇文章,他们都是进程的前导知识。 操作系统…

Joern安装与使用

环境准备 Joern需要在Linux环境中运行,所以在Windows系统中需要借助WSL或虚拟机安装。 JDK安装 Joern的运行需要JAVA环境的支持,本次采用的是JDK17,其他版本建议看一下Joern官方文档。 apt install openjdk-17-jre-headless 配置JAVA环境变…

shell脚本之条件语句

条件语句 linux测试 test 测试 测试表达式是否成立(用echo $? 检测是否正确) 语法:test [选项] [文件名] 选项作用-e测试文件是否存在-r查看文件有无读的权限-d测试是否为目录-f测试是否为文件-w测试当前用户有无写的权限-x测试是否有执…

EANet:用于医学图像分割的迭代边缘注意力网络

EANet: Iterative edge attention network for medical image segmentation EANet:用于医学图像分割的迭代边缘注意力网络背景贡献实验方法Dynamic scale-aware context module(动态规模感知上下文模块)Edge attention preservation module&a…

【Java】java | CacheManager | redisCacheManager

一、说明 1、查询增加缓存,使用Cacheable注解 2、项目中已经用到了ehcache,现在需求是两个都用 二、备份配置 1、redisConfig增加代码 Bean("redisCacheManage")Primarypublic CacheManager redisCacheManager(RedisConnectionFactory fact…

Matlab通信仿真系列——图形处理函数

微信公众号上线,搜索公众号小灰灰的FPGA,关注可获取相关源码,定期更新有关FPGA的项目以及开源项目源码,包括但不限于各类检测芯片驱动、低速接口驱动、高速接口驱动、数据信号处理、图像处理以及AXI总线等 本节目录 一、plot函数 (1)绘制一…

SystemV

一、共享内存 1、直接原理 进程间通信的本质是:先让不同的进程,看到同一份资源!! 我们要把这句话奉若圭臬一般 到了共享内存了支持双向通信能读也能写,但是一般都是一个读一个写 要想通信先看到同一个份资源&#xff0…

Lifecyle的原理

1、Lifecycle是典型的观察者模式,被观察者的继承关系如上图所示。 2、LifeCycleRegistry是Lifecycle的子类。 3、观察者通过LifeCycle对象的addObserver注册监听生命周期的变化,通过removeObserver移除监听生命周期的变化。 4、Activity或Fragment的生命…

HDFS的Shell操作

文章目录 一、HDFS的Shell介绍二、了解HDFS常用Shell命令(一)三种Shell命令方式(二)FileSystem Shell文档(三)常用HDFS的Shell命令 三、HDFS常用命令操作实战(一)创建目录&#xff0…

深度学习之基于Pytorch的昆虫分类识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介系统架构技术亮点 二、功能三、系统四. 总结 一项目简介 # 深度学习基于 Pytorch 的昆虫分类识别系统介绍 深度学习在图像分类领域取得了显著的成就&#…

windows上 adb devices有设备 wsl上没有

终于解决了!!!! TAT,尝试了很多种办法。 比如WSL中的adb和Windows中的adb版本必须一致,一致也没用,比如使用 ln 建立链接也没用。 这个解决办法的前提是windows中的abd是好用的。 ●在windows…

计算机显示msvcp140.dll丢失的解决方法,实测有效的5个方法分享

在日常的电脑操作中,常常遭遇某些错误讯息,如“缺少xxx.dll文件”,这些dll文件即为动态链接库文件,内含诸多可执行的程序码及数据。当启动某款应用时,系统将会自动调用与其相关的dll文件,其中msvcp140.dll便…

pycharm 控制台中文乱码处理

今天使用pycharm,发现控制台输出又中文乱码了,看网上很多资料说把编码改为UTF-8,设置为并未生效,特此在此记录下本地设置。 1. 修改文件编码:Setting -> Editor ->File Encodings,修改配置如下: 2. …