TensorFlow2实战-系列教程13:Resnet实战1

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

Resnet实战1
Resnet实战2
Resnet实战3

1、残差连接

深度学习中出现了随着网络的堆叠效果下降的现象,Resnet使用残差连接的方法解决了这个问题,让深度学习从此变得深了起来。
残差连接的做法可以表示为:
y = f ( x ) + x ( s h o r t c u t ) y = f(x)+x(shortcut) y=f(x)+x(shortcut)
其中y表示网络最后的输出,x为输入,f(x)则表示输入经过几层网络的输出结果,比如三次卷积+三次批归一化+三次relu,一般情况下f(x)就是网络的最终输出,这里再加上x就是一个残差连接的操作。

残差连接的操作保证了,x经过网络后得到的y一定比f(x)更好的结果,最差是同等效果,也就是保证了不会出现效果降低的情况。

这个shortcut是什么意思呢?因为x经过几次卷积后,可能会出现多个特征图,也就是f(x)和x的通道数不一样了,这个时候就需要对x的通道数进行调整再与f(x)相加得到y,如果通道数一样就不需要调整了
在这里插入图片描述
上图就是通道数没有发生变化的情况, y = f ( x ) + x y = f(x)+x y=f(x)+x,x经过两次(卷积+批归一化+ReLU)和一次(卷积+批归一化)后得到f(x),再加上x后经过ReLU就得到了最终的y
在这里插入图片描述
上图就是通道数发生变化的情况, y = f ( x ) + C o n v 2 d ( x ) y = f(x)+Conv2d(x) y=f(x)+Conv2d(x),x经过两次(卷积+批归一化+ReLU)和一次(卷积+批归一化)后得到f(x),x再经过一次(二维卷积+批归一化),这个二维卷积的卷积核是1x1的,经过这个二维卷积的x再加上f(x)后经过ReLU就得到了最终的y

2、项目介绍

在这里插入图片描述

  1. dataset文件夹,将原始数据分割成训练、验证、测试三个数据集
  2. models构建模型的代码,包含resnet31、resnet50、resnet101、resnet152的构建代码,以及残差模块实现的代码
  3. original_dataset,原始数据,包含猫、狗、熊猫3个类别的数据,每个类别1000张图像
  4. save_model,训练模型保存的路径
  5. config.py 设置配置参数的代码
  6. evaluate.py 使用测试集对模型进行测试的代码
  7. prepare_data.py 数据预处理的辅助函数代码
  8. split_dataset.py 将原始数据集分割成训练集、验证集、测试集的代码
  9. train.py 训练验证的代码

3、训练脚本train.py解读------数据预处理

from __future__ import absolute_import, division, print_function
import tensorflow as tf
from models import resnet50, resnet101, resnet152, resnet34
import config
from prepare_data import generate_datasets
import math

if __name__ == '__main__':
    # GPU settings
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

导入项目工具包和辅助函数
配置 TensorFlow 中的 GPU 内存,:

  1. tf.config.experimental.list_physical_devices('GPU'):这个函数调用列出了 TensorFlow 在你的机器上可用的所有 GPU
  2. if gpus: 这个检查用来确认是否有可用的 GPU。如果有,它将继续对每一个 GPU 进行配置
  3. 在循环内部,对每一个 GPU 调用 tf.config.experimental.set_memory_growth(gpu, True),这使得 GPU 上的内存增长被启用
# get the original_dataset
    train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count = generate_datasets()
def generate_datasets():
    train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir)
    valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir)
    test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir)

    # read the original_dataset in the form of batch
    train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE)
    valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE)
    test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE)

    return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
def get_dataset(dataset_root_dir):
    all_image_path, all_image_label = get_images_and_labels(data_root_dir=dataset_root_dir)
    # print("image_path: {}".format(all_image_path[:]))
    # print("image_label: {}".format(all_image_label[:]))
    # load the dataset and preprocess images
    image_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map(load_and_preprocess_image)
    label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
    dataset = tf.data.Dataset.zip((image_dataset, label_dataset))
    image_count = len(all_image_path)

    return dataset, image_count
def get_images_and_labels(data_root_dir):
    # 得到所有图像路径
    data_root = pathlib.Path(data_root_dir)
    all_image_path = [str(path) for path in list(data_root.glob('*/*'))]
    # 得到标签名字
    label_names = sorted(item.name for item in data_root.glob('*/'))
    # 例如:{'cats': 0, 'dogs': 1, 'panda': 2}
    label_to_index = dict((index, label) for label, index in enumerate(label_names))
    # 每一个图像对应的标签
    all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path]

    return all_image_path, all_image_label

