【深度学习实战】kaggle 自动驾驶的假场景分类

本次分享我在kaggle中参与竞赛的历程,这个版本是我的第一版,使用的是vgg。欢迎大家进行建议和交流。

概述

  • 判断自动驾驶场景是真是假,训练神经网络或使用任何算法来分类驾驶场景的图像是真实的还是虚假的。

  • 图像采用 RGB 格式并以 JPEG 格式压缩。

  • 标签显示 (1) 真实和 (0) 虚假

  • 二元分类

数据集描述

文件
train.csv - 训练集标签
Sample_submission.csv - 正确格式的示例提交文件
Train/- 训练图像
Test/ - 测试图像

模型思路

由于是要进行图像的二分类任务,因此考虑使用迁移学习,将vgg16中的卷积层和卷积层的参数完全迁移过来,不包括顶部的全连接层,自己设计适合该任务的头部结构,然后加以训练,绘制图像查看训练结果。

vgg16简介

VGG16 是由牛津大学视觉几何组(VGG)在2014年提出的卷积神经网络(CNN)。它由16个层组成,其中包含13个卷积层和3个全连接层。其特点是使用3x3的小卷积核和2x2的最大池化层,网络深度较深,有效提取图像特征。VGG16在图像分类任务中表现优异,尤其是在ImageNet挑战中取得了良好成绩。尽管计算量大、参数众多,但它因其简单而高效的结构,仍广泛应用于迁移学习和其他计算机视觉任务中。

源码+解析

  1. 第一步,导入所需的库。
import os
import cv2
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.applications.vgg16 import preprocess_input
  1. 加载文件
# 路径和文件
data_file = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/train.csv'
image_test = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/Test/'
image_train = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/Train/'

# 加载标签数据
df = pd.read_csv(data_file)
df['image_path'] = df['image'].apply(lambda x: os.path.join(image_train, x))

n_classes = df['label'].nunique()

df.head()  # 显示数据的前几行,检查路径和标签

输出

	image	label	image_path
0	1.jpg	editada	/kaggle/input/cidaut-ai-fake-scene-classificat...
1	2.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...
2	3.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...
3	6.jpg	editada	/kaggle/input/cidaut-ai-fake-scene-classificat...
4	8.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...

原始train.csv文件只有前两列,image 和label 列,为了方便读取图像文件,新添加了一列image_path用来记录图像文件的具体路径。

# 初始化空列表 x 用于存储图像
x = []

# 遍历每一行读取图像
for index, row in df.iterrows():
    image_path = row['image_path']  # 获取图像路径
    img = cv2.imread(image_path)  # 使用 cv2 读取图像
    
    if img is not None:
        img_resized = cv2.resize(img, (256, 256))  # 调整图像尺寸为 (256, 256)
        x.append(img_resized)  # 将读取的图像添加到列表 x 中
    else:
        print(f"图像 {row['image_path']} 读取失败")  # 打印失败的路径

# x 列表现在包含了所有读取的图像
print(f"总共有 {len(x)} 张图像被读取")

输出

总共有 720 张图像被读取

通过输出结果,可以看到图像被正确的读取了。并且将图像的大小调整为vgg所能用的256*256的尺寸,存放在变量x中。

  1. 第三步,进行数据处理
# 将图像转换为 NumPy 数组
x = np.array(x)

# 标签映射并进行 one-hot 编码
y = df['label'].map({'real': 1, 'editada': 0})
y = np.array(y)
y = to_categorical(y, num_classes=2)  # 二分类

# 分割训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

# 检查转换后的结果
print(f"x_train.shape: {x_train.shape}")
print(f"y_train.shape: {y_train.shape}")
print(f"x_test.shape: {x_test.shape}")
print(f"y_test.shape: {y_test.shape}")

输出

x_train.shape: (576, 256, 256, 3)
y_train.shape: (576, 2)
x_test.shape: (144, 256, 256, 3)
y_test.shape: (144, 2)

这里是为了将原始的图像转换为numpy数组,并且将标签进行独热编码,(对分类的标签一定要进行独热编码,转换为矩阵形式),并且切分数据集。

  1. 第四步,设计模型结构
