TensorFlow中slim包的具体用法

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim
    |-- BUILD
    |-- README.md
    |-- WORKSPACE
    |-- __init__.py
    |-- datasets
    |   |-- __init__.py
    |   |-- __pycache__
    |   |   |-- __init__.cpython-37.pyc
    |   |   |-- dataset_utils.cpython-37.pyc
    |   |   |-- download_and_convert_cifar10.cpython-37.pyc
    |   |   |-- download_and_convert_flowers.cpython-37.pyc
    |   |   `-- download_and_convert_mnist.cpython-37.pyc
    |   |-- build_imagenet_data.py
    |   |-- cifar10.py
    |   |-- dataset_factory.py
    |   |-- dataset_utils.py
    |   |-- download_and_convert_cifar10.py
    |   |-- download_and_convert_flowers.py
    |   |-- download_and_convert_imagenet.sh
    |   |-- download_and_convert_mnist.py
    |   |-- download_imagenet.sh
    |   |-- flowers.py
    |   |-- imagenet.py
    |   |-- imagenet_2012_validation_synset_labels.txt
    |   |-- imagenet_lsvrc_2015_synsets.txt
    |   |-- imagenet_metadata.txt
    |   |-- mnist.py
    |   |-- preprocess_imagenet_validation_data.py
    |   `-- process_bounding_boxes.py
    |-- deployment
    |   |-- __init__.py
    |   |-- model_deploy.py
    |   `-- model_deploy_test.py
    |-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式
    |-- eval_image_classifier.py        # 测试模型分类效果
    |-- export_inference_graph.py
    |-- export_inference_graph_test.py
    |-- nets
    |   |-- __init__.py
    |   |-- alexnet.py
    |   |-- alexnet_test.py
    |   |-- cifarnet.py
    |   |-- cyclegan.py
    |   |-- cyclegan_test.py
    |   |-- dcgan.py
    |   |-- dcgan_test.py
    |   |-- i3d.py
    |   |-- i3d_test.py
    |   |-- i3d_utils.py
    |   |-- inception.py
    |   |-- inception_resnet_v2.py
    |   |-- inception_resnet_v2_test.py
    |   |-- inception_utils.py
    |   |-- inception_v1.py
    |   |-- inception_v1_test.py
    |   |-- inception_v2.py
    |   |-- inception_v2_test.py
    |   |-- inception_v3.py
    |   |-- inception_v3_test.py
    |   |-- inception_v4.py
    |   |-- inception_v4_test.py
    |   |-- lenet.py
    |   |-- mobilenet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- conv_blocks.py
    |   |   |-- madds_top1_accuracy.png
    |   |   |-- mnet_v1_vs_v2_pixel1_latency.png
    |   |   |-- mobilenet.py
    |   |   |-- mobilenet_example.ipynb
    |   |   |-- mobilenet_v2.py
    |   |   `-- mobilenet_v2_test.py
    |   |-- mobilenet_v1.md
    |   |-- mobilenet_v1.png
    |   |-- mobilenet_v1.py
    |   |-- mobilenet_v1_eval.py
    |   |-- mobilenet_v1_test.py
    |   |-- mobilenet_v1_train.py
    |   |-- nasnet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- nasnet.py
    |   |   |-- nasnet_test.py
    |   |   |-- nasnet_utils.py
    |   |   |-- nasnet_utils_test.py
    |   |   |-- pnasnet.py
    |   |   `-- pnasnet_test.py
    |   |-- nets_factory.py
    |   |-- nets_factory_test.py
    |   |-- overfeat.py
    |   |-- overfeat_test.py
    |   |-- pix2pix.py
    |   |-- pix2pix_test.py
    |   |-- resnet_utils.py
    |   |-- resnet_v1.py
    |   |-- resnet_v1_test.py
    |   |-- resnet_v2.py
    |   |-- resnet_v2_test.py
    |   |-- s3dg.py
    |   |-- s3dg_test.py
    |   |-- vgg.py
    |   `-- vgg_test.py
    |-- preprocessing
    |   |-- __init__.py
    |   |-- cifarnet_preprocessing.py
    |   |-- inception_preprocessing.py
    |   |-- lenet_preprocessing.py
    |   |-- preprocessing_factory.py
    |   `-- vgg_preprocessing.py
    |-- scripts                     # gqr:存储的是相关的模型训练脚本                
    |   |-- export_mobilenet.sh
    |   |-- finetune_inception_resnet_v2_on_flowers.sh
    |   |-- finetune_inception_v1_on_flowers.sh
    |   |-- finetune_inception_v3_on_flowers.sh
    |   |-- finetune_resnet_v1_50_on_flowers.sh
    |   |-- train_cifarnet_on_cifar10.sh
    |   `-- train_lenet_on_mnist.sh
    |-- setup.py
    |-- slim_walkthrough.ipynb
    `-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e

# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径

# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50

# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径

# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
  mkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
  wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
  tar -xvf resnet_v1_50_2016_08_28.tar.gz
  mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
  rm resnet_v1_50_2016_08_28.tar.gz
fi

# Download the dataset
python download_and_convert_data.py \
  --dataset_name=flowers \
  --dataset_dir=${DATASET_DIR}

# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50 \
  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR} \
  --eval_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --checkpoint_path=${TRAIN_DIR} \
  --model_name=resnet_v1_50 \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR}/all \
  --eval_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \
  --train_dir=./data/flowers-models/resnet_v1_50\
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=./data/flowers \
  --model_name=resnet_v1_50 \
  --checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \ 
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/91065.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode 2235.两整数相加

一、两整数相加 给你两个整数 num1 和 num2,返回这两个整数的和。 示例 1: 输入:num1 12, num2 5 输出:17 解释:num1 是 12,num2 是 5 ,它们的和是 12 5 17 ,因此返回 17 。示例…

【OCR识别】tess4j图片识别文字

什么是OCR? OCR (Optical Character Recognition,光学字符识别)是指电子设备(例如扫描仪或数码相机)检查纸上打印的字符,通过检测暗、亮的模式确定其形状,然后用字符识别方法将形状翻译成计算机…

ServiceManager接收APP的跨进程Binder通信流程分析

现在一起来分析Server端接收(来自APP端)Binder数据的整个过程,还是以ServiceManager这个Server为例进行分析,这是一个至下而上的分析过程。 在分析之前先思考ServiceManager是什么?它其实是一个独立的进程,由init解析i…

【人脸考勤项目】

本项目主要是基于Opencv完成的人脸识别的考勤系统 人脸检测器的5种实现方法 方法一:haar方法进行实现(以下是基于notebook进行编码) # 步骤 # 1、读取包含人脸的图片 # 2.使用haar模型识别人脸 # 3.将识别结果用矩形框画出来# 导入相关包 …

ARM开发,stm32mp157a-A7核中断实验(实现按键中断功能)

1.实验目的:实现KEY1/LEY2/KE3三个按键,中断触发打印一句话,并且灯的状态取反; key1 ----> LED3灯状态取反; key2 ----> LED2灯状态取反; key3 ----> LED1灯状态取反; 2.分析框图: …

ACL2023 Prompt 相关文章速通 Part 1

Accepted Papers link: ACL2023 main conference accepted papers 文章目录 Accepted PapersPrompter: Zero-shot Adaptive Prefixes for Dialogue State Tracking Domain AdaptationQuery Refinement Prompts for Closed-Book Long-Form QAPrompting Language Models for Lin…

爬虫逆向实战(二十一)-- 某某点集登录与获取数据

登录 一、数据接口分析 主页地址:某某点集 1、抓包 通过抓包可以发现登录接口是phonePwdLogin 2、判断是否有加密参数 请求参数是否加密? 通过查看“载荷”模块可以发现有pwd和sig两个加密参数 请求头是否加密? 无响应是否加密&#x…

卡尔曼滤波

第一章知识点回顾 表1变量符号对照表 1.1数学期望 数学期望表示为每次可能的结果乘上结果概率的总和。 1.1.1 数学期望的性质 假设常数为 C ,随机变量 X 和 Y ,则 1.2 方差(variance) 概率论中和统计中的方差反映单个&…

Java进阶篇--进程和线程的区别

进程和线程 进程 在一个操作系统中,每个独立执行的程序都可称之为一个进程,也就是“正在运行的程序”。目前大部分计算机上安装的都是多任务操作系统,即能够同时执行多个应用程序,最常见的有Windows、Linux、Unix等。比如在Wind…

S波形及鱼眼扭曲源码

三角波形扭曲&#xff1a; void sinwave(cv::Mat& src,cv::Mat& dst) {dst.create(src.rows, src.cols, CV_8UC3);dst.setTo(0);src.copyTo(dst);int PI 3.1415;int RANGE dst.cols/2;for (int i 0; i < dst.rows; i) {double temp (dst.cols - RANGE) / 2 (d…

Git,分布式版本控制工具

1.为常用指令配置别名&#xff08;可选&#xff09; 打开用户目录&#xff0c;创建.bashrc文件 &#xff08;touch ~/.bashrc&#xff09; 2.往其输入内容 #用于输出git提交日志 alias git-loggit log --prettyoneline --all --graph --abbrev-commit #用于输出当前目录所有文…

意外发现Cortex-M内核带的64bit时间戳,比32bit的DWT时钟周期计数器更方便,再也不用担心溢出问题了

视频&#xff1a; https://www.bilibili.com/video/BV1Bw411D7F5 意外发现Cortex-M内核带的64bit时间戳&#xff0c;比32bit的DWT时钟周期计数器更方便&#xff0c;再也不用担心溢出问题了 介绍&#xff1a; 看参数手册的Debug章节&#xff0c;System ROM Table里面带Timestam…

首轮征稿 | 2024年第二届先进无人飞行系统国际会议(ICAUAS 2024)

会议简介 Brief Introduction 2024年第二届先进无人飞行系统国际会议(ICAUAS 2024) 会议时间&#xff1a;2024年4月5日-7日 召开地点&#xff1a;中国武汉 大会官网&#xff1a;ICAUAS 2024-2024 2nd International Conference on Advanced Unmanned Aerial Systems 由华中科技…

软件测试框架实战:Python+Slenium搭建Web自动化测试框架全教程

PythonSelenium是一种流行的Web自动化测试框架&#xff0c;可以模拟真实的用户操作&#xff0c;对网页进行功能和样式的验证。要通过selenium测试网页&#xff0c;需要以下几个步骤&#xff1a; 安装selenium库和浏览器驱动 。 使用selenium提供的方法来控制浏览器窗口大小、后…

图像检索,目标检测map的实现

一、图像检索指标Rank1,map 参考&#xff1a;https://blog.csdn.net/weixin_41427758/article/details/81188164?spm1001.2014.3001.5506 1.Rank1: rank-k&#xff1a;算法返回的排序列表中&#xff0c;前k位为存在检索目标则称为rank-k命中。 常用的为rank1&#xff1a;首…

老人摔倒智能识别检测算法

老人摔倒智能识别检测算法通过yolov8深度学习算法模型架构&#xff0c;老人摔倒智能识别检测算法能够实时监测老人的活动状态及时发现摔倒事件&#xff0c;系统会立即触发告警&#xff0c;向相关人员发送求助信号&#xff0c;减少延误救援的时间。YOLOv8 算法的核心特性和改动可…

对于pycharm 运行的时候不在cmd中运行,而是在python控制台运行的情况,如何处理?

对于pycharm 运行的时候不在cmd中运行&#xff0c;而是在python控制台运行的情况&#xff0c;如何处理&#xff1f; 比如&#xff0c;你在运行你的代码的时候 它总在python控制台运行&#xff0c;十分难受 解决方法 在pycharm中设置下即可&#xff0c;很简单 选择运行点击…

XSS攻击与防御

目录 一、环境配置 kali安装beef contos7安装dvwa 二、XSS攻击简介 三、XSS攻击的危害 四、xSS攻击的分类 五、XSS产生的原因 六、构造XSS攻击脚本 (一)基础知识 常用的html标签 常用的js脚本 (二)构造脚本的方式弹窗警告 七、自动XSS攻击 (一)BeEF简介 (二)BeEF…

亲测influxdb安装为window后台服务

InfluxDB 安装 64bit&#xff1a;https://dl.influxdata.com/influxdb/releases/influxdb-1.7.4_windows_amd64.zip 解压安装包 修改配置文件 [meta]# Where the metadata/raft database is storeddir "D:/influxdb/meta"...[data]# The directory where the TSM…

容器内执行命令

上篇文章向读者介绍了一个Nginx的例子&#xff0c;对于Nginx这样一个容器而言&#xff0c;当它启动成功后&#xff0c;我们不可避免的需要对Nginx进行的配置进行修改&#xff0c;那么这个修改要如何完成呢&#xff1f;且看下文。 依附容器 docker attach 依附容器这个主要是…