def load_and_preprocess_image(img_path):
    # read pictures
    img_raw = tf.io.read_file(img_path)
    # decode pictures
    img_tensor = tf.image.decode_jpeg(img_raw, channels=channels)
    # resize
    img_tensor = tf.image.resize(img_tensor, [image_height, image_width])
    img_tensor = tf.cast(img_tensor, tf.float32)
    # normalization
    img = img_tensor / 255.0
    return img

load_and_preprocess_image()函数:

  1. 通过读取一个图像的路径
  2. 返回Tensor
  3. 进去进行归一化

get_images_and_labels()函数:

  1. 通过数据集的地址,获取当前目录下的所有图像的名称
  2. 在加上前缀路径和文件后缀,得到当前所有图像的对应的地址
  3. 返回地址和标签

get_dataset()函数:

  1. 通过调用get_images_and_labels()函数,得到当前目录下的图像的对应的地址和标签
  2. 使用from_tensor_slices方法和load_and_preprocess_image()函数读取地址和标签转换为Tensor
  3. 返回标签和数据组成的Tensor以及数据量

generate_datasets()函数:

  1. 训练、验证、测试数据路径分别通过调用get_dataset()函数得到训练、验证、测试数据Tensor和数据量
  2. 对训练、验证、测试数据加上batch_size和shuffle参数
  3. 返回训练、验证、测试数据Tensor和数据量

Resnet实战1
Resnet实战2
Resnet实战3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/362281.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SpringCloud Gateway(4.1.0) 返回503:原因分析与解决方案

文章目录 一、环境版本二、原因分析三、解决方案 一、环境版本 Versionspring-cloud-dependencies2023.0.0spring-cloud-starter-gateway4.1.0Nacosv2.3.0 二、原因分析 在 Spring Cloud Gateway 的早期版本中,Ribbon 被用作默认的负载均衡器。随着Spring Cloud的…

jsonpath相关---JSONPath - 用于 JSON 的 XPath

一.简介 XML 的一个经常强调的优点是提供了大量工具来分析、转换和有选择地从 XML 文档中提取数据。XPath 就是这些强大的工具之一。 现在是时候想知道,是否需要像 XPath4JSON 这样的东西,以及它可以解决哪些问题。 无需特殊脚本,即可以交…

眼未来,萨科微半导体将持续发挥自身在技术研发和产品创新方面的优势

金航标kinghelm萨科微slkor宋仕强说,着眼未来,萨科微半导体将持续发挥自身在技术研发和产品创新方面的优势,以优质高效的半导体解决方案满足全球各地市场的需求。目前,萨科微的产品线已经囊括了二极管、三极管、功率器件、电源管理芯片等多个系列,并在霍尔传感器、A…

