mindspore mindcv图像分类算法;模型保存与加载

参考:
https://www.mindspore.cn/tutorials/en/r1.3/save_load_model.html
https://github.com/mindspore-lab/mindcv/blob/main/docs/zh/tutorials/finetune.md

1、mindspore mindcv图像分类算法

import os
from mindcv.utils.download import DownLoad
import os
import mindspore as ms


os.environ['DEVICE_ID']='0'
ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0)  ##指定cpu
#ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend", device_id=0)  ##需要使用才能npu加速



dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
root_dir = "./"

if not os.path.exists(os.path.join(root_dir, 'data/Canidae')):
    DownLoad().download_and_extract_archive(dataset_url, root_dir)
    
    
    
##加载数据

from mindcv.data import create_dataset, create_transforms, create_loader

num_workers = 8

# 数据集目录路径
data_dir = "./data/Canidae/"

# 加载自定义数据集
dataset_train = create_dataset(root=data_dir, split='train', num_parallel_workers=num_workers)
dataset_val = create_dataset(root=data_dir, split='val', num_parallel_workers=num_workers)



# 定义和获取数据处理及增强操作
trans_train = create_transforms(dataset_name='ImageNet', is_training=True)
trans_val = create_transforms(dataset_name='ImageNet',is_training=False)

loader_train = create_loader(
    dataset=dataset_train,
    batch_size=16,
    is_training=True,
    num_classes=2,
    transform=trans_train,
    num_parallel_workers=num_workers,
)
loader_val = create_loader(
    dataset=dataset_val,
    batch_size=5,
    is_training=True,
    num_classes=2,
    transform=trans_val,
    num_parallel_workers=num_workers,
)


#模型微调

from mindcv.models import create_model

network = create_model(model_name='densenet121', num_classes=2, pretrained=True)


#训练
from mindcv.loss import create_loss
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler
from mindspore import Model, LossMonitor, TimeMonitor

# 定义优化器和损失函数
opt = create_optimizer(network.trainable_params(), opt='adam', lr=1e-4)
loss = create_loss(name='CE')

# 实例化模型
model = Model(network, loss_fn=loss, optimizer=opt, metrics={'accuracy'})
model.train(10, loader_train, callbacks=[LossMonitor(5), TimeMonitor(5)], dataset_sink_mode=False)

res = model.eval(loader_val)
print(res)

import matplotlib.pyplot as plt
import mindspore as ms
import numpy as np

def visualize_model(model, val_dl, num_classes=2):
    # 加载验证集的数据进行验证
    images, labels= next(val_dl.create_tuple_iterator())
    # 预测图像类别
    output = model.predict(images)
    pred = np.argmax(output.asnumpy(), axis=1)
    # 显示图像及图像的预测值
    images = images.asnumpy()
    labels = labels.asnumpy()
    class_name = {0: "dogs", 1: "wolves"}
    plt.figure(figsize=(15, 7))
    for i in range(len(labels)):
        plt.subplot(3, 6, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}'.format(class_name[pred[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()
    
visualize_model(model, loader_val)

在这里插入图片描述

在这里插入图片描述

2、模型保存与加载

## 保存模型
import mindspore as ms

from mindcv.models import create_model

network = create_model(model_name='densenet121', num_classes=2, pretrained=True)

ms.save_checkpoint(network, "model1.ckpt")
## 加载模型
from mindspore import load_checkpoint, load_param_into_net
from mindspore import Model

param_dict = load_checkpoint("model1.ckpt")
param_not_load = load_param_into_net(network, param_dict)
print(param_not_load)

model1 = Model(network, loss, metrics={"accuracy"})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/120303.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

移远EC600U-CN开发板 day03

控件探索-按钮(lv.btn) (1) 创建并显示一个按钮 * 核心代码 btn lv.btn(scr) #将按钮与src对象关联 btn.align(lv.ALIGN.CENTER,0,0) #居中显示(第1个0表示x的偏移量,第2个0表示相对于y的偏移量) label lv.l…

小白学爬虫:手机app分享商品短连接获取淘宝商品链接接口|淘宝淘口令接口|淘宝真实商品链接接口|淘宝商品详情接口

通过手机APP分享的商品短链接,我们可以调用相应的接口来获取淘口令真实URL,进而获取到PC端的商品链接及商品ID。具体步骤如下: 1、通过手机APP分享至PC端的短链接,调用“item_password”接口。 2、该接口将返回淘口令真实URL。 3…

从使用的角度看 ByConity 和 ClickHouse 的差异

自 ClickHouse Inc 宣布其重要新功能仅在 ClickHouse Cloud 上开放以来,一些关注 ByConity 开源的社区小伙伴也来询问 ByConity 后续开源规划。为回答社区疑问,我们将之前分享的关于 ByConity 与 ClickHouse 相关功能对比的 webinar 整理为文章&#xff…

Git 基础知识回顾及 SVN 转 Git 自测

背景 项目开发过程中使用的版本控制工具是 SVN,Git 多有耳闻,以前也偶尔玩过几次,但是工作中不用,虽然本地也有环境,总是不熟练。 最近看一本网络开源技术书时,下载源码部署了一下,又温故了一…

助力工业数字化!TDengine 与恩菲 MIM+ 工业互联网平台实现兼容性互认

在云计算、物联网、5G等新兴技术快速发展的当下,制造企业想要运用新兴技术实现数字化转型,工业互联网平台的应用和打造是非常关键的转型要素。在工业互联网平台的发展中,数据处理上存在的问题一直都是令企业所头疼的,越来越多的案…

