基于PaddlePaddle平台训练物体分类——猫狗分类

学习目标:

在百度的PaddlePaddle平台训练自己需要的模型,以训练一个猫狗分类模型为例

PaddlePaddle平台:

  • 飞桨(PaddlePaddle)是百度开发的深度学习平台,具有动静统一框架、端到端开发套件等特性,支持大规模分布式训练和高性能推理
  • 作为中国首个自主研发的产业级平台,飞桨在市场份额和应用规模上均居中国第一,服务了800万开发者和22万家企事业单位,广泛应用于金融、能源、制造、交通等领域

学习概述:

  • 基于百度的PaddlePaddle平台训练猫狗分类模型

  • 学习使用PaddlePaddle平台的使用方法,其中包括寻找数据集、运行环境配置、数据预处理、训练、计算预估准确率、使用ncc工具将模型转换为kmodel模型文件等


训练方法:

1、寻找数据集:我们可以在搜索框搜索猫和狗,选择一个合适大小的猫与狗的数据集,便于后面训练模型,数据集样本的数量直接影响训练模型的正确率、迭代次数、训练时间等(点击跳转)

在这里插入图片描述
在这里插入图片描述


2、运行环境配置:首先创建一个Notebook项目,然后填写项目名称、数据集配置,此处使用AI Studio经典版、PaddlePaddle 2.4.0框架,接下来选择运行环境,我们选择免费的两核CPU就可以,然后运行创建配置好的项目

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述


3、数据预处理:可以看到项目中有两个文件夹work和data,work目录下的变更会持久保存,但data目录下的变更重启环境后会自动还原。在右侧Notebook编译区可以新建代码脚本等

  1. 将data目录下的猫狗.zip重命名为catanddog.zip,新建Code并运行解压数据集
# 解压猫狗数据集
!cd data/data17036 && unzip -q catanddog.zip
  1. 上传预训练参数文件下载链接,新建Code并运行解预训练参数
# 解压预训练参数 
!cd data && unzip -q Pts.zip
# 解压预训练参数 pretrained
!cd data/Pts && unzip -q pretrained.zip
  1. 预处理数据,同时将数据拆分成两份以便训练和计算预估准确率,将其转化为标准格式。
# 预处理数据,将其转化为标准格式。同时将数据拆分成两份,以便训练和计算预估准确率
import codecs
import os
import random
import shutil
from PIL import Image

train_ratio = 4 / 5

all_file_dir = 'data/data17036/catanddog'
class_list = [c for c in os.listdir(all_file_dir) if os.path.isdir(os.path.join(all_file_dir, c)) and not c.endswith('Set') and not c.startswith('.')]
class_list.sort()
print(class_list)
train_image_dir = os.path.join(all_file_dir, "trainImageSet")
if not os.path.exists(train_image_dir):
    os.makedirs(train_image_dir)
    
eval_image_dir = os.path.join(all_file_dir, "evalImageSet")
if not os.path.exists(eval_image_dir):
    os.makedirs(eval_image_dir)

train_file = codecs.open(os.path.join(all_file_dir, "train.txt"), 'w')
eval_file = codecs.open(os.path.join(all_file_dir, "eval.txt"), 'w')

with codecs.open(os.path.join(all_file_dir, "label_list.txt"), "w") as label_list:
    label_id = 0
    for class_dir in class_list:
        label_list.write("{0}\t{1}\n".format(label_id, class_dir))
        image_path_pre = os.path.join(all_file_dir, class_dir)
        for file in os.listdir(image_path_pre):
            try:
                img = Image.open(os.path.join(image_path_pre, file))
                if random.uniform(0, 1) <= train_ratio:
                    shutil.copyfile(os.path.join(image_path_pre, file), os.path.join(train_image_dir, file))
                    train_file.write("{0}\t{1}\n".format(os.path.join(train_image_dir, file), label_id))
                else:
                    shutil.copyfile(os.path.join(image_path_pre, file), os.path.join(eval_image_dir, file))
                    eval_file.write("{0}\t{1}\n".format(os.path.join(eval_image_dir, file), label_id))
            except Exception as e:
                pass
                # 存在一些文件打不开,此处需要稍作清洗
        label_id += 1
            
train_file.close()
eval_file.close()

4、训练模型:训练常用视觉基础网络进行猫狗分类