【MySQL】学习并使用聚合函数和DQL进行分组查询

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-t8K8tl6eNwqdFmcD {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

【开源】SpringBoot框架开发天然气工程运维系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统角色分类2.2 核心功能2.2.1 流程 12.2.2 流程 22.3 各角色功能2.3.1 系统管理员功能2.3.2 用户服务部功能2.3.3 分公司(施工单位)功能2.3.3.1 技术员角色功能2.3.3.2 材料员角色功能 2.3.4 安…

Python算法题集_轮转数组

本文为Python算法题集之一的代码示例 题目189 题目:轮转数组 说明:给定一个整数数组 nums,将数组中的元素向右轮转 k 个位置,其中 k 是非负数。 示例 1: 输入: nums [1,2,3,4,5,6,7], k 3 输出: [5,6,7,1,2,3,4] 解释: 向右…

Hadoop3.x基础(2)- HDFS

来源:B站尚硅谷 目录 HDFS概述HDFS产出背景及定义HDFS优缺点HDFS组成架构HDFS文件块大小(面试重点) HDFS的Shell操作(开发重点)基本语法命令大全常用命令实操准备工作上传下载HDFS直接操作 HDFS的API操作HDFS的API案例…

微信小程序(二十八)网络请求数据进行列表渲染

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.GET请求的规范 2.数据赋值的方法 源码&#xff1a; index.wxml <!-- 列表渲染基础写法&#xff0c;不明白的看上一篇 --> <view class"students"><view class"item">&…

yolov8训练自己的关键点检测模型

参考&#xff1a; https://blog.csdn.net/weixin_38807927/article/details/135036450 标注数据集 安装labelme pip install labelme -i https://pypi.tuna.tsinghua.edu.cn/simple如果报错 $ labelme 2024-01-31 03:16:20,636 [INFO ] __init__:get_config:67- Loading …

MIMIC-IV数据库, 如何提取哪些肺栓塞病人进行了溶栓手术治疗?

溶栓手术是通过药物或者手术的方式&#xff0c;使闭塞的血管再通的一种手术。 溶栓手术主要是通过药物或者手术的方式&#xff0c;使闭塞的血管再通的一种手术。常用的药物有尿激酶、链激酶等&#xff0c;这些药物可以激活纤溶酶原&#xff0c;使纤溶酶原转化为纤溶酶&#xff…

Shell的字符处理和expect

一、Here Document免交互 1.1Here Document概述 使用I/O重定向的方式将命令列表提供给交互式程序&#xff0c;标准输入的一种替代品 格式: 命令 <<标记 输入内容 标记 1.2Here Document使用注意事项 标记可以使用任意合法字符结尾的标记一定要顶格写&#xff0c;前面…

DEV-C++ ege.h库 绘图教程(九)

一、Getting Start 前情回顾&#xff1a; DEV-C ege.h库 绘图教程 今天我们将来讲一讲一些关于杂项的函数。 二、控制台函数 1.initconsole 初始化并显示控制台窗口。 &#xff08;但因为Dev C默认就是显示窗口的&#xff0c;所以这个函数一点也没用&#xff09; 但如果想…

基于C++的面向对象程序设计:类与对象的深入剖析

面向对象程序设计的基本特点 面向对象程序设计的基本特点包括&#xff1a;抽象、封装、继承、多态。 抽象 抽象是指对具体问题或对象进行概括&#xff0c;抽出其公共性质并加以描述的过程。一般情况抽象分为数据抽象和行为抽象&#xff0c;其中数据抽象是指一个对象区别于另…

二叉树顺序结构堆实现

目录 Test.c测试代码 test1 test2 test3 &#x1f387;Test.c总代码 Heap.h头文件&函数声明 头文件 函数声明 &#x1f387;Heap.h总代码 Heap.c函数实现 ☁HeapInit初始化 ☁HeapDestroy销毁 ☁HeapPush插入数据 【1】插入数据 【2】向上调整Adjustup❗ …

关于谷歌新版调试用具(Chrome Dev Tool ),网络选项(chrome-network)默认开启下拉模式的设置

今天在使用谷歌浏览器进行调试的时候&#xff0c;打开调试工具网络选项发现过滤不同模式的选项卡不见了&#xff0c;转而变成一个下拉式选项&#xff0c;如下图 这样一来使得切换不同类型查看的时候变得非常不方便&#xff0c;然后网上查了一下发现这个功能谷歌在很早版本就已…

Mysql 主从库的重新配置

1.从库和主库的数据差异实在太大&#xff0c;反复处理数据耗时耗力&#xff0c;不如重做。 2.备份主数据库(命令备份的) usr/local/mysql/bin/mysqldump -h 100.1.4.42 -P 5566 -u root -p 备份数据库 > /mysql/db/备份的名称.sql 3.停止从库复制 登录到MySQL从库&#x…

腾讯云邀请你参与【腾讯2024技术答人挑战赛】 赢取丰厚的礼品

腾讯云邀请你参与【腾讯2024技术答人挑战赛】 赢取丰厚的礼品 2024年 腾讯礼品大派送 保持技术好奇心是程序员构建护城河的重要一环&#xff0c;快来测测你现在的技术知识面在中国程序员中排第几&#xff1f; 参与答题更有iPad、Pico VR游戏机、Switch等、腾讯云官方认证证书好…

Prometheus+grafana配置监控系统

使用docker compose安装 方便拓展, 配置信息都放在在 /docker/prometheus 目录下 1.目录结构如下 . ├── conf │ └── prometheus.yml ├── grafana_data ├── prometheus_data └── prometheus_grafana.yaml2.创建目录文件 mkdir /docker/prometheus &&am…

Leetcode—2670. 找出不同元素数目差数组【简单】

2024每日刷题&#xff08;一零七&#xff09; Leetcode—2670. 找出不同元素数目差数组 实现代码 class Solution { public:vector<int> distinctDifferenceArray(vector<int>& nums) {unordered_set<int> s;int n nums.size();vector<int> dif…

CapCut - 剪映国际版11.0.0

【应用名称】&#xff1a;CapCut - 剪映国际版 【适用平台】&#xff1a;#Android 【软件标签】&#xff1a;#CapCut #剪映国际版 【应用版本】&#xff1a;11.0.0 【应用大小】&#xff1a;231MB 【软件说明】&#xff1a;软件升级更新。目前大家广泛使用的最令人惊叹、最专业…