深度学习进行数据增强(实战篇)

本文章是我在进行深度学习时做的数据增强,接着我们上期的划分测试集和训练集来做.

文章目录

前言

数据增强有什么好处?

一、构造数据增强函数

二、数据增强

总结



前言

很多人在深度学习的时候在对数据的处理时一般采用先数据增强在进行对训练集和测试集的划分,其实我感觉这样做还是有点不好的,其实这里也是分情况的.

1.如果你的数据集很少每个类别很少,我建议先进行数据增强,后进行训练集和测试集的划分还是可以的,但要注意也是因为你的数据集很小,所以有必要对你的模型进行k-折交叉验证.

2.如果你的数据量还可以的话,我建议先划分测试集和训练集,之后再进行数据增强,再训练的时候再对你的增强后的数据再进行训练和验证集的划分,这是因为你先进行数据增强后划分训练集和测试集,很有可能将你的测试集信息透露给了训练集,这样模型感觉很不错,但最后投入到实际就拉跨了,如果你想模型感觉很好的话当我没说,当然可以先数据增强在划分很多人也是这么干的.

3.如果你的数据量非常的大,那基本上也不用数据增强了,这一点还是要根据实际的情况.


数据增强有什么好处?

正确使用数据增强能够带来如下好处:

  1. 降低数据采集和数据标记的成本
  2. 通过赋予模型更多的多样性和灵活性来改进模型泛化
  3. 提高模型在预测中的准确性,因为它使用更多数据来训练模型
  4. 减少数据的过拟合
  5. 通过增加少数类中的样本来处理数据集中的不平衡

一、构造数据增强函数

我们使用opencv对我们的数据进行数据增强

1.添加椒盐噪声

2.高斯噪声

3.昏暗

4.亮度

5.旋转翻转

# -*- coding: utf-8 -*-

import cv2
import numpy as np
import os.path
import copy

# 椒盐噪声
def SaltAndPepper(src,percetage):
    SP_NoiseImg=src.copy()
    SP_NoiseNum=int(percetage*src.shape[0]*src.shape[1])
    for i in range(SP_NoiseNum):
        randR=np.random.randint(0,src.shape[0]-1)
        randG=np.random.randint(0,src.shape[1]-1)
        randB=np.random.randint(0,3)
        if np.random.randint(0,1)==0:
            SP_NoiseImg[randR,randG,randB]=0
        else:
            SP_NoiseImg[randR,randG,randB]=255
    return SP_NoiseImg

# 高斯噪声
def addGaussianNoise(image,percetage):
    G_Noiseimg = image.copy()
    w = image.shape[1]
    h = image.shape[0]
    G_NoiseNum=int(percetage*image.shape[0]*image.shape[1])
    for i in range(G_NoiseNum):
        temp_x = np.random.randint(0,h)
        temp_y = np.random.randint(0,w)
        G_Noiseimg[temp_x][temp_y][np.random.randint(3)] = np.random.randn(1)[0]
    return G_Noiseimg

# 昏暗
def darker(image,percetage=0.9):
    image_copy = image.copy()
    w = image.shape[1]
    h = image.shape[0]
    #get darker
    for xi in range(0,w):
        for xj in range(0,h):
            image_copy[xj,xi,0] = int(image[xj,xi,0]*percetage)
            image_copy[xj,xi,1] = int(image[xj,xi,1]*percetage)
            image_copy[xj,xi,2] = int(image[xj,xi,2]*percetage)
    return image_copy

# 亮度
def brighter(image, percetage=1.5):
    image_copy = image.copy()
    w = image.shape[1]
    h = image.shape[0]
    #get brighter
    for xi in range(0,w):
        for xj in range(0,h):
            image_copy[xj,xi,0] = np.clip(int(image[xj,xi,0]*percetage),a_max=255,a_min=0)
            image_copy[xj,xi,1] = np.clip(int(image[xj,xi,1]*percetage),a_max=255,a_min=0)
            image_copy[xj,xi,2] = np.clip(int(image[xj,xi,2]*percetage),a_max=255,a_min=0)
    return image_copy

# 旋转
def rotate(image, angle, center=None, scale=1.0):
    (h, w) = image.shape[:2]
    # If no rotation center is specified, the center of the image is set as the rotation center
    if center is None:
        center = (w / 2, h / 2)
    m = cv2.getRotationMatrix2D(center, angle, scale)
    rotated = cv2.warpAffine(image, m, (w, h))
    return rotated

