J1打卡——鸟类识别

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.检查GPU

import tensorflow as tf
gpus=tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0],True)  
    tf.config.set_visible_devices([gpus[0]],"GPU")

​​​​

2.查看数据

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  
plt.rcParams['axes.unicode_minus'] = False 
import os,PIL,pathlib
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,models

data_dir="D:\\jupyter lab\\训练营\\data\\第8天\\bird_photos"
data_dir=pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

3.划分数据集

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("photo")
for images, labels in train_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i + 1)  
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

plt.imshow(images[1].numpy().astype("uint8"))

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

​​​​​​

4.创建模型

​​

from keras import layers

from keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Model

def identity_block(input_tensor, kernel_size, filters, stage, block):

    filters1, filters2, filters3 = filters

    name_base = str(stage) + block + '_identity_block_'

    x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size,padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    x = layers.add([x, input_tensor] ,name=name_base + 'add')
    x = Activation('relu', name=name_base + 'relu4')(x)
    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    filters1, filters2, filters3 = filters

    res_name_base = str(stage) + block + '_conv_block_res_'
    name_base = str(stage) + block + '_conv_block_'

    x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)
    shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)

    x = layers.add([x, shortcut], name=name_base+'add')
    x = Activation('relu', name=name_base+'relu4')(x)
    return x

def ResNet50(input_shape=[224,224,3],classes=1000):

    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3, 3))(img_input)

    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x =     conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x =     conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

    x =     conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    x =     conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

    x = AveragePooling2D((7, 7), name='avg_pool')(x)

    x = Flatten()(x)
    x = Dense(classes, activation='softmax', name='fc1000')(x)

    model = Model(img_input, x, name='resnet50')
    
    # 加载预训练模型
    model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")

    return model

model = ResNet50()
model.summary()

.....


​​​​​​

5.编译及训练模型
 

model.compile(optimizer="adam",
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

epochs = 10

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

6.结果可视化

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle("photo")

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

​​​​​7.预测图片

plt.figure(figsize=(10, 5)) 
plt.suptitle("photo")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i + 1)  
        
        plt.imshow(images[i].numpy().astype("uint8"))
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")

​​总结:

     这些代码展示了如何使用TensorFlow和Keras构建、训练并评估一个基于ResNet50架构的图像分类模型。

  1. 检查GPU:首先,通过TensorFlow检查系统中是否存在可用的GPU,并设置内存增长以优化性能。

  2. 查看数据:定义了数据集的路径,并使用matplotlib和TensorFlow库来探索数据集的基本信息,包括图片总数和类别名称。此外,还展示了一些样本图片及其标签,帮助直观理解数据集的构成。

  3. 划分数据集:利用tf.keras.preprocessing.image_dataset_from_directory函数将数据集划分为训练集和验证集,以便后续模型训练与验证。

  4. 创建模型:实现了经典的ResNet50模型,包含基本块(identity block)和卷积块(convolutional block),这些块是ResNet的核心组成部分。此步骤还包括加载预训练权重,以增强模型的泛化能力。

  5. 编译及训练模型:配置了Adam优化器、损失函数以及评价指标,并对模型进行了训练。这里采用了10个epoch的训练周期,可以根据实际情况调整这个参数以达到更好的效果。

  6. 结果可视化:通过绘制训练和验证的准确率与损失曲线,可以直观地看到模型在训练过程中的表现。这对于分析模型是否过拟合或欠拟合非常有帮助。

  7. 预测图片:最后,选取验证集中的一些图片进行预测,展示模型的实际识别效果。这一步骤有助于评估模型的最终性能,并提供了一个直观的方式来检验模型的有效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/958028.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

不建模,无代码,如何构建一个3D虚拟展厅?

在数字化浪潮的推动下,众多企业正积极探索线上3D虚拟展厅这一新型展示平台,旨在以更加生动、直观的方式呈现其产品、环境与综合实力。然而,构建一个既专业又吸引人的3D虚拟展厅并非易事,它不仅需要深厚的技术支持,还需…

翻译:How do I reset my FPGA?