# -*- coding: UTF-8 -*-
"""
训练常用视觉基础网络,用于分类任务
需要将训练图片,类别文件 label_list.txt 放置在同一个文件夹下
程序会先读取 train.txt 文件获取类别数和图片数量
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import time
import math
import paddle
import paddle.fluid as fluid
import codecs
import logging

from paddle.fluid.initializer import MSRA
from paddle.fluid.initializer import Uniform
from paddle.fluid.param_attr import ParamAttr
from PIL import Image
from PIL import ImageEnhance

......#代码较长其余可fork项目,参考本文结尾

5、计算预估准确率:测试集模型评估,测试模型的正确率

from __future__ import absolute_import    
from __future__ import division    
from __future__ import print_function    
    
import os    
import numpy as np    
import random    
import time    
import codecs    
import sys    
import functools    
import math    
import paddle    
import paddle.fluid as fluid    
from paddle.fluid import core    
from paddle.fluid.param_attr import ParamAttr    
from PIL import Image, ImageEnhance    
    
target_size = [3, 224, 224]    
mean_rgb = [127.5, 127.5, 127.5]    
data_dir = "data/data17036/catanddog"    
eval_file = "eval.txt"    
use_gpu = train_parameters["use_gpu"]    
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()    
exe = fluid.Executor(place)    
save_freeze_dir = "./freeze-model"    
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(dirname=save_freeze_dir, executor=exe)    
# print(fetch_targets)    
    
    
def crop_image(img, target_size):    
    width, height = img.size    
    w_start = (width - target_size[2]) / 2    
    h_start = (height - target_size[1]) / 2    
    w_end = w_start + target_size[2]    
    h_end = h_start + target_size[1]    
    img = img.crop((w_start, h_start, w_end, h_end))    
    return img    
    
    
def resize_img(img, target_size):    
    ret = img.resize((target_size[1], target_size[2]), Image.BILINEAR)    
    return ret    
    
    
def read_image(img_path):    
    img = Image.open(img_path)    
    if img.mode != 'RGB':    
        img = img.convert('RGB')    
    img = crop_image(img, target_size)    
    img = np.array(img).astype('float32')    
    img -= mean_rgb    
    img = img.transpose((2, 0, 1))  # HWC to CHW    
    img *= 0.007843    
    img = img[np.newaxis,:]    
    return img    
    
    
def infer(image_path):    
    tensor_img = read_image(image_path)    
    label = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets)    
    return np.argmax(label)    
    
    
def eval_all():    
    eval_file_path = os.path.join(data_dir, eval_file)    
    total_count = 0    
    right_count = 0    
    with codecs.open(eval_file_path, encoding='utf-8') as flist:     
        lines = [line.strip() for line in flist]    
        t1 = time.time()    
        for line in lines:    
            total_count += 1    
            parts = line.strip().split()    
            result = infer(parts[0])    
            # print("infer result:{0} answer:{1}".format(result, parts[1]))    
            if str(result) == parts[1]:    
                right_count += 1    
        period = time.time() - t1    
        print("total eval count:{0} cost time:{1} predict accuracy:{2}".format(total_count, "%2.2f sec" % period, right_count / total_count))    
    
    
if __name__ == '__main__':    
    eval_all()  
#print:total eval count:17 cost time:1.00 sec predict accuracy:0.8235294117647058

6、模型转换:

  1. 下载ncc工具,准备转换模型。关于ncc工具可参考K210学习记录(3)——kmodel生成与使用
!mkdir /home/aistudio/work/ncc
!wget "https://platform.bj.bcebos.com/sdk%2Fncc-linux-x86_64.tar.gz" -O ncc-linux-x86_64.tar.gz 
!tar -zxvf ncc-linux-x86_64.tar.gz -C /home/aistudio/work/ncc 
  1. 在模型转换前,需要进行模型压缩,进行量化。为了保证量化后的精度, 需要使用训练图片调整模型。拷贝评估图片到/home/aistudio/work/images
import os
import shutil
!mkdir /home/aistudio/work/images
filenames = os.listdir("/home/aistudio/data/data17036/catanddog/evalImageSet/")

#下面方法是图片太多的时候随机选择图片  
# index = 0
# for i in range(1, len(filenames), 7):
#     srcFile = os.path.join("/home/aistudio/data/data17036/catanddog/evalImageSet/", filenames[index])
#     targetFile = os.path.join("/home/aistudio/work/images",filenames[index])
#     shutil.copyfile(srcFile,targetFile)
#     index += 7

index = 0
for i in range(0, len(filenames), 1):
    srcFile = os.path.join("/home/aistudio/data/data17036/catanddog/evalImageSet/", filenames[index])
    targetFile = os.path.join("/home/aistudio/work/images",filenames[index])
    shutil.copyfile(srcFile,targetFile)
    index += 1
  1. 转换为.kmodel模型
!chmod 777 /home/aistudio/work/ncc
!/home/aistudio/work/ncc/ncc -i paddle -o k210model --postprocess n1to1 --dataset work/images/ freeze-model catanddog.kmodel

小结:

  • PaddlePaddle平台的学习有六个关键步骤:配置运行环境以安装PaddlePaddle、选择并获取适合的数据集、对数据进行预处理,如清洗和标准化、利用PaddlePaddle框架进行深度学习模型的训练、训练完成后,使用验证集对模型性能进行评估、最后通过ncc工具将模型转换成kmodel文件,为模型部署做准备。这一系列步骤构成了机器学习从数据准备到模型部署的完整流程
  • 为提升PaddlePaddle实验效率和模型性能,可自动化实验流程,进行超参数调优,使用可视化监控训练,并通过交叉验证等方法增强模型泛化能力。
  • 本实验的项目地址,为大家学习使用带来方便,大家可以fork学习,点击进入项目地址

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/572014.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

tailwindcss在使用cdn引入静态html的时候,vscode默认不会提示问题

1.首先确保vscode下载tailwind插件&#xff1a;Tailwind CSS IntelliSense 2.需要在根目录文件夹创建一个tailwind.config.js文件 export default {theme: {extend: {// 可根据需要自行配置&#xff0c;空配置项可以正常使用},}, }3.在html文件的标签中引入配置文件&#xf…

程序员到架构师,除了代码,还有文档和图

文章目录 前言一、书面设计文档文档应该作为代码和口头交流的补充文档应该注意鲜活 二、图——架构讨论的直观语言总结 前言 作为人类&#xff0c;我们天生就被视觉所吸引。在这个信息爆炸的时代&#xff0c;从精炼的代码到清晰的文档&#xff0c;再到直观的图&#xff0c;我们…

【数据结构】串(String)

文章目录 基本概念顺序存储结构比较当前串与串s的大小取子串插入删除其他构造函数拷贝构造函数扩大数组空间。重载重载重载重载[]重载>>重载<< 链式存储结构链式存储结构链块存储结构 模式匹配朴素的模式匹配算法(BF算法)KMP算法字符串的前缀、后缀和部分匹配值nex…

Parade Series - CoreAudio Reformating

// 获得音频播放设备格式信息CComHeapPtr<WAVEFORMATEX> pDeviceFormat;pAudioClient->GetMixFormat(&pDeviceFormat);constexpr int REFTIMES_PER_SEC 10000000; // 1 reference_time 100nsconstexpr int REFTIMES_PER_MILLISEC 10000;// Microsoftif (p…

Golang | Leetcode Golang题解之第49题字母异位词分组

题目&#xff1a; 题解&#xff1a; func groupAnagrams(strs []string) [][]string {mp : map[[26]int][]string{}for _, str : range strs {cnt : [26]int{}for _, b : range str {cnt[b-a]}mp[cnt] append(mp[cnt], str)}ans : make([][]string, 0, len(mp))for _, v : ra…

Alibaba 的fastjson源码详解

一、概述 Fastjson 是阿里巴巴开源的一个 Java 工具库&#xff0c;它常常被用来完成 Java 的对象与 JSON 格式的字符串的相互转化。 Fastjson 可以操作任何 Java 对象&#xff0c;即使是一些预先存在的没有源码的对象。 二、源码分析 1.首先以fastjson-1.2.70为例&#xff0c;…

nodejs

334 先下载zip文件&#xff0c;然后加上.zip,可以看到两个文件 在user中可以看到 输入即可得到flag。 335. 这里提到eval函数&#xff0c;eval中可以执行js代码&#xff0c;可以尝试使用这个函数进行测试 payload&#xff08;显示当前目录下的文件和文件夹列表&#xff09; …

基于emp的mysql查询

SQL命令 结构化查询语句&#xff1a;Structured Query Language 结构化查询语言是高级的非过程化变成语言&#xff0c;允许用户在高层数据结构上工作。是一种特殊目的的变成语言&#xff0c;是一种数据库查询和程序设计语言&#xff0c;用于存取数据以及查询、更新和管理关系数…

Python 网络与并发编程(四)

文章目录 协程Coroutines协程的核心(控制流的让出和恢复)协程和多线程比较协程的优点协程的缺点 asyncio实现协程(重点) 协程Coroutines 协程&#xff0c;全称是“协同程序”&#xff0c;用来实现任务协作。是一种在线程中&#xff0c;比线程更加轻量级的存在&#xff0c;由程…

android脱壳第二发:grpc-dumpdex加修复

上一篇我写的dex脱壳&#xff0c;写到银行类型的app的dex修复问题&#xff0c;因为dex中被抽取出来的函数的code_item_off 的偏移所在的内存&#xff0c;不在dex文件范围内&#xff0c;所以需要进行一定的修复&#xff0c;然后就停止了。本来不打算接着搞得&#xff0c;但是写了…

基础SQL DCL语句

DCL是数据控制语言&#xff0c;用来管理数据库用户&#xff0c;还有控制用户的访问权限 1.用户的查询 MySQL的用户信息存储在mysql数据库中&#xff0c;查询用户时&#xff0c;我们需要使用这个数据库。 后面&#xff0c;还有很多数据&#xff0c;因为篇幅的问题&#xff0c;就…

【FFmpeg】音视频录制 ② ( 使用 Screen Capturer Recorder 软件生成 ffmpeg 可录制的音视频设备 )

文章目录 一、使用 Screen Capturer Recorder 软件生成音视频设备1、设备查找问题 - 引入 Screen Capturer Recorder 软件2、下载安装 Screen Capturer Recorder 软件3、验证 Screen Capturer Recorder 生成的设备 一、使用 Screen Capturer Recorder 软件生成音视频设备 1、设…

【PyTorch】torch.gather() 用法

gather常被用于image做mask的操作中&#xff0c;对哪些地方进行赋值0/1 API&#xff1a; torch.gather — PyTorch 2.2 documentation torch.gather(input, dim, index, outNone) → Tensor gather()的意义&#xff1a; 顾名思义&#xff0c;聚集、集合&#xff1a;gather…

在mac上安装node.js及使用npm,yarn相关命令教程

1、安装node.js 官网&#xff1a;Node.js — Download Node.js 选择需要的版本&#xff0c;点击DownLoad 2、点击继续&#xff0c;直到安装成功。 2.1打开终端输入命令node -v 显示版本号则说明已安装成功 3、全局安装yarn命令 1、sudo npm install --global yarn &#xf…

Python构建学生信息管理系统:构建RESTful API - 学生信息管理系统的后端逻辑

在之前的博客里&#xff0c;我们已经完成了项目初始化&#xff0c;在本篇博客中&#xff0c;我们将深入探讨如何使用Flask框架实现学生信息管理系统的后端逻辑&#xff0c;特别是通过RESTful API来实现学生信息的增删改查&#xff08;CRUD&#xff09;操作。 Flask RESTful AP…

计网笔记:第1章 计算机网络概论

计网笔记&#xff1a;第1章 计算机网络概论 第1章 计算机网络概论1.1 计算机网络发展与分类1.2 OSI和TCP/IP参考模型OSI与TCP/IP参考模型图 1.3 数据封装与解封过程借助OSI模型理解数据传输过程(封装)借助OSI模型理解数据传输过程(解封) 1.4 本章例题 第1章 计算机网络概论 1.…

详解Al作画算法原理

ChatGPT AI作画算法&#xff0c;又称为AI图像生成算法&#xff0c;是一种人工智能技术&#xff0c;它可以根据给定的输入自动生成图像。这类算法近年来变得非常流行&#xff0c;尤其是随着深度学习技术的发展。这里我将聚焦于目前最先进的一类AI作画算法&#xff0c;即生成对抗…

PHP定期给自己网站目录做个特征镜像供快速对比

效果图 上代码&#xff1a; <style> h1{font-size:24px;line-height:180%;font-weight:600;margin:1px 2px;color:#0180cf;} h2{font-size:20px;line-height:140%;font-weight:600;margin:2px 4px;color:green;} h3{font-size:16px;line-height:140%;font-weight:600;m…

Hive——DML(Data Manipulation Language)数据操作语句用法详解

DML 1.Load Load语句可将文件导入到Hive表中。 hive> LOAD DATA [LOCAL] INPATH filepath [OVERWRITE] INTO TABLE tablename [PARTITION (partcol1val1, partcol2val2 ...)];关键字说明&#xff1a; local&#xff1a;表示从本地加载数据到Hive表&#xff1b;否则从HD…