# 翻转
def flip(image):
    flipped_image = np.fliplr(image)
    return flipped_image
    

二、数据增强

from PIL import Image, ImageEnhance
import os
import random
import shutil
def augment_image(image_path, save_path):
    img = cv2.imread(image_path)
    image_name = os.path.basename(image_path)  # 获取图片名称
    split_result = image_name.split('.')
    name = split_result[:-1]
    extension = split_result[-1]
    # cv2.imshow("1",img)
    # cv2.waitKey(5000)
    # 旋转
    rotated_90 = rotate(img, 90)
    cv2.imwrite(save_path +  "".join(name) + '_r90.'+ extension, rotated_90)
    rotated_180 = rotate(img, 180)
    cv2.imwrite(save_path +  "".join(name) + '_r180.'+ extension, rotated_180)
    flipped_img = flip(img)
    cv2.imwrite(save_path +  "".join(name) + '_fli.'+ extension, flipped_img)

    # 增加噪声
    # img_salt = SaltAndPepper(img, 0.3)
    # cv2.imwrite(save_path + img_name[0:7] + '_salt.jpg', img_salt)
    img_gauss = addGaussianNoise(img, 0.3)
    cv2.imwrite(save_path +  "".join(name) + '_noise.'+ extension,img_gauss)

    #变亮、变暗
    img_darker = darker(img)
    cv2.imwrite(save_path +  "".join(name) + '_darker.'+ extension, img_darker)
    img_brighter = brighter(img)
    cv2.imwrite(save_path +  "".join(name) + '_brighter.'+ extension, img_brighter)

    blur = cv2.GaussianBlur(img, (7, 7), 1.5)
    #      cv2.GaussianBlur(图像,卷积核,标准差)
    cv2.imwrite(save_path +  "".join(name) + '_blur.'+ extension,blur)

target_num = 2000  # 目标增强图片数量
image_folder = 'D:/plantsdata/data/train/'  # 图片文件夹路径
save_folder = 'D:/plantsdata/data/train_with_augmentation/'  # 保存增强后的图片的文件夹路径
# 获取所有类别的文件夹路径
class_folders = os.listdir(image_folder)

# 遍历类别文件夹
for class_folder in class_folders:
    if not os.path.isdir(os.path.join(image_folder, class_folder)):
         continue
    target_subfolder = os.path.join(save_folder,class_folder)
    os.makedirs(target_subfolder, exist_ok=True)
    image_list = os.listdir(os.path.join(image_folder, class_folder))
    # 获取当前文件夹中所有图片的路径
    images = []
    for file_name in image_list:
        images.append(os.path.join(image_folder, class_folder, file_name))
    num_images = len(images)
    print(num_images)
    print(target_num)
    if num_images < target_num:
        for image_path in images:
            with Image.open(image_path) as img:
                name = os.path.basename(image_path)
                target_path = os.path.join(target_subfolder, name)
                shutil.copy(image_path, target_path)
        i = num_images
        j = 0
        random_image = random.sample(image_list,k=num_images)
        while i<target_num and j<=num_images-1:
            image_path = os.path.join(image_folder, class_folder, random_image[j])
            target_path = target_subfolder + '/'
            augment_image(image_path, target_path)
            i+=7
            j+=1
            print(i)
    else:
        # 随机选择2000张图片
        selected_images = random.sample(images,k=2000)
        # 将选中的图片复制到目标文件夹
        for image_path in selected_images:
            with Image.open(image_path) as img:
                name = os.path.basename(image_path)
                target_path = os.path.join(target_subfolder, name)
                shutil.copy(image_path, target_path)

数据增强将文件夹中的每个类别的文件夹中的图片数据首先复制到目标文件夹,如果大于2000张随机挑选2000张图片复制,不够的画在进行数据增强,目标每个类别2000张,如果类别文件夹的图片数量太小,那就缩小目标数目,或者在找些图片.

三、查看数据分布情况

import os
import matplotlib.pyplot as plt

def count_photos_in_categories(folder_path):
    count_dict = {}

    for root, dirs, files in os.walk(folder_path):
        category = os.path.basename(root)
        count_dict.setdefault(category, 0)
        for file in files:
                count_dict[category] += 1

    return count_dict

folder1 = "D:/plantsdata/data/train/"
folder2 = "D:/plantsdata/data/train_with_augmentation/"

count1 = count_photos_in_categories(folder1)
count2 = count_photos_in_categories(folder2)

categories = sorted(set(count1.keys()).union(set(count2.keys())))