Java并发编程第11讲——AQS设计思想及核心源码分析

Java并发包(JUC)中提供了很多并发工具,比如前面介绍过的ReentrantLock、ReentrantReadWriteLock、CountDownLatch、Semaphore、FutureTask等锁或者同步部件,它们的实现都用到了一个共同的基类——AbstractQueuedSynchronizer&…

汽车网络安全渗透测试概述

目录 1.汽车网络安全法规概述 1.1 国外标准 1.2 国内标准 2.汽车网络安全威胁分析 2.1 汽车网络安全资产定义 2.2 汽车网络安全影响场景及评级示例 3.汽车网络安全渗透测试描述 3.1 参考法规 3.2 渗透测试内容 4.小结 1.汽车网络安全法规概述 近年来,汽车…

【UE4】UE编辑器乱码问题

环境:UE4.27、vs2019 如何解决 问题原因,UE的编码默认是UTF-8,VS的默认编码是GBK 通过"高级保存选项" 直接修改VS的 .h头文件 的 编码 为 UTF-8 步骤1. 步骤2. 修改编码后,从新编译,然后就可以解决编辑器…

很多个pdf怎么合并在一起?

很多个pdf怎么合并在一起?作为一个办公室的伙伴,对于PDF格式肯定不会陌生。它强大的功能为我们的工作提供了许多便利。由于PDF文件格式的稳定性和安全性较高,我们通常在工作或学习中使用它来传输文件,很多人都喜欢将办公文件都做成…

Java用log4j写日志

日志可以方便追踪和调试问题,以前用log4net写日志,换Java了改用log4j写日志,用法和log4net差不多。 到apache包下载下载log4j的包,解压后把下图两个jar包引入工程 先到网站根下加一个log4j2.xml的配置文件来配置日志的格式和参…

自动驾驶行业观察之2023上海车展-----车企发展趋势(3)

合资\外资发展 宝马:i7、iX1新车亮相,未来将持续发力电动化、数字化(座舱) 宝马在本次车展重点展示了电动化产品,新发车型为i7 M70L、iX1、及i vision Dee概念车等车型。 • 展示重点:电动化数字化&#…

STM32H750之FreeRTOS学习--------(五)临界段代码保护

FreeRTOS 文章目录 FreeRTOS五、临界段代码保护临界段代码保护函数任务级进入临界段任务级退出临界段中断级进入临界段中断级退出临界段 任务调度器的挂起和恢复挂起任务调度器恢复任务调度器挂起任务调度器恢复任务调度器 五、临界段代码保护 临界段代码也叫做临界区&#xf…

Android EditText 实现强制性弹出只能输入英文的键盘

如果 EditText 控件不做任何特殊处理,例如笔者手机默认弹出的是百度输入法的软键盘,可实现中英文切换,并且自带英文单词智能联想功能(与系统安装输入法和设置相关)。但在某些应用场景下,例如在英语APP里练习…

RabbitMQ(高级特性):限流

消费端限流 在rabbitmq中,使用消费端限流必须开启手动签收信息 过MQ可以对请求进行“削峰填谷”,即通过消费端限流的方式限制消息的拉取速度,达到保护消费端的目的。 生产者批量发送消息: Test public void testSendBatch() {…

DBever 连接trino时区问题 The datetime zone id ‘GMT+08:00‘ is not recognised

DBever连接trino 测试连接成功,但是执行sql报时区不对、如果你默认使用的是大于jdk8的版本 会存在这个问题,因为jdk版本 jdk8 和jdk17 版本默认时区是不同的 trino官网明确说明了时区默认跟jdk走 解决方案 可以先行查看JDK本地时区库版本,执…

Django初窥门径-自定义用户模型

前言 自定义用户模型在Django应用中是一个重要的话题,它涉及到如何根据您的项目需求以及特定的用户身份验证和授权需求来调整用户模型。在以下前言中,我将讲述为什么自定义用户模型是如此重要以及其潜在的优势: 随着Web应用的不断发展&…

工业路由器网关的网络协议之NAT技术

在物联网通讯领域,NAT技术能将内网的一个私有IP转换成一个公网IP去接入互联网,解决组建局域网络时私有IP地址无法在公网上进行路由的问题。 NAT(Network Address Translation)的三种方式: 静态NAT 1、一个私有IP对应…

在微信小程序中怎么实现刮刮卡抽奖活动

在当今的数字化时代,微信小程序已经成为一种广泛使用的营销工具。通过各种互动活动,企业可以吸引用户的关注,提升品牌影响力。其中,刮刮卡抽奖活动是一种特别受欢迎的形式。本文将为你详细介绍如何在微信小程序中实现刮刮卡抽奖活…

Python中通过socketserver库创建服务端

socketserver库是Python的标准库,提供了套接字服务端的框架,通过该框架可以简化服务端的创建流程。 1 socketserver库的导入 通过如图1显示的代码导入socketserver库。 图1 导入socketserver库 2 通过socketserver库创建TCP服务端 通过socketserver库…

SpringCloud——服务容错——Hystrix

1.现在的微服务存在哪些问题? 在大型的微服务项目中,肯定少不了服务之间多条链路调用,如果调用中有一个服务出现了问题,如果不做任何的处理,就会造成大量的阻塞,可能会导致整个服务雪崩。 2.要解决的问题 …