CNN实现卫星图像分类(tensorflow)

使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47

1、查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、加载并显示训练数据

从文件夹中获取所有数据路径

import glob
import random

all_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg')  # glob相比于pathlib更简洁
random.shuffle(all_image_path)

读取并处理图像

def load_and_preprocess_image(path):
    img_raw = tf.io.read_file(path)
    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)
    img_tensor = tf.image.resize(img_tensor,[256,256])
    img_tensor = tf.cast(img_tensor,tf.float32)
    img_tensor = img_tensor/255
    return img_tensor

处理标签

label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]

显示卫星影像

import matplotlib.pyplot as plt

def plot_images_lables(all_image_path,labels,start_idx,num=5):
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[i])
        title = 'label=' + index_to_label.get(labels[start_idx+i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

plot_images_lables(all_image_path,labels,0,5)

在这里插入图片描述

4、使用tf.data.Dataset制作训练/测试数据

制作 Dataset

img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))

训练集、测试集划分

test_count = int(len(labels)*0.2) 
train_count = len(labels) - test_count

train_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)

分批次加载数据

BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)

5、CNN模型构建

from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential([
    Input(shape=(256,256,3)),
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),  # 增加filter个数,增加模型拟合能力
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),  # 默认2*2. 池化层扩大视野
    Dropout(0.2),  # 防止过拟合
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    GlobalAvgPool2D(),  # 全局平均池化
    Dense(1024,activation='relu'),
    Dense(256,activation='relu'),
    Dense(1,activation='sigmoid') 
])

model.summary()

在这里插入图片描述

6、模型编译与训练

model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),  # 已经使用sigmoid激活过了
              metrics=['acc'])

steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZE

H = model.fit(train_ds,
             epochs=10,
             steps_per_epoch=steps_per_epoch,
             validation_data=test_ds,
             validation_steps=val_step,
             verbose=1)

在这里插入图片描述

7、模型评估

import matplotlib.pyplot as plt

fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()

在这里插入图片描述

8、模型预测

def pred_img(img_path):
    img = load_and_preprocess_image(img_path)
    img = tf.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = index_to_label.get((pred>0.5).astype('int')[0][0])
    return pred
    
img_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/594180.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

swift微调多模态大语言模型

微调训练数据集指定方式的问题请教 Issue #813 modelscope/swift GitHubQwen1.5微调训练脚本中,我用到了--dataset new_data.jsonl 这个选项, 可以训练成功,但我看文档有提到--custom_train_dataset_path这个选项,这两个有什么…

C语言中字符串输入的3种方式

Ⅰ gets() 函数 gets() 函数的功能是从输入缓冲区中读取一个字符串存储到字符指针变量 str 所指向的内存空间 # include <stdio.h> int main(void) {char a[256] {0};gets(a);printf("%s",a);return 0; }Ⅱ getchar() # include <stdio.h> int mai…

「 网络安全常用术语解读 」通用安全通告框架CSAF详解

1. 简介 通用安全通告框架&#xff08;Common Security Advisory Framework&#xff0c;CSAF&#xff09;通过标准化结构化机器可读安全咨询的创建和分发&#xff0c;支持漏洞管理的自动化。CSAF是OASIS公开的官方标准。开发CSAF的技术委员会包括许多公共和私营部门的技术领导…

VsCode插件 -- Power Mode

一、安装插件 1. 首先在扩展市场里搜索 Power Mode 插件&#xff0c;如下图 二、配置插件 设置 点击小齿轮 打上勾 就可以了 第二种设置方法 1. 安装完成之后&#xff0c;使用快捷键 Ctrl Shift P 打开命令面板&#xff0c;在命令行中输入 settings.json &#xff0c; 选择首…

流畅的python-学习笔记_数据结构

概念 抽象基类&#xff1a;ABC, Abstract Base Class 序列 内置序列类型 分类 可分为容器类型和扁平类型 容器类型有list&#xff0c; tuple&#xff0c; collections.deque等&#xff0c;存储元素类型可不同&#xff0c;存储的元素也是内容的引用而非内容实际占用内存 …

.排序总讲.

在这里赘叙一下我对y总前四节所讲排序的分治思想以及递归的深度理解。 就以788.逆序对 这一题来讲&#xff08;我认为这一题对于分治和递归的思想体现的淋淋尽致&#xff09;。 题目&#xff1a; 给定一个长度为 n&#x1d45b; 的整数数列&#xff0c;请你计算数列中的逆序对…

易语言IDE界面美化支持库

易语言IDE界面美化支持库 下载下来可以看到&#xff0c;是一个压缩包。 那么&#xff0c;怎么安装到易语言中呢&#xff1f; 解压之后&#xff0c;得到这两个文件。 直接将clr和lib丢到易语言安装目录中&#xff0c;这样子就安装完成了。 打开易语言&#xff0c;在支持库配置…