x = range(-1, len(categories)-1)  # 从1开始编号
width = 0.35

fig, ax = plt.subplots(figsize=(12, 6),dpi = 400)
rects1 = ax.bar(x, [count1.get(category, 0) for category in categories], width, label='train_without_augmentation')
rects2 = ax.bar([i + width for i in x], [count2.get(category, 0) for category in categories], width, label='train_with_augmentation')

ax.set_ylabel('Photo Count')
ax.set_xlabel('Category')
ax.set_title('Comparison of Photo Counts in Different Categories')

ax.set_xticks([i + width/2 for i in x])
ax.set_xticklabels(x)
# ax.set_xticklabels(categories)
ax.set_xlim(-0.5, len(categories)-1)
ax.legend()

plt.show()

这样我们就分好啦!


总结

对数据增强的代码,帮助大家从文件夹中对图片进行处理,本文的图片增强的代码就举例几个,大家可以在搜寻图像增强的方法加入函数即可.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/333734.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ORM Bee设计思想与功能思维导图

ORM Bee设计思想与功能思维导图 Bee&#xff0c;互联网新时代的Java ORM框架&#xff0c;支持Sharding&#xff1b;JDBC&#xff0c;Android&#xff0c;HarmonyOS&#xff1b;支持多种关系型数据库&#xff0c;还支持NoSQL的Cassandra&#xff0c;Mongodb等&#xff1b;更快、…

NVIDIA 大模型 RAG 分享笔记

文章目录 大语言模型在垂直领域落地的三个挑战&#xff1a;什么是 RAG以及为什么能解决大预言模型所带来的的这三个问题RAG 不是一项技术而是整体的 Pipeline非参数化 &#xff1a;数据库部分加载到数据库中检索阶段 提升检索效率的技术检索前&#xff1a;对query做处理use que…

redis缓存和本地缓存的应用设计

数据查询顺序 一级缓存&#xff1a;本地缓存 -》二级缓存&#xff1a;redis缓存 -》数据库 本地缓存和分布式缓存 本地缓存&#xff1a;基于jvm, 意思是程序放在哪&#xff0c;数据就存储在哪&#xff0c;不需要网络请求&#xff0c;特别快&#xff0c;但是需要占用jvm的内存…

Redis--Zset使用场景举例(滑动窗口实现限流)

文章目录 前言什么是滑动窗口zset实现滑动窗口小结附录 前言 在Redis–Zset的语法和使用场景举例&#xff08;朋友圈点赞&#xff0c;排行榜&#xff09;一文中&#xff0c;提及了redis数据结构zset的指令语法和一些使用场景&#xff0c;今天我们使用zset来实现滑动窗口限流&a…

Docker 仓库管理

Docker 仓库管理 仓库&#xff08;Repository&#xff09;是集中存放镜像的地方。以下介绍一下 Docker Hub。当然不止 docker hub&#xff0c;只是远程的服务商不一样&#xff0c;操作都是一样的。 Docker Hub 目前 Docker 官方维护了一个公共仓库 Docker Hub。 大部分需求…

Oracle命令大全

文章目录 1. SQL*Plus命令&#xff08;用于连接与管理Oracle数据库&#xff09;2. SQL数据定义语言&#xff08;DDL&#xff09;命令3. SQL数据操作语言&#xff08;DML&#xff09;命令4. PL/SQL程序块5. 系统用户管理6. 数据备份与恢复相关命令1. SQL*Plus命令&#xff08;用…

java-log4j日志冲突解决

一、概述 java日志框架较多&#xff0c;其中主流的slf4j和commons-logging是日志接口&#xff0c;log4j、log4j2和logback是真正的日志实现库。 二、具体库单独使用 2.1 log4j <dependency><groupId>log4j</groupId><artifactId>log4j</artifa…

CentOS stream 9配置网卡

CentOS stream9的网卡和centos 7的配置路径&#xff1a;/etc/sysconfig/network-scripts/ifcfg-ens32不一样。 CentOS stream 9的网卡路径&#xff1a; /etc/NetworkManager/system-connections/ens32.nmconnection 方法一&#xff1a; [connection] idens32 uuid426b60a4-4…

【鸿蒙4.0】详解harmonyos开发语言ArkTS

文章目录 一.什么是ArkTS&#xff1f;1.ArkTS的背景2.了解js&#xff0c;ts&#xff0c;ArkTS的演变js(Javascript)Javascript的简介Javascript的特点 ts(Typescript)ArkTS 二. ArkTS的特点 一.什么是ArkTS&#xff1f; 1.ArkTS的背景 如官方文档所描述&#xff0c;ArkTS是基…