from tensorflow.keras.regularizers import l2
# 加载预训练的VGG16卷积基(不包括顶部的全连接层)
vgg16_model = VGG16(include_top=False, weights='imagenet', input_shape=(256, 256, 3))

# 冻结VGG16的卷积层
for layer in vgg16_model.layers:
    layer.trainable = False

# 创建一个新的模型
model_fine_tuning = Sequential()

# 将VGG16的卷积基添加到新模型中
model_fine_tuning.add(vgg16_model)  # 添加VGG16卷积基
model_fine_tuning.add(Flatten())  # 将卷积特征图展平

# 添加新的全连接层并进行正则化
model_fine_tuning.add(Dense(512, activation='relu', kernel_regularizer=l2(0.01)))  # L2正则化
model_fine_tuning.add(Dropout(0.3))  # Dropout层,减少过拟合
model_fine_tuning.add(Dense(256, activation='relu', kernel_regularizer=l2(0.01)))  # 较小的全连接层
model_fine_tuning.add(Dropout(0.3) ) # 再次使用Dropout层

# 输出层
model_fine_tuning.add(Dense(2, activation='softmax'))  # 对于二分类问题,使用softmax

# 查看模型架构
model_fine_tuning.summary()

输出:

Layer (type)Output ShapeParam #
vgg16 (Functional)(None, 8, 8, 512)14,714,688
flatten (Flatten)(None, 32768)0
dense (Dense)(None, 512)16,777,728
dropout (Dropout)(None, 512)0
dense_1 (Dense)(None, 256)131,328
dropout_1 (Dropout)(None, 256)0
dense_2 (Dense)(None, 2)514

这里实现了一个基于预训练VGG16模型的迁移学习框架,用于图像分类任务。首先,加载了预训练的VGG16卷积基(不包括全连接层),并通过设置include_top=False来只使用卷积部分,从而利用其在ImageNet数据集上学到的特征。接着,冻结VGG16的卷积层,即通过将trainable属性设为False,使得这些层在训练过程中不进行更新。接下来,创建了一个新的Sequential模型,并将VGG16的卷积基添加进去,随后使用Flatten层将卷积特征图展平,为全连接层准备输入。为了增加模型的表达能力,添加了两个全连接层,每个层都应用了ReLU激活函数,并使用L2正则化来防止过拟合。为了进一步减少过拟合,模型还在每个全连接层后添加了Dropout层,丢弃30%的神经元。最后,输出层是一个具有两个神经元的全连接层,采用softmax激活函数,用于处理二分类问题。model_fine_tuning.summary()方法输出模型架构,帮助查看各层的结构和参数。通过这种方式,模型能够利用VGG16的预训练卷积基进行特征提取,并通过新添加的全连接层进行分类。

  1. 第五步,编译并训练模型
# 编译模型
model_fine_tuning.compile(loss='binary_crossentropy', 
                          optimizer=Adam(), 
                          metrics=['accuracy'])

datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    preprocessing_function=preprocess_input)  # 使用VGG16的预处理函数

# 对原始图像进行增强,并进行训练
history = model_fine_tuning.fit(datagen.flow(x_train, y_train, batch_size=32),
                                epochs=20,
                                validation_data=(x_test, y_test),
                                callbacks=[ModelCheckpoint('best_model.keras', save_best_only=True),
                                           EarlyStopping(patience=5)])

这里主要完成了对已经构建的模型(model_fine_tuning)的编译与训练过程。

  • 首先,使用compile()方法对模型进行编译,指定损失函数为binary_crossentropy,适用于二分类问题,同时选择Adam优化器,这是一种自适应学习率的优化算法,能够有效提升训练性能。在编译时,还通过metrics=['accuracy']设置了准确率作为评估指标。
  • 接着,创建了一个ImageDataGenerator对象用于数据增强,它包含多种图像变换方式,如旋转、平移、剪切、缩放、水平翻转等,这些操作可以增加数据多样性,减少过拟合,提升模型的泛化能力。
  • 此外,preprocessing_function=preprocess_input使用了VGG16预训练模型的标准预处理函数,确保输入图像的像素范围符合VGG16的训练要求。
  • 随后,通过fit()方法开始训练模型,训练数据通过datagen.flow()进行增强和批量生成,训练将在20个周期(epochs)内进行。在训练过程中,还设置了两个回调函数:ModelCheckpoint,用于保存最好的模型权重文件(best_model.keras),并且只保存验证集上表现最好的模型;
  • EarlyStopping,用于在验证集准确率不再提升时提前停止训练,patience=5表示如果5个周期内没有改进,则停止训练。这样,通过数据增强和回调函数的配合,能够有效提高训练的效果和模型的稳定性。

