RK3568笔记十八:MobileNetv2部署测试

若该文为原创文章,转载请注明原文出处。

记录MobileNetv2训练测试

一、环境

1、平台:rk3568

2、开发板: ATK-RK3568正点原子板子

3、环境:buildroot

4、虚拟机:正点原子提供的ubuntu 20

二、MobileNetv2简介

       MobileNet ,它是谷歌研究人员于 2017 年开发的一种 CNN 架构,用于将计算机视觉有效地融入 手机和机器人等小型便携式设备中,而不会显著降低准确性。后续进一步为了解决实际应用中的
一些问题,推出了 v2,v3 版本。
MobileNet 提出了一种深度可分离卷积(Depthwise Separable Convolutions),该卷积不同于标准卷
积,可以大幅度减小模型规模的同时保证模型性能下降很小。
深度可分离卷积分为两个操作:深度卷积 (DW) 和逐点卷积 (PW)。
• 深度卷积 (DW) 和标准卷积的不同之处在于,对于标准卷积,其卷积核是应用于所有的输
入通道,而 DW 卷积针对每个输入通道采用不同的卷积核,也就是说,一个卷积核对应一
个输入通道。
• 逐点卷积 (PW) 实际上就是普通的卷积,只不过其采用 1x1 的卷积核。
MobileNet 设计了两个控制网络大小全局超参数(宽度乘系数和分辨率乘系数),通过这两个超参
数来进行速度和准确率的权衡,使用者可以根据设备的限制调整网络。 

具体参考该论文

三、环境搭建

1、创建环境

 conda create -n MobileNetv2_env python=3.8

2、激活环境

 conda activate MobileNetv2_env

3、安装pytorch

pip install torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple

3、下载数据

https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

下载需要APN,需要数据评论留言

数据集下载后解压到同组目录

4、train

直接上代码

import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader,Dataset

from model import MobileNetV2

# 自定义数据集FlowerData
# 读取的数据目录结构:
"""
            directory/
            ├── class_x
            │   ├── xxx.jpg
            │   ├── yyy.jpg
            │   └── ...   
            └── class_y
                ├── 123.jpg
                ├── 456.jpg
                └── ...
"""
class FlowerData(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        classes = sorted(entry.name for entry in os.scandir(self.root_dir) if entry.is_dir())
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        self.classes = classes
        self.class_to_idx = class_to_idx

        self.images = self.get_images(self.root_dir, self.class_to_idx)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,index):
        path, target = self.images[index]
        with open(path, "rb") as f:
            img = Image.open(f)
            image = img.convert("RGB")

        if self.transform:
            image = self.transform(image)   #对样本进行变换

        return image,target

    def get_images(self, directory, class_to_idx):
        images = []
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    item = path, class_index
                    images.append(item)

        return images

# 训练和评估
def fit(epochs, model, loss_function, optimizer, train_loader, validate_loader, device):
    t0 = time.time()
    best_acc = 0.0
    save_path = './MobileNetV2.pth'
    train_steps = len(train_loader)
    model.to(device)
    for epoch in range(epochs):
        # 训练
        model.train()
        running_loss = 0.0
        train_acc = 0.0
        train_bar = tqdm(train_loader, total=train_steps) # 进度条
        for step, (images, labels) in enumerate(train_bar):
            optimizer.zero_grad() # grad zero 
            logits = model(images.to(device)) # Forward
            loss = loss_function(logits, labels.to(device)) # loss
            loss.backward() # Backward
            optimizer.step() # optimizer.step

            _, predict = torch.max(logits, 1)
            train_acc += torch.sum(predict == labels.to(device))
            
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)

        train_accurate = train_acc / len(train_loader.dataset)

        # 验证
        model.eval()
        val_acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, total=len(validate_loader)) # 进度条
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))

                _, val_predict = torch.max(outputs, 1)
                val_acc += torch.sum(val_predict == val_labels.to(device))

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)
        val_accurate = val_acc / len(validate_loader.dataset)

        print('[epoch %d] train_loss: %.3f - train_accuracy: %.3f - val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, train_accurate, val_accurate))

        # 保存最好的模型
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_path)

    print("\n{} epochs completed in {:.0f}m {:.0f}s.".format(epochs,(time.time() - t0) // 60, (time.time() - t0) % 60))


def main():
    # 有GPU,就使用GPU训练
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # 超参数
    batch_size = 32
    epochs = 10
    learning_rate = 0.0001

    data_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # 初始化自定义FlowerData类,设置数据集所在路径以及变换
    flower_data = FlowerData('./flower_photos',transform=data_transform)
    print("Dataset class: {}".format(flower_data.class_to_idx))

    # 数据集随机划分训练集(80%)和验证集(20%)
    train_size = int(len(flower_data) * 0.8)
    validate_size = len(flower_data) - train_size
    train_dataset, validate_dataset = torch.utils.data.random_split(flower_data, [train_size, validate_size])
    print("using {} images for training, {} images for validation.".format(len(train_dataset),len(validate_dataset)))

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process \n'.format(nw))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
    validate_loader = DataLoader(validate_dataset, batch_size=1, shuffle=True, num_workers=nw)

    # 实例化模型,设置类别个数num_classes
    net = MobileNetV2(num_classes=5).to(device)

    # 使用预训练权重 https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
    model_weight_path = "./mobilenet_v2-b0353104.pth"
    assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)

    pre_weights = torch.load(model_weight_path, map_location=device)
    # print("The type is:".format(type(pre_weights)))

    pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
    missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)

    # 通过requires_grad == False的方式来冻结特征提取层权重,仅训练后面的池化和classifier层
    for param in net.features.parameters():
        param.requires_grad = False

    # 使用交叉熵损失函数
    loss_function = nn.CrossEntropyLoss()

    # 使用adam优化器, 仅仅对最后池化和classifier层进行优化
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=learning_rate)

    # 输出网络结构
    #print(summary(net, (3, 224, 224)))

    # 训练和验证模型
    fit(epochs, net, loss_function, optimizer, train_loader, validate_loader, device)