《Linux C编程实战》笔记:Linux信号介绍

信号是一种软件中断&#xff0c;它提供了处理一种异步事件的方法&#xff0c;也是进程惟一的异步通信方式。在Linux系统中&#xff0c;根据POSIX标准扩展的信号机制&#xff0c;不仅可以用来通知某进程发生了什么事&#xff0c;还可以给进程传递数据。 信号的来源 信号的来源…

广东金牌电缆:法大大电子合同助力业务风险管控

广东金牌电缆集团股份有限公司&#xff08;以下简称“广东金牌电缆”&#xff09;成立于2013年&#xff0c;现为广东省电线电缆重点生产企业、广东省守合同重信用单位、国家专精特新小巨人企业、国家高新技术企业&#xff0c;拥有自主商标“夺冠”&#xff0c;“夺冠”商标被评…

一文读懂「Fine-tuning」微调

一、什么是微调&#xff1f; 1. 什么是微调&#xff1f; 微调是指在预训练模型&#xff08;Pre-trained model&#xff09;的基础上&#xff0c;针对特定任务或数据领域&#xff0c;对部分或全部模型参数进行进一步的训练和调整&#xff08;Fine Tune&#xff09;。预训练模型…

File 类的用法和 InputStream, OutputStream 的用法

1.File类的用法 下面就用几个简单的代码案例来熟悉File类里面函数的用法&#xff1a; public class IODemo1 {public static void main(String[] args) throws IOException {File f new File("./test2.txt");//File f new File("C:/User/1/test.txt");S…

redis数据安全(二)数据持久化 RDB

目录 一、RDB快照持久化 原理 二、RDB快照持久化配置&#xff08;redis.conf&#xff09;&#xff1a; 三、触发RDB备份&#xff1a; 1、自动备份&#xff0c;需配置备份规则&#xff1a; 2、手动执行命令备份&#xff08;save | bgsave&#xff09;&#xff1a; 3、flus…

排序:非递归的归并排序

目录 递归与非递归的思想对比&#xff1a; 递归&#xff1a; 非递归&#xff1a; 代码解析&#xff1a; 完整代码&#xff1a; 递归与非递归的思想对比&#xff1a; 递归&#xff1a; 在之前的归并排序&#xff0c;它的核心思想是通过不断的分割&#xff0c;从一个数组变…

JAVA实现向Word模板中插入Base64图片和数据信息

目录 需求一、准备模板文件二、引入Poi-tl、Apache POI依赖三、创建实体类&#xff08;用于保存向Word中写入的数据&#xff09;四、实现Service接口五、Controller层实现 需求 在服务端提前准备好Word模板文件&#xff0c;并在用户请求接口时服务端动态获取图片。数据等信息插…

数据结构图算法

算法就要多练,我在国庆节放假的时间编写了图的算法题,写完让我受益匪浅,希望可以帮助到大家. 文章目录 前言 一、图的数据结构 1.图的邻接表数据结构定义 2.图的邻接矩阵的存储形式 二、邻接表建立图代码 三、邻接表删除边(基本操作考试不考) 四、邻接表删除顶点及销毁整…

【LLM】Prompt微调

Prompt 在机器学习中&#xff0c;Prompt通常指的是一种生成模型的输入方式。生成模型可以接收一个Prompt作为输入&#xff0c;并生成与该输入相对应的输出。Prompt可以是一段文本、一个问题或者一个片段&#xff0c;用于指导生成模型生成相应的响应、续写文本等。 Prompt优化…

.pings勒索病毒的无声侵袭:保护你的数据财产免受.pings的侵害

尊敬的读者&#xff1a; 在数字时代&#xff0c;网络犯罪者不断推陈出新&#xff0c;而.pings勒索病毒则是一种极富威胁的加密型恶意软件。本文将深入探讨.pings勒索病毒的攻击方式&#xff0c;为您提供从数据恢复到全面预防的完整指南&#xff0c;帮助您有效对抗这一威胁。如…

安全帽识别:智能监控新趋势

在现代工业安全领域&#xff0c;安全帽识别技术已成为一项关键的创新。这项技术通过智能监控系统确保工作人员在危险环境中佩戴安全帽&#xff0c;显著提升了工作场所的安全标准。本文将探讨这一技术的工作原理、应用前景及其在现代工业中的重要性。 安全帽识别的工作机制 安全…