到这里,整个部分就基本完成了。

  1. 绘制损失和准确率图像
import matplotlib.pyplot as plt

# 获取训练过程中的损失和准确率数据
history_dict = history.history
loss = history_dict['loss']
accuracy = history_dict['accuracy']
val_loss = history_dict['val_loss']
val_accuracy = history_dict['val_accuracy']

# 绘制损失图
plt.figure(figsize=(12, 6))

# 损失图
plt.subplot(1, 2, 1)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# 准确率图
plt.subplot(1, 2, 2)
plt.plot(accuracy, label='Training Accuracy')
plt.plot(val_accuracy, label='Validation Accuracy')
plt.title('Accuracy over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

# 展示图像
plt.tight_layout()
plt.show()

在这里插入图片描述
数据文件已经上传,感兴趣的小伙伴可以下载后自己尝试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/954502.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在Linux上如何让ollama在GPU上运行模型

之前一直在 Mac 上使用 ollama 所以没注意,最近在 Ubuntu 上运行发现一直在 CPU 上跑。我一开始以为是超显存了,因为 Mac 上如果超内存的话,那么就只用 CPU,但是我发现 Llama3.2 3B 只占用 3GB,这远没有超。看了一下命…

学习 Git 的工作原理,而不仅仅是命令

Git 是常用的去中心化源代码存储库。它是由 Linux 创建者 Linus Torvalds 创建的,用于管理 Linux 内核源代码。像 GitHub 这样的整个服务都是基于它的。因此,如果您想在 Linux 世界中进行编程或将 IBM 的 DevOps Services 与 Git 结合使用,那…

【MySQL实战】mysql_exporter+Prometheus+Grafana

要在Prometheus和Grafana中监控MySQL数据库,如下图: 可以使用mysql_exporter。 以下是一些步骤来设置和配置这个监控环境: 1. 安装和配置Prometheus: - 下载和安装Prometheus。 - 在prometheus.yml中配置MySQL通过添加以下内…

适配器模式案例

如果在这样的结构中 我们在Controller中注入,但我们后续需要修改Oss时,比如从minioService改成AliyunService时,需要改动的代码很多。于是我们抽象出一个FileService,让controller只跟fileservice耦合,这样我没只需要在…

AI大模型语音交互方案,ESP32-S3联网通信,设备智能化响应联动

在科技日新月异的当下,人工智能与物联网技术正以前所未有的速度重塑着我们的生活,玩具和潮玩领域也迎来了翻天覆地的变化。 AI语音交互在玩具和潮玩产品中的应用越来越广泛,ESP32-S3凭借其高性能、低功耗、丰富的外设接口和强大的AI能力&…

Android DataBinding 结合 ViewModel的使用

Android DataBinding 结合 ViewModel的使用 一、build.gradle引入对应的依赖 在build.gradle(app模块)里引入依赖,然后Sync Now一下: android {​viewBinding {enabled true}dataBinding {enabled true}} 完整的build.gradle代…

掌握Golang strings包:高效字符串处理指南

掌握Golang strings包:高效字符串处理指南 引言为什么要学习和掌握strings包本教程的目标 基本用法strings包概述导入strings包常用函数列表及简要介绍 字符串创建与基本操作创建字符串字符串连接:Join重复字符串:Repeat修改字符串&#xff1…

论文阅读:Searching for Fast Demosaicking Algorithms

今天介绍一篇有关去马赛克的工作,去马赛克是 ISP 流程里面非常重要的一个模块,可以说是将多姿多彩的大千世界进行色彩还原的重要一步。这篇工作探索的是如何从各种各样的去马赛克算法中,选择最佳的一种。 Abstract 本文提出了一种方法&…

自建RustDesk服务器

RustDesk服务端 下面的截图是我本地的一个服务器做为演示用,你自行的搭建服务需要该服务器有固定的ip地址 1、通过宝塔面板快速安装 2、点击【安装】后会有一个配置信息,默认即可 3、点击【确认】后会自动安装等待安装完成 4、安装完成后点击【打开…

JavaSE学习心得(反射篇)

反射 前言 获取class对象的三种方式 利用反射获取构造方法 利用反射获取成员变量 利用反射获取成员方法 练习 保存信息 跟配置文件结合动态创建 前言 接上期文章:JavaSE学习心得(多线程与网络编程篇) 教程链接:黑马…

工业视觉2-相机选型

工业视觉2-相机选型 一、按芯片类型二、按传感器结构特征三、按扫描方式四、按分辨率大小五、按输出信号六、按输出色彩接口类型 这张图片对工业相机的分类方式进行了总结,具体如下: 一、按芯片类型 CCD相机:采用电荷耦合器件(CC…

信凯科技业绩波动明显:毛利率远弱行业,资产负债率偏高

《港湾商业观察》施子夫 1月8日,深交所官网显示,浙江信凯科技集团股份有限公司(以下简称“信凯科技”)主板IPO提交注册。 自2022年递交上市申请,信凯科技的IPO之路已走过两年光景,尽管提交注册&#xff0…

1.15学习

web ctfhub-网站源码 打开环境,查看源代码无任何作用,但是其提醒就在表面暗示我们用dirsearch进行目录扫描,登录kali的root端,利用终端输入dirsearch -u 网址的命令扫描该网址目录,扫描成功后获得信息,在…

Windows部署NVM并下载多版本Node.js的方法(含删除原有Node的方法)

本文介绍在Windows电脑中,下载、部署NVM(node.js version management)环境,并基于其安装不同版本的Node.js的方法。 在之前的文章Windows系统下载、部署Node.js与npm环境的方法(https://blog.csdn.net/zhebushibiaoshi…

Android Studio历史版本包加载不出来,怎么办?

为什么需要下载历史版本呢? 虽然官网推荐使用最新版本,但是最新版本如果自己碰到问题,根本找不到答案,所以博主这里推荐使用历史版本!!! Android Studio历史版本包加载不出来? 下…

一招解决word嵌入图片显示不全问题

大家在word中插入图片的时候有没有遇到过这个问题,明明已经将图片的格式选为“嵌入式”了,但是图片仍然无法完全显示,这个时候直接拖动图片可能会使文字也乱掉,很难精准定位位置。 这个问题是由于行距设置导致的,行距…

C# (图文教学)在C#的编译工具Visual Studio中使用SQLServer并对数据库中的表进行简单的增删改查--14

目录 一.安装SQLServer 二.在SQLServer中创建一个数据库 1.打开SQL Server Manager Studio(SSMS)连接服务器 2.创建新的数据库 3.创建表 三.Visual Studio 配置 1.创建一个简单的VS项目(本文创建为一个简单的控制台项目) 2.添加数据库连接 四.简单连通代码示例 简单连…

CentOS 7 下 MySQL 5.7 的详细安装与配置

1、安装准备 下载mysql5.7的安装包 https://dev.mysql.com/get/mysql-5.7.29-1.el7.x86_64.rpm-bundle.tar 下载后上传至/home目录下 2、mysql5.7安装 2.1、更新yum并安装依赖 yum update -y sudo yum install -y wget sudo yum install libaio sudo yum install perl su…

HunyuanVideo 文生视频模型实践

HunyuanVideo 文生视频模型实践 flyfish 运行 HunyuanVideo 模型使用文本生成视频的推荐配置(batch size 1): 模型分辨率(height/width/frame)峰值显存HunyuanVideo720px1280px129f60GHunyuanVideo544px960px129f45G 本项目适用于使用 N…

TY1801 反激变换器PWM GaN功率开关

TY1801 是一款针对离线式反激变换器的多模式 PWM GaN 功率开关。TY1801 内置 GaN 功率管,它具备超宽 的 VCC 工作范围,非常适用于 PD 快充等要求宽输出电压的应用场合,系统不需要使用额外的绕组或外围降压电路,节省系统 BOM 成本。TY1801 支持 Burst&…