if __name__ == '__main__':
    main()

开始训练,执行命令

python train.py

电脑是CPU版本,大概等待1小时,训练完成。会在当前目录下生成MobileNetV2.pth模型

四、pt模型转换

训练后保存了 MobileNetV2.pth 模型权重文件,部署需要导出 torchscript 的模型。
export.py
import torch
import os
from model import MobileNetV2


if __name__ == '__main__':

    # 模型
    model = MobileNetV2(num_classes=5)

    # 加载权重
    model.load_state_dict(torch.load("./MobileNetV2.pth"))

    model.eval()
    # 保存模型
    trace_model = torch.jit.trace(model, torch.Tensor(1, 3, 224, 224))
    trace_model.save('./MobileNetV2.pt')

执行上面程序会导出MobileNetV2.pt模型

五、部署

1、RKNN模型转换

使用 RKNN Toolkit2 工具,将导出的模型转换出 rknn 模型,并进行简单模型测试。

RKNN Toolkit2 工具环境安装,参考正点原子手册。

pt2rknn.py

import numpy as np
import cv2
from rknn.api import RKNN

class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

def show_outputs(output):
    output_sorted = sorted(output, reverse=True)
    top5_str = '\n class    prob\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        topi = '{}:    {:.3}% \n'.format(class_names[(index[0][0])], value*100)
        top5_str += topi
    print(top5_str)

def show_perfs(perfs):
    perfs = 'perfs: {}\n'.format(perfs)
    print(perfs)

def softmax(x):
    return np.exp(x)/sum(np.exp(x))