C#-快速剖析文件和流,并使用(持续更新)

目录 一、概述 二、文件系统 1、检查驱动器信息 2、Path 3、文件和文件夹 三、流 1、FileStream 2、StreamWriter与StreamReader 3、BinaryWriter与BinaryReader 一、概述 文件&#xff0c;具有永久存储及特定顺序的字节组成的一个有序、具有名称的集合&#xff1b; …

【系统架构师】-选择题(十三)

1、在某企业的营销管理系统设计阶段&#xff0c;属性"员工"在考勤管理子系统中被称为"员工"&#xff0c;而在档案管理子系统中被称为"职工"&#xff0c;这类冲突称为&#xff08; 命名冲突&#xff09;。 同一个实体在同系统中存在不同的命名&am…

【4087】基于小程序实现的电影票订票小程序软件

作者主页&#xff1a;Java码库 主营内容&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app等设计与开发。 收藏点赞不迷路 关注作者有好处 文末获取源码 技术选型 【后端】&#xff1a;Java 【框架】&#xff1a;ssm 【…

局部性原理和磁盘预读

局部性原理 磁盘预读 \

Linux 基础命令、性能监控

一、Linux 基础命令 grep&#xff1a;在文件中执行关键词搜索&#xff0c;并显示匹配的结果。 -c 仅显示找到的行数 -i 忽略大小写 -n 显示行号 -v 反向选择: 仅列出没有关键词的行 (invert) -r 递归搜索文件目录 -C n 打印匹配行的前后 n 行grep login user.cpp # 在…

编译官方原版的openwrt并加入第三方软件包

最近又重新编译了最新的官方原版openwrt-2305&#xff08;2024.3.22&#xff09;&#xff0c;此处记录一下以待日后参考。 目录 1.源码下载 1.1 通过官网直接下载 1.2 映射github加速下载 1.2.1 使用github账号fork源码 1.2.2 创建gitee账号映射github openwrt 2.编译准…

cordova build android 下载gradle太慢

一、 在使用cordova run android / cordova build android 的时候 gradle在线下载 对于国内的链接地址下载太慢。 等待了很长时间之后还会报错。 默认第一次编译在线下载 gradle-7.6.1-all.zip 然后解压缩到 C:\Users\Administrator\.gradle 文件夹中,下载慢导致失败。 二…

(论文阅读-优化器)A Cost Model for SPARK SQL

目录 Abstract 1 Introduction 2 Related Work 3 Background and Spark Basics 4 Cost Model Basic Bricks 4.1 Cluster Abastraction and Cost Model Parameters 4.2 Read 4.3 Write 4.4 Shuffle Read 4.5 Broadcast 5 Modeling GPSJ Queries 5.1 Statistics and S…

【论文阅读笔记】Order Matters(AAAI 20)

个人博客地址 注&#xff1a;部分内容参考自GPT生成的内容 论文笔记&#xff1a;Order Matters&#xff08;AAAI 20&#xff09; 用于二进制代码相似性检测的语义感知神经网络 论文:《Order Matters: Semantic-Aware Neural Networks for Binary Code Similarity Detection》…

Spring Boot 整合 socket 实现简单聊天

来看一下实现的界面效果 pom.xml的maven依赖 <!-- 引入 socket --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency><!-- 引入 Fastjson &#x…

【CV-CUDA实战】使用Python+TensorRT+CVCUDA优化YOLOv8

目录 什么是CV-CUDA环境准备准备CV-CUDA静态库解压添加至变量将PyBind静态库复制到env下算子设计前处理算子 TensorRT模型加载后处理函数 完整代码输出演示为什么重新写了&#xff1f;结语 什么是CV-CUDA NVIDIA CV-CUDA™ 是一个开源项目&#xff0c;用于构建云规模人工智能 (…

【数据结构(邓俊辉)学习笔记】列表02——无序列表

文章目录 0.概述1.插入与构造1.1 插入1.1.1 前插入1.1.2后插入1.1.3 复杂度 1.2 基于复制构造1.2.1 copyNodes()1.2.2 基于复制构造1.2.3 复杂度 2.删除与析构2.1 删除2.1.1 实现2.1.2 复杂度 2.2 析构2.2.1 释放资源及清除节点2.2.2 复杂度 3.查找3.1 实现3.2 复杂度 4.唯一化…

每天五分钟深度学习:数学中常见函数中的导数

本文重点 导数是微积分学中的一个核心概念,它描述了函数在某一点附近的变化率。在物理学、工程学、经济学等众多领域中,导数都发挥着极其重要的作用。本文旨在详细介绍数学中常见函数的导数,以期为读者提供一个全面而深入的理解。 数学中常见的导数 常数函数的导数 对于常数…