文章目录 背景翻译:How do I reset my FPGA?1、Understanding the flip-flop reset behavior2、Reset methodology3、Use appropriate resets to maximize utilization4、Many options5、About the author 背景 在写博客《复位信号的同步与释放(同步复…

开源物业管理系统助力高效管理与服务提升

内容概要 开源物业管理系统为物业行业带来了诸多革新,能有效提升管理质量和服务水平。这个系统不仅灵活,适应性强,还能通过智能化功能满足各种复杂的物业管理需求。我们知道,对于物业公司来说,合理的资源配置至关重要…

UE5 开启“Python Remote Execution“

demo 代码 remote_execution.py 远程调用UE5 python代码-CSDN博客 在启用 Unreal Engine 5(UE5)的“Python 远程执行”功能后,UE5 会启动一个 UDP 组播套接字服务,以监听来自外部应用程序的 Python 命令。 具体行为如下&#xf…

web应用引入cookie机制的用途和cookie技术主要包括的内容

web应用引入cookie机制,用于用户跟踪。 (1)HTTP响应报文中的Cookie头行:set-Cookie (2)用户浏览器在本地存储、维护和管理的Cookie文件 (3)HTTP请求报文中的Cookie头行:…

win32汇编环境,按字节、双字等复制字符的操作

;运行效果 ;win32汇编环境,按字节、双字等复制字符的操作 ;这是汇编的优点之一。我们可以按字节、双字、四字、八字节等复制或挨个检查字符。 ;有时候,在接收到的一串信息中,比如访问网站时,返回的字串里,有很多0值存在&#xff0…

Node.js接收文件分片数据并进行合并处理

前言:上一篇文章讲了如何进行文件的分片:Vue3使用多线程处理文件分片任务,那么本篇文章主要看一下后端怎么接收前端上传来的分片并进行合并处理。 目录: 一、文件结构二、主要依赖1. express2. multer3. fs (文件系统模块)4. pat…

《offer 来了:Java 面试核心知识点精讲 -- 框架篇》(附资源)

继上篇文章介绍了《offer 来了:Java 面试核心知识点精讲 -- 原理篇》书后,本文章再给大家推荐兄弟篇 《offer来了:Java面试核心知识点精讲--框架篇》, 简直就是为Java开发者量身定制的面试神器。 本书是对Java程序员面试中常见的…

iOS 权限管理:同时请求相机和麦克风权限的最佳实践

引言 在开发视频类应用时,我们常常会遇到需要同时请求相机和麦克风权限的场景。比如,在用户发布视频动态时,相机用于捕捉画面,麦克风用于录制声音;又或者在直播功能中,只有获得这两项权限,用户…

docker Ubuntu实战

目录 Ubuntu系统环境说明 一、如何安装docker 二、发布.netcore应用到docker中 Ubuntu系统环境说明 cat /etc/os-release PRETTY_NAME"Ubuntu 22.04.5 LTS" NAME"Ubuntu" VERSION_ID"22.04" VERSION"22.04.5 LTS (Jammy Jellyfish)&quo…

线上突发:MySQL 自增 ID 用完,怎么办?

线上突发:MySQL 自增 ID 用完,怎么办? 1. 问题背景2. 场景复现3. 自增id用完怎么办?4. 总结 1. 问题背景 最近,我们在数据库巡检的时候发现了一个问题:线上的地址表自增主键用的是int类型。随着业务越做越…

Golang之Context详解

引言 之前对context的了解比较浅薄,只知道它是用来传递上下文信息的对象; 对于Context本身的存储、类型认识比较少。 最近又正好在业务代码中发现一种用法:在每个协程中都会复制一份新的局部context对象,想探究下这种写法在性能…

从桌面到前端:效率与渲染优化的技术进化20250122

从桌面到前端:效率与渲染优化的技术进化 在应用开发的广袤天地中,我们见证了从传统桌面开发(如 MFC、PyQt)向现代 Web 前端框架(如 React、Vue)的华丽转变。这一变革犹如一场技术革命,带来了开…

web服务器 网站部署的架构

WEB服务器工作原理 Web web是WWW(World Wide Web)的简称,基本原理是:请求(客户端)与响应(服务器端)原理,由遍布在互联网中的Web服务器和安装了Web浏览器的计算机组成 客户端发出请求的方式:地址栏请求、超链接请求、表单请求 …

快速构建springboot+vue后台管理系统

项目介绍 1.需求定义:外包项目如雨后春笋,开发工期被迫压缩,为了开发人员专注开发项目业务,早点下班能陪老婆、孩子。 2.产品定位: 简约后台管理系统 3.项目特点:此项目代码清晰、界面简洁、springboot layuiadmin 构…

C语言--数据在内存中的存储

数据在内存中的存储 主要研究整型和浮点型在内存中的存储。 1. 整数在内存中的存储 在学习操作符的时候,就了解过了下面的内容: 整数的2进制表示方法有三种,即原码、反码和补码。 有符号的整数,三种表示方法均有符号位和数值…

HTB:Sauna[WriteUP]

目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机…

wireshark工具简介

目录 1 wireshark介绍 2 wireshark抓包流程 2.1 选择网卡 2.2 停止抓包 2.3 保存数据 3 wireshark过滤器设置 3.1 显示过滤器的设置 3.2 抓包过滤器 4 wireshark的封包列表与封包详情 4.1 封包列表 4.2 封包详情 参考文献 1 wireshark介绍 wireshark是非常流行的网络…

2025.1.20——一、[RCTF2015]EasySQL1 二次注入|报错注入|代码审计

题目来源:buuctf [RCTF2015]EasySQL1 目录 一、打开靶机,整理信息 二、解题思路 step 1:初步思路为二次注入,在页面进行操作 step 2:尝试二次注入 step 3:已知双引号类型的字符型注入,构造…

kong 网关和spring cloud gateway网关性能测试对比

该测试只是简单在同一台机器设备对spring cloud gateway网关和kong网关进行对比,受限于笔者所拥有的资源,此处仅做简单评测。 一、使用spring boot 的auth-service作为服务提供者 该服务提供了一个/health接口,接口返回"OK"&…