if __name__ == '__main__':

    model = './MobileNetV2.pt'

    input_size_list = [[1, 3, 224, 224]]

    # Create RKNN object
    rknn = RKNN()

    # Pre-process config, 默认设置rk3588
    print('--> Config model')
    rknn.config(mean_values=[[128, 128, 128]], std_values=[[128, 128, 128]], target_platform='rk3568')
    print('done')

    # Load model
    print('--> Loading model')
    ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)
    if ret != 0:
        print('Load model failed!')
        exit(ret)
    print('done')

    # Build model
    print('--> Building model')
    # ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
    ret = rknn.build(do_quantization=False)
    if ret != 0:
        print('Build model failed!')
        exit(ret)
    print('done')

    # Export rknn model
    print('--> Export rknn model')
    ret = rknn.export_rknn('./MobileNetV2.rknn')
    if ret != 0:
        print('Export rknn model failed!')
        exit(ret)
    print('done')

    #Set inputs
    img = cv2.imread('./sun.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (224,224))
    img = np.expand_dims(img, 0)

    # Init runtime environment
    print('--> Init runtime environment')
    ret = rknn.init_runtime()
    if ret != 0:
        print('Init runtime environment failed!')
        exit(ret)
    print('done')

    # Inference
    print('--> Running model')
    outputs = rknn.inference(inputs=[img])
    # np.save('./MobileNetV2.npy', outputs[0])
    print(outputs[0][0])
    show_outputs(softmax(np.array(outputs[0][0])))
    print('done')

    rknn.release()

在虚拟机下执行

python pt2rknn.py

模型转换成功,并测试正常。有点要注意,平台是RK3568.

2、部署测试

MobileNetV2.rknn  test.py  tulips.jpg这三个文件通过adb上传到开发板,打开开发板终端

执行测试程序。

test.py

import cv2
import numpy as np
from rknnlite.api import RKNNLite

INPUT_SIZE = 224

RK3566_RK3568_RKNN_MODEL = 'MobileNetV2.rknn'
RK3588_RKNN_MODEL = 'MobileNetV2.rknn'

class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

def softmax(x):
    return np.exp(x)/sum(np.exp(x))

def show_outputs(output):
    output_sorted = sorted(output, reverse=True)
    top5_str = '\n Class    Prob\n'
    for i in range(5):
        value = output_sorted[i]
        index = np.where(output == value)
        topi = '{}:    {:.3}% \n'.format(class_names[(index[0][0])], value*100)
        top5_str += topi
    print(top5_str)

if __name__ == '__main__':

    rknn_lite = RKNNLite()

    # load RKNN model
    print('--> Load RKNN model')
    ret = rknn_lite.load_rknn(RK3566_RK3568_RKNN_MODEL)
    if ret != 0:
        print('Load RKNN model failed')
        exit(ret)
    print('done')

    ori_img = cv2.imread('./tulips.jpg')
    img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (224,224))
    
    # init runtime environment
    print('--> Init runtime environment')
    # run on RK356x/RK3588 with Debian OS, do not need specify target.
    #ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
    ret = rknn_lite.init_runtime()
    if ret != 0:
        print('Init runtime environment failed')
        exit(ret)
    print('done')

    # Inference
    print('--> Running model')
    outputs = rknn_lite.inference(inputs=[img])
    print(outputs[0][0])
    show_outputs(softmax(np.array(outputs[0][0])))
    print('done')

    rknn_lite.release()

测试结果正常,部署成功

六、参考链接

https://pytorch.org
https://arxiv.org/abs/1801.04381
https://arxiv.org/pdf/1704.04861
https://github.com/rockchip-linux/rknn-toolkit2

如有侵权,或需要完整代码,请及时联系博主。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/424681.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前端面试练习24.3.2-3.3

HTMLCSS部分 一.说一说HTML的语义化 在我看来,它的语义化其实是为了便于机器来看的,当然,程序员在使用语义化标签时也可以使得代码更加易读,对于用户来说,这样有利于构建良好的网页结构,可以在优化用户体…

Python【初识】

一、Python简介 Python是一种高级的解释型编程语言,以其简洁、易学和强大的库支持而闻名。它最初由荷兰国家数学与计算机科学研究中心的吉多范罗苏姆于1990年代初设计,作为一门叫做ABC语言的替代品。Python的设计理念强调优雅、明确和简单,旨…

Google 地图 API 教程--干货(1/2)

Google Maps API 教程 在本教程中我们将学习如何使用谷歌地图API V3创建交互式地图。 什么是 API? API = 应用程序编程接口(Application programming interface)。 API(Application Programming Interface,应用编程接口)其实就是操作系统留给应用程序的一个调用接口,…

vb.net获取Windows主题颜色、深色模式窗体,实时响应

先上效果图 可直接跳到完整代码 目录 先上效果图 开始教学 响应用户的更改 API讲解 读取深浅模式、主题颜色、十六进制颜色转换 完整代码 如果大家留意资源管理器的“文件”菜单的话就会发现它的底色就是你设置的主题色,在更改Windows颜色模式时,…

《OpenScene: 3D Scene Understanding with Open Vocabularies》阅读笔记1

传统的3D场景理解方法依赖于带标签的3D数据集,用于训练一个模型以进行单一任务的监督学习。我们提出了OpenScene,一种替代方法,其中模型在CLIP特征空间中预测与文本和图像像素共同嵌入的3D场景点的密集特征。这种零样本方法实现了与任务无关的训练和开放词汇查询。例如,为了…

开源项目热榜 - 华为OD统一考试(C卷)

