TensorFlow-slim包进行图像数据集分类---具体流程

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim
    |-- BUILD
    |-- README.md
    |-- WORKSPACE
    |-- __init__.py
    |-- datasets
    |   |-- __init__.py
    |   |-- __pycache__
    |   |   |-- __init__.cpython-37.pyc
    |   |   |-- dataset_utils.cpython-37.pyc
    |   |   |-- download_and_convert_cifar10.cpython-37.pyc
    |   |   |-- download_and_convert_flowers.cpython-37.pyc
    |   |   `-- download_and_convert_mnist.cpython-37.pyc
    |   |-- build_imagenet_data.py
    |   |-- cifar10.py
    |   |-- dataset_factory.py
    |   |-- dataset_utils.py
    |   |-- download_and_convert_cifar10.py
    |   |-- download_and_convert_flowers.py
    |   |-- download_and_convert_imagenet.sh
    |   |-- download_and_convert_mnist.py
    |   |-- download_imagenet.sh
    |   |-- flowers.py
    |   |-- imagenet.py
    |   |-- imagenet_2012_validation_synset_labels.txt
    |   |-- imagenet_lsvrc_2015_synsets.txt
    |   |-- imagenet_metadata.txt
    |   |-- mnist.py
    |   |-- preprocess_imagenet_validation_data.py
    |   `-- process_bounding_boxes.py
    |-- deployment
    |   |-- __init__.py
    |   |-- model_deploy.py
    |   `-- model_deploy_test.py
    |-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式
    |-- eval_image_classifier.py        # 测试模型分类效果
    |-- export_inference_graph.py
    |-- export_inference_graph_test.py
    |-- nets
    |   |-- __init__.py
    |   |-- alexnet.py
    |   |-- alexnet_test.py
    |   |-- cifarnet.py
    |   |-- cyclegan.py
    |   |-- cyclegan_test.py
    |   |-- dcgan.py
    |   |-- dcgan_test.py
    |   |-- i3d.py
    |   |-- i3d_test.py
    |   |-- i3d_utils.py
    |   |-- inception.py
    |   |-- inception_resnet_v2.py
    |   |-- inception_resnet_v2_test.py
    |   |-- inception_utils.py
    |   |-- inception_v1.py
    |   |-- inception_v1_test.py
    |   |-- inception_v2.py
    |   |-- inception_v2_test.py
    |   |-- inception_v3.py
    |   |-- inception_v3_test.py
    |   |-- inception_v4.py
    |   |-- inception_v4_test.py
    |   |-- lenet.py
    |   |-- mobilenet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- conv_blocks.py
    |   |   |-- madds_top1_accuracy.png
    |   |   |-- mnet_v1_vs_v2_pixel1_latency.png
    |   |   |-- mobilenet.py
    |   |   |-- mobilenet_example.ipynb
    |   |   |-- mobilenet_v2.py
    |   |   `-- mobilenet_v2_test.py
    |   |-- mobilenet_v1.md
    |   |-- mobilenet_v1.png
    |   |-- mobilenet_v1.py
    |   |-- mobilenet_v1_eval.py
    |   |-- mobilenet_v1_test.py
    |   |-- mobilenet_v1_train.py
    |   |-- nasnet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- nasnet.py
    |   |   |-- nasnet_test.py
    |   |   |-- nasnet_utils.py
    |   |   |-- nasnet_utils_test.py
    |   |   |-- pnasnet.py
    |   |   `-- pnasnet_test.py
    |   |-- nets_factory.py
    |   |-- nets_factory_test.py
    |   |-- overfeat.py
    |   |-- overfeat_test.py
    |   |-- pix2pix.py
    |   |-- pix2pix_test.py
    |   |-- resnet_utils.py
    |   |-- resnet_v1.py
    |   |-- resnet_v1_test.py
    |   |-- resnet_v2.py
    |   |-- resnet_v2_test.py
    |   |-- s3dg.py
    |   |-- s3dg_test.py
    |   |-- vgg.py
    |   `-- vgg_test.py
    |-- preprocessing
    |   |-- __init__.py
    |   |-- cifarnet_preprocessing.py
    |   |-- inception_preprocessing.py
    |   |-- lenet_preprocessing.py
    |   |-- preprocessing_factory.py
    |   `-- vgg_preprocessing.py
    |-- scripts                     # gqr:存储的是相关的模型训练脚本                
    |   |-- export_mobilenet.sh
    |   |-- finetune_inception_resnet_v2_on_flowers.sh
    |   |-- finetune_inception_v1_on_flowers.sh
    |   |-- finetune_inception_v3_on_flowers.sh
    |   |-- finetune_resnet_v1_50_on_flowers.sh
    |   |-- train_cifarnet_on_cifar10.sh
    |   `-- train_lenet_on_mnist.sh
    |-- setup.py
    |-- slim_walkthrough.ipynb
    `-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e

# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径

# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50

# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径

# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
  mkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
  wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
  tar -xvf resnet_v1_50_2016_08_28.tar.gz
  mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
  rm resnet_v1_50_2016_08_28.tar.gz
fi

# Download the dataset
python download_and_convert_data.py \
  --dataset_name=flowers \
  --dataset_dir=${DATASET_DIR}

# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50 \
  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR} \
  --eval_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --checkpoint_path=${TRAIN_DIR} \
  --model_name=resnet_v1_50 \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR}/all \
  --eval_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \
  --train_dir=./data/flowers-models/resnet_v1_50\
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=./data/flowers \
  --model_name=resnet_v1_50 \
  --checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \ 
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–checkpoint_path:预训练模型存放地址
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

注意:在模型训练前,需要下载预训练模型,
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

解压后存放在相应目录

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/96073.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

web自动化框架:selenium学习使用操作大全(Python版)

目录 一、浏览器驱动下载二、selenium-python安装(打开网站、操作元素)三、网页解析(HTML、xpath)四、selenium基本操作1、元素定位八种方法2、元素动态定位3、iframe切换4、填充表单_填充文本框5、填充表单_单选按钮6、填充表单_…

RT-Thread在STM32硬件I2C的踩坑记录

RT-Thread在STM32硬件I2C的踩坑记录 0.前言一、软硬件I2C区别二、RT Thread中的I2C驱动三、尝试适配硬件I2C四、i2c-bit-ops操作函数替换五、Attention Please!六、总结 参考文章: 1.将硬件I2C巧妙地将“嫁接”到RTT原生的模拟I2C驱动框架 2.基于STM32F4平台的硬件I…

java八股文面试[多线程]——Synchronized的底层实现原理

笔试:画出Synchronized 线程状态流转实现原理图 synchronized关键字解决的是多个线程之间访问资源的同步性,synchronized 翻译为中文的意思是同步,也称之为”同步锁“。 synchronized的作用是保证在同一时刻, 被修饰的代码块或方…

16.CSS菜单悬停特效

效果 源码 <!DOCTYPE html> <html> <head> <title>Creative Menu Item Hover Effects</title> <link rel="stylesheet" type="text/css" href="style.css"> </head> <body><section><…

(详解)数据结构-----------栈与队列 c语言实现

本章将会详细讲解以下知识点&#xff1a; 目录 一&#xff1a;栈 1&#xff1a;栈的定义&#xff0c;栈的特点 2&#xff1a;用什么结构来实现栈与原因的分析? 3: (超详解)栈的常用接口并且附上测试用例 二:队列 1:队列的定义&#xff0c;队列的特点 2&#xff1a;用什么结…

【ArcGIS微课1000例】0073:ArcGIS探索性回归分析案例

一、探索性回归工具简介 “探索性回归”工具会对输入的候选解释变量的所有可能组合进行评估,以便根据用户所指定的指标来查找能够最好地对因变量做出解释的 OLS 模型。 给定一组候选解释变量,找出正确指定的 OLS 模型: 用法: 工具还会生成一个可选表,该表包括所有满足…

Mybatis1.4 多条件查询

1.4 多条件查询 1.4.1 编写接口方法1.4.2 编写SQL语句1.4.3 编写测试方法1.4.4 动态SQL 我们经常会遇到如上图所示的多条件查询&#xff0c;将多条件查询的结果展示在下方的数据列表中。而我们做这个功能需要分析最终的SQL语句应该是什么样&#xff0c;思考两个问题 条件表达式…

SpringBoot整合JUnit、MyBatis、SSM

&#x1f40c;个人主页&#xff1a; &#x1f40c; 叶落闲庭 &#x1f4a8;我的专栏&#xff1a;&#x1f4a8; c语言 数据结构 javaEE 操作系统 石可破也&#xff0c;而不可夺坚&#xff1b;丹可磨也&#xff0c;而不可夺赤。 SpringBoot整合 一、SpringBoot整合JUnit二、Spri…

无人机甚高频无线电中继通讯U-ATC118

简介 甚高频无线电中继通讯系统使用经过适航认证的机载电台连接数字网络传输模块&#xff0c;通过网络远程控制无缝实现无人机操作员与塔台直接语音通话。无人机操作员可以从地面控制站远程操作机载电台进行频率切换、静噪开关、PTT按钮&#xff0c;电台虚拟面板与真实面板布局…

小程序数据导出文件

小程序josn数据生成excel文件 先从下载传送门将xlsx.mini.min.js拷贝下来&#xff0c;新建xlsx.js文件放入小程序项目文件夹下。 const XLSX require(./xlsx)//在需要用的页面中引入// 定义导出 Excel 报表的方法exportData() {const that thislet newData [{time:2021,val…

【Seata】00 - Seata Server 部署(Windows、Docker 基于 Jpom)

文章目录 前言参考目录版本说明Windows 部署 seata-server1&#xff1a;下载压缩包2&#xff1a;文件存储模式3&#xff1a;db 存储模式3.1&#xff1a;建表3.2&#xff1a;修改配置文件3.3&#xff1a;启动脚本4&#xff1a;源码部署 Docker 部署 seata-server &#xff08;基…

程序员必须掌握哪些算法?

一个程序员一生中可能会邂逅各种各样的算法&#xff0c;但总有那么几种&#xff0c;是作为一个程序员一定会遇见且大概率需要掌握的算法。今天就来聊聊这些十分重要的“必抓&#xff01;”算法吧~ 文章目录 一、程序员必须掌握哪些算法&#xff1f;二&#xff1a;常见算法介绍…

flutter高德地图大头针

1、效果图 2、pub get #地图定位 amap_flutter_map: ^3.0.0 amap_flutter_location: ^3.0.0 3、上代码 import dart:async; import dart:io;import package:amap_flutter_location/amap_flutter_location.dart; import package:amap_flutter_location/amap_location_option…

R语言APRIORI关联规则、K-MEANS均值聚类分析中药专利复方治疗用药规律网络可视化...

全文链接&#xff1a;http://tecdat.cn/?p30605 应用关联规则、聚类方法等数据挖掘技术分析治疗的中药专利复方组方配伍规律&#xff08;点击文末“阅读原文”获取完整代码数据&#xff09;。 方法检索治疗中药专利复方&#xff0c;排除外用中药及中西药物合用的复方。最近我们…

每天 26,315 美元罚款?交通安全局要求特斯拉提供 Autopilot数据

根据美国国家公路交通安全管理局&#xff08;NHTSA&#xff09;最近的特别命令&#xff0c;特斯拉公司被要求提供关于其自动驾驶功能Autopilot的相关信息。这一命令是继NHTSA于2021年8月启动初步评估后&#xff0c;在2022年6月升级为正式调查的一部分&#xff0c;NHTSA近期对特…

网络安全法+网络安全等级保护

网络安全法 2014年2月&#xff0c;中央网络安全和信息化领导小组成立&#xff0c;习主席当组长 2017年6月1日&#xff0c;网络安全法正式成立 网络安全是国家安全的重要组成部分没有网络安全就没有国家安全&#xff0c;没有信息化就没有现代化 网络安全法21条 网络安全法31条 …

angular抛出 ExpressionChangedAfterItHasBeenCheckedError错误分析

当变更检测完成后又更改了表达式值时&#xff0c;Angular 就会抛出 ExpressionChangedAfterItHasBeenCheckedError 错误。Angular 只会在开发模式下抛出此错误。 在开发模式下&#xff0c;Angular 在每次变更检测运行后都会执行一次附加检查&#xff0c;以确保绑定没有更改。这…

Nginx到底是什么,他能干什么?

目录 Ngnix是什么&#xff0c;它是用来做什么的呢&#xff1f; 一。Nginx简介 二&#xff0c;为什么要用Nginx呢&#xff1f; 二。Nginx应用 1.HTTP代理和反向代理 2.负载均衡 Ngnix是什么&#xff0c;它是用来做什么的呢&#xff1f; 一。Nginx简介 Nginx是enginex的简写&…

Redis——如何解决redis穿透、雪崩、击穿问题

目录 一、查询商品信息的常规代码示例二、缓存击穿2.1、缓存击穿的理解2.2、缓存击穿的解决方案2.3、解决缓存击穿的代码示例 三、缓存雪崩3.1、缓存雪崩的理解3.2、缓存雪崩的解决方案3.2.1、缓存集中过期的情况3.2.2、缓存服务器宕机的情况3.2.3、缓存服务器断电的情况 3.3、…

springBoot打印精美logo

文章目录 &#x1f412;个人主页&#x1f3c5;JavaEE系列专栏&#x1f4d6;前言&#xff1a;&#x1f380;文本logo &#x1f412;个人主页 &#x1f3c5;JavaEE系列专栏 &#x1f4d6;前言&#xff1a; 本篇博客主要以提供springBoot打印精美logo &#x1f380;文本logo ??…