OD统一考试(C卷) 分值: 100分 题解: Java / Python / C 题目描述 某个开源社区希望将最近热度比较高的开源项目出一个榜单,推荐给社区里面的开发者。 对于每个开源项目,开发者可以进行关注(watch)、收藏(…

如何在 Mac 上成功轻松地恢复 Excel 文件

Microsoft Excel 的 Mac 版本始终略落后于 Windows 版本,这也许可以解释为什么如此多的用户渴望学习如何在 Mac 上恢复 Excel 文件。 但导致重要电子表格不可用的不仅仅是 Mac 版 Excel 的不完全稳定性。用户有时会失去注意力并删除错误的文件,存储设备…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的体育赛事目标检测系统(Python+PySide6界面+训练代码)

摘要:开发和研究体育赛事目标检测系统对于增强体育分析和观赏体验至关重要。本篇博客详细讲述了如何运用深度学习技术构建一个体育赛事目标检测系统,并提供了完整的实现代码。系统基于先进的YOLOv8算法,对比了YOLOv7、YOLOv6、YOLOv5的性能&a…

nginx笔记整理

目录 一.Nginx基础介绍 二.nginx安装配置 三.Nginx配置文件 3.1nginx主配置文件(/etc/nginx/nginx.conf) 3.2默认的网站配置文件(/etc/nginx/conf.d/default.conf) 四.创建新的虚拟主机 五.Nginx日志 5.1nginx日志格式 5.2查看日志 5.3日志缓存(了解) 5.4日志轮转(/…

总结 HashTable, HashMap, ConcurrentHashMap 之间的区别

1.多线程环境使用哈希表 HashMap 不行,线程不安全 更靠谱的,Hashtable,在关键方法上加了synchronized 后来标准库又引入了一个更好的解决方案;ConcurrentHashMap 2.HashMap 首先HashMap本身线程不安全其次HashMap的key值可以为空(当key为空时,哈希会…

Jenkins的Pipeline概念

文章目录 Pipeline什么是Jenkins Pipeline声明式和脚本式Pipeline语法为何使用PipelinePipeline概念PipelineNodeStageStep Pipeline语法概述声明式Pipeline脚本式Pipeline Pipeline示例 参考 Pipeline 什么是Jenkins Pipeline Jenkins Pipeline是一套插件,它支持…

【精华】麻省理工学院MIT技术双月刊(Bimonthly MIT Technology Review)2024年3/4月刊荐书 Book reviews

本期内容概览见博客:2024年3/4月刊内容概览 Book Reviews 1. Read Write Own: Building the Next Era of the Internet By Chris Dixon (Random House, 2024) With the demise of Twitter, many have advocated for a decentralized alternative for social medi…

浅析this指针

浅析this指针 文章目录 浅析this指针前言this指针作用this指针使用注意事项总结 前言 ​ 在面向对象的编程语言中,this指针是一个自引用指针,通常用于指向对象自身。通过这篇文章,我们将探讨this指针的核心特性、应用场景和相关案例。 this指…

高维中介数据:基于交替方向乘子法(ADMM)的高维度单模态中介模型的参数估计(入门+实操)

全文摘要 用于高维度单模态中介模型的参数估计,采用交替方向乘子法(ADMM)进行计算。该包提供了确切独立筛选(SIS)功能来提高中介效应的敏感性和特异性,并支持Lasso、弹性网络、路径Lasso和网络约束惩罚等不…

flynn发布服务小结

背景 flynn是一个基于容器的paas平台,可以快速的发布运行新的应用,用户只需要提交代码到git上,flynn就会基于提交的代码进行发布和部署,本文就简单看下flynn发布部署的流程 flynn发布服务 1.首先flynn会基于用户的web代码构建一…

远程服务器Ubuntu 18.04安装VNC远程桌面

一、安装vnc 1.安装图形化界面工具 # 安装过程中会弹窗让选择配置,选lightdm sudo apt install ubuntu-desktop sudo apt-get install gnome-panel gnome-settings-daemon metacity nautilus gnome-terminal 2.安装vnc sudo apt-get install x11vnc3.安装LightD…

(面试题)数据结构:链表相交

问题:有两个链表,如何判断是否相交,若相交,找出相交的起始节点 一、介绍 链表相交: 若两个链表相交,则两个链表有共同的节点,那从这个节点之后,后面的节点都会重叠,知道…

推荐五本程序员必看书籍!

昨天推送的是视频,今天给大家推荐基本入门渗透测试的好书,以结合昨天文章一起学习,忘记了的可以回复“学习之路”会自动跳出文章的,好的话不多说,直接上主菜了! 第一本当然是我们网络基础的书,…

SpringMVC了解

1.springMVC概述 Spring MVC(Model-View-Controller)是基于 Java 的 Web 应用程序框架,用于开发 Web 应用程序。它通过将应用程序分为模型(Model)、视图(View)和控制器(Controller&a…

快递平台独立版小程序源码|带cps推广营销流量主+前端

源码介绍: 快递代发快递代寄寄件小程序可以对接易达云洋一级总代 快递小程序,接入云洋/易达物流接口,支持选择快递公司,三通一达,极兔,德邦等,功能成熟 如何收益: 1.对接第三方平台成本大约4元…