【图像分类实用脚本】数据可视化以及高数量类别截断

图像分类时,如果某个类别或者某些类别的数量远大于其他类别的话,模型在计算的时候,更倾向于拟合数量更多的类别;因此,观察类别数量以及对数据量多的类别进行截断是很有必要的。

1.准备数据

数据的格式为图像分类数据集格式,根目录下分为train和val文件夹,每个文件夹下以类别名命名的子文件夹:

.
├── ./datasets
│ ├── ./datasets/train/A
│ │ ├── ./datasets/train/A/1.jpg
│ │ ├── ./datasets/train/A/2.jpg
│ │ ├── ./datasets/train/A/3.jpg
│ │ ├── …
│ ├── ./datasets/train/B
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── …
│ ├── ./datasets/val/A
│ │ ├── ./datasets/val/A/1.jpg
│ │ ├── ./datasets/val/A/2.jpg
│ │ ├── ./datasets/val/A/3.jpg
│ │ ├── …
│ ├── ./datasets/val/B
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── …

2.查看数据分布

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def count_images(directory, image_extensions):
    """
    统计每个子文件夹中的图像数量。

    :param directory: 主目录路径(train或val)
    :param image_extensions: 允许的图像文件扩展名元组
    :return: 一个字典,键为类别名,值为图像数量
    """
    counts = {}
    if not os.path.exists(directory):
        print(f"目录不存在: {directory}")
        return counts

    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):
            # 统计符合扩展名的文件数量
            image_count = sum(
                1 for file in os.listdir(class_path)
                if file.lower().endswith(image_extensions)
            )
            counts[class_name] = image_count
    return counts

def count_images_in_single_directory(directory, image_extensions):
    """
    统计单个目录下每个类别的图像数量。

    :param directory: 主目录路径
    :param image_extensions: 允许的图像文件扩展名元组
    :return: 一个字典,键为类别名,值为图像数量
    """
    counts = {}
    if not os.path.exists(directory):
        print(f"目录不存在: {directory}")
        return counts

    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):
            image_count = sum(
                1 for file in os.listdir(class_path)
                if file.lower().endswith(image_extensions)
            )
            counts[class_name] = image_count
    return counts

def autolabel(ax, rects):
    """
    在每个柱状图上方添加数值标签。

    :param ax: Matplotlib 的轴对象
    :param rects: 柱状图对象
    """
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')

def plot_distribution(all_classes, train_values, val_values, output_path, has_val=False):
    """
    绘制并保存训练集和验证集中每个类别的图像数量分布柱状图。
    如果没有验证集数据,则只绘制训练集数据。

    :param all_classes: 所有类别名称列表
    :param train_values: 训练集中每个类别的图像数量列表
    :param val_values: 验证集中每个类别的图像数量列表(如果有的话)
    :param output_path: 保存图表的文件路径
    :param has_val: 是否包含验证集数据
    """
    x = np.arange(len(all_classes))  # 类别位置
    width = 0.35  # 柱状图的宽度

    fig, ax = plt.subplots(figsize=(12, 8))

    if has_val:
        rects1 = ax.bar(x - width/2, train_values, width, label='Train')
        rects2 = ax.bar(x + width/2, val_values, width, label='Validation')
    else:
        rects1 = ax.bar(x, train_values, width, label='Count')

    # 添加一些文本标签
    ax.set_xlabel('Category')
    ax.set_ylabel('Number of Images')
    title = 'Number of Images in Each Category for Train and Validation' if has_val else 'Number of Images in Each Category'
    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_xticklabels(all_classes, rotation=45, ha='right')
    ax.legend() if has_val else ax.legend(['Count'])

    # 自动标注柱状图上的数值
    autolabel(ax, rects1)
    if has_val:
        autolabel(ax, rects2)

    fig.tight_layout()

    # 保存图表为图片文件
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"图表已保存到 {output_path}")

def compute_and_display_statistics(counts_dict, dataset_name, save_csv=False):
    """
    计算并展示统计数据,包括总图像数量、类别数量、平均每个类别的图像数量和类别占比。

    :param counts_dict: 类别名称与图像数量的字典
    :param dataset_name: 数据集名称(例如 'Train', 'Validation', 'Dataset')
    :param save_csv: 是否保存统计结果为 CSV 文件
    """
    total_images = sum(counts_dict.values())
    num_classes = len(counts_dict)
    avg_per_class = total_images / num_classes if num_classes > 0 else 0

    # 计算每个类别的占比
    category_proportions = {cls: (count / total_images * 100) if total_images > 0 else 0 
                            for cls, count in counts_dict.items()}

    # 创建 DataFrame
    df = pd.DataFrame({
        '类别名称': list(counts_dict.keys()),
        '图像数量': list(counts_dict.values()),
        '占比 (%)': [f"{prop:.2f}" for prop in category_proportions.values()]
    })

    # 排序 DataFrame 按图像数量降序
    df = df.sort_values(by='图像数量', ascending=False)

    print(f"\n===== {dataset_name} 数据统计 =====")
    print(df.to_string(index=False))

    print(f"总图像数量: {total_images}")
    print(f"类别数量: {num_classes}")
    print(f"平均每个类别的图像数量: {avg_per_class:.2f}")

    # 根据 save_csv 参数决定是否保存为 CSV 文件
    if save_csv:
        # 将数据集名称转换为小写并去除空格,以作为文件名的一部分
        sanitized_name = dataset_name.lower().replace(" ", "_").replace("(", "").replace(")", "")
        csv_filename = f"{sanitized_name}_statistics.csv"
        df.to_csv(csv_filename, index=False, encoding='utf-8-sig')
        print(f"统计表已保存为 {csv_filename}\n")

def main():
    # ================== 配置参数 ==================
    # 设置数据集的根目录路径
    dataset_root = 'datasets/device_cls_merge_manual_with_21w_1218'  # 替换为你的数据集路径

    # 定义train和val目录
    train_dir = os.path.join(dataset_root, 'train')
    val_dir = os.path.join(dataset_root, 'val')

    # 定义允许的图像文件扩展名
    image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')

    # 输出图表的路径
    output_path = 'dataset_distribution.png'  # 你可以更改为你想要的文件名和路径

    # 是否保存统计结果为 CSV 文件(默认不保存)
    SAVE_CSV = False  # 设置为 True 以启用保存 CSV

    # ================== 统计图像数量 ==================
    has_train = os.path.exists(train_dir) and os.path.isdir(train_dir)
    has_val = os.path.exists(val_dir) and os.path.isdir(val_dir)

    if has_train and has_val:
        print("检测到 'train' 和 'val' 目录。统计训练集和验证集中的图像数量...")
        train_counts = count_images(train_dir, image_extensions)
        val_counts = count_images(val_dir, image_extensions)

        # 获取所有类别的名称(确保train和val中的类别一致)
        all_classes = sorted(list(set(train_counts.keys()) | set(val_counts.keys())))

        # 准备绘图数据
        train_values = [train_counts.get(cls, 0) for cls in all_classes]
        val_values = [val_counts.get(cls, 0) for cls in all_classes]

        # ================== 计算并展示统计数据 ==================
        compute_and_display_statistics(train_counts, '训练集 (Train)', save_csv=SAVE_CSV)
        compute_and_display_statistics(val_counts, '验证集 (Validation)', save_csv=SAVE_CSV)

        # ================== 绘制并保存图表 ==================
        print("绘制并保存训练集和验证集的图表...")
        plot_distribution(all_classes, train_values, val_values, output_path, has_val=True)

    else:
        print("未检测到 'train' 和 'val' 目录。将统计主目录下的图像数量...")
        # 如果没有train和val目录,则统计主目录下的图像分布
        main_counts = count_images_in_single_directory(dataset_root, image_extensions)

        # 获取所有类别的名称
        all_classes = sorted(main_counts.keys())

        # 准备绘图数据
        main_values = [main_counts.get(cls, 0) for cls in all_classes]

        # 定义输出图表路径(可以区分不同的输出文件名)
        output_path_single = 'dataset_distribution_single.png'  # 或者使用与train_val相同的output_path

        # ================== 计算并展示统计数据 ==================
        compute_and_display_statistics(main_counts, '数据集 (Dataset)', save_csv=SAVE_CSV)

        # ================== 绘制并保存图表 ==================
        print("绘制并保存主目录的图表...")
        plot_distribution(all_classes, main_values, [], output_path_single, has_val=False)

if __name__ == "__main__":
    main()

下图为原始数据集运行结果,可以看到数据存在严重不均衡问题
在这里插入图片描述

3.数据截断

import os
import shutil
import random


def count_images(directory, image_extensions):
    """
    统计每个子文件夹中的图像文件路径列表。

    :param directory: 主目录路径(train或val)
    :param image_extensions: 允许的图像文件扩展名列表
    :return: 一个字典,键为类别名,值为图像文件路径列表
    """
    counts = {}
    if not os.path.exists(directory):
        print(f"目录不存在: {directory}")
        return counts

    for class_name in os.listdir(directory):
        class_path = os.path.join(directory, class_name)
        if os.path.isdir(class_path):
            # 获取符合扩展名的文件列表
            images = [
                file for file in os.listdir(class_path)
                if file.lower().endswith(tuple(image_extensions))
            ]
            image_paths = [os.path.join(class_path, img) for img in images]
            counts[class_name] = image_paths
    return counts


def truncate_dataset(class_images, threshold, seed=42):
    """
    对每个类别的图像进行截断,如果超过阈值则随机选择一定数量的图像。

    :param class_images: 一个字典,键为类别名,值为图像文件路径列表
    :param threshold: 每个类别的图像数量阈值
    :param seed: 随机种子
    :return: 截断后的类别图像字典
    """
    truncated = {}
    random.seed(seed)
    for class_name, images in class_images.items():
        if len(images) > threshold:
            truncated_images = random.sample(images, threshold)
            truncated[class_name] = truncated_images
            print(f"类别 '{class_name}' 超过阈值 {threshold},已随机选择 {threshold} 张图像。")
        else:
            truncated[class_name] = images
            print(f"类别 '{class_name}' 不超过阈值 {threshold},保留所有 {len(images)} 张图像。")
    return truncated


def copy_images(truncated_data, subset, output_root):
    """
    将截断后的图像复制到输出目录,保持原有的目录结构。

    :param truncated_data: 截断后的类别图像字典
    :param subset: 'train' 或 'val'
    :param output_root: 输出根目录路径
    """
    for class_name, images in truncated_data.items():
        dest_dir = os.path.join(output_root, subset, class_name)
        os.makedirs(dest_dir, exist_ok=True)
        for img_path in images:
            img_name = os.path.basename(img_path)
            dest_path = os.path.join(dest_dir, img_name)
            shutil.copy2(img_path, dest_path)
    print(f"'{subset}' 子集已复制到 {output_root}")


def main():
    """
    主函数,执行数据集截断和复制操作。
    """
    # ================== 配置参数 ==================

    # 原始数据集根目录路径
    input_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224'  # 替换为你的原始数据集路径

    # 截断后数据集的输出根目录路径
    output_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate'  # 替换为你希望保存截断后数据集的路径

    # 训练集每个类别的图像数量阈值
    train_threshold = 2000  # 设置为你需要的训练集阈值

    # 验证集每个类别的图像数量阈值
    val_threshold = 400  # 设置为你需要的验证集阈值

    # 允许的图像文件扩展名
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']

    # 随机种子以确保可重复性
    random_seed = 42

    # ================== 脚本实现 ==================

    # 设置随机种子
    random.seed(random_seed)

    # 定义train和val目录路径
    train_input_dir = os.path.join(input_dir, 'train')
    val_input_dir = os.path.join(input_dir, 'val')

    # 统计train和val中的图像
    print("统计训练集中的图像数量...")
    train_counts = count_images(train_input_dir, image_extensions)
    print("统计验证集中的图像数量...")
    val_counts = count_images(val_input_dir, image_extensions)

    # 截断train和val中的图像
    print("\n截断训练集中的图像...")
    truncated_train = truncate_dataset(train_counts, train_threshold, random_seed)
    print("\n截断验证集中的图像...")
    truncated_val = truncate_dataset(val_counts, val_threshold, random_seed)

    # 复制截断后的图像到输出目录
    print("\n复制截断后的训练集图像...")
    copy_images(truncated_train, 'train', output_dir)
    print("复制截断后的验证集图像...")
    copy_images(truncated_val, 'val', output_dir)

    print("\n数据集截断完成。")


if __name__ == "__main__":
    main()

再次查看已经符合截断后的数据分布了
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/940235.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

飞牛os使用ddns-go配合华为云实现内网穿透

DDNS-Go 是一个开源的动态域名解析工具,它支持多种操作系统,包括 Windows、Mac 和 Linux,并且支持 ARM 和 x86 架构。以下是使用 DDNS-Go 的基本步骤: 1. 下载和安装: 访问 DDNS-Go 的 GitHub 仓库(&#x…

易语言OCR证件照文字识别

一.引言 文字识别,也称为光学字符识别(Optical Character Recognition, OCR),是一种将不同形式的文档(如扫描的纸质文档、PDF文件或数字相机拍摄的图片)中的文字转换成可编辑和可搜索的数据的技术。随着技…

二八(vue2-04)、scoped、data函数、父子通信、props校验、非父子通信(EventBus、provideinject)、v-model进阶

1. 组件的三大组成部分(结构/样式/逻辑) 1.1 scoped 样式冲突 App.vue <template><!-- template 只能有一个根元素 --><div id"app"><BaseOne></BaseOne><BaseTwo></BaseTwo></div> </template><script…

3D工具显微镜的测量范围

一、测量尺寸范围 样品尺寸&#xff1a; 3D工具显微镜通常能够测量各种尺寸和形状的样品&#xff0c;从小至微米级别的微小结构到大至几厘米甚至更大的物体。具体的测量尺寸范围取决于显微镜的载物台大小、镜头焦距以及软件处理能力。测量精度&#xff1a; 3D工具显微镜的测量…

C#—扩展方法

扩展方法 扩展方法是C#中一种特殊的静态方法&#xff0c;它定义在一个静态类中&#xff0c;但是可以像实例方法一样被调用&#xff0c;使得代码看起来更为直观和易于阅读。扩展方法允许你在不修改原始类的情况下&#xff0c;添加新的方法到现有的类型中。 有↓箭头的是扩展方…

vertx idea快速使用

目录 1.官网下载项目 2.修改代码 2.1拷贝代码方式 为了能够快速使用&#xff0c;我另外创建一个新的maven项目&#xff0c;将下载项目的src文件和pom文件拷贝到新建的maven项目。 2.2删除.mvn方式 3.更新配置 4.配置application 5.idea启动项目 1.官网下载项目 从vert…

分布式全文检索引擎ElasticSearch-数据的写入存储底层原理

一、数据写入的核心流程 当向 ES 索引写入数据时&#xff0c;整体流程如下&#xff1a; 1、客户端发送写入请求 客户端向 ES 集群的任意节点&#xff08;称为协调节点&#xff0c;Coordinating Node&#xff09;发送一个写入请求&#xff0c;比如 index&#xff08;插入或更…

android EditText密码自动填充适配

android上的密码&#xff08;其实不仅仅是密码&#xff0c;可以是用户名也可以是邮箱&#xff09;自动填充&#xff0c;是需要考虑适配的。 官方文档&#xff1a;https://developer.android.com/identity/autofill/autofill-optimize?hlzh-cn 什么是自动填充 手机厂商一般会…

【MySQL】非聚簇索引和聚簇索引,索引的创建、查询、删除

目录 存储引擎是MyISAM 非聚簇索引 主键索引&#xff1a; 普通(辅助)索引&#xff1a; 存储引擎是InnoDB 聚簇索引 主键索引&#xff1a; 普通(辅助)索引&#xff1a; 回表查询 创建索引 创建主键索引 主键索引的特点&#xff1a; 创建唯一索引 唯一索引的特点&am…

list的常用操作

list的介绍 list是序列容器&#xff0c;它允许在常数范围O&#xff08;1&#xff09;进行插入和删除在这段序列的任意位置&#xff0c;并且可以双向遍历 它是弥补vector容器的缺点&#xff0c;与vector有互补的韵味&#xff0c; 这里我们可以将其进行与vector进行对比 vect…

C# opencvsharp 流程化-脚本化-(2)ROI

ROI ROI也是经常需要使用的方法。特别是在图像编辑中。ROI又称感兴趣的区域&#xff0c;但是图像是矩阵是矩形的&#xff0c;感兴趣的是乱八七糟的&#xff0c;所以还有一个Mask需要了解一下的。 public class RoiStep : IImageProcessingStep{public ImageProcessingStepType…

极狐GitLab 正式发布安全补丁版本 17.6.2、17.5.4、 17.4.6

本分分享极狐GitLab 补丁版本 17.6.2, 17.5.4, 17.4.6 的详细内容。这几个版本包含重要的缺陷和安全修复代码&#xff0c;我们强烈建议所有私有化部署用户应该立即升级到上述的某一个版本。对于极狐GitLab SaaS&#xff0c;技术团队已经进行了升级&#xff0c;无需用户采取任何…

在window环境下安装openssl生成钥私、证书和签名,nodejs利用express实现ssl的https访问和测试

在生成我们自己的 SSL 证书之前&#xff0c;让我们创建一个简单的 Express应用程序。 要创建一个新的 Express 项目&#xff0c;让我们创建一个名为node-ssl -server 的目录&#xff0c;用终端cmd中进入node-ssl-server目录。 cd node-ssl-server 然后初始化一个新的 npm 项目…

CTFHub ssrf

第一关&#xff08;内网访问&#xff09; 尝试访问位于127.0.0.1的flag.php吧 第二关(伪协议读取文件) 尝试去读取一下Web目录下的flag.php吧 1.首先尝试http://127.0.0.1/flag.php 2.查看页面源代码 3.根据提示输入file:///var/www/html/flag.php 4.查看页面源代码 第三关&…

vue3+vite一个IP对站点名称的前端curd更新-会议系统优化

vue3-tailwind-todo https://github.com/kgrg/vue3-tailwind-todo 基于这个项目,把ip到sta的映射做了前端管理. 核心代码是存储和获得的接口,需要flask提供. def redis2ipdic():global ipdicipdic.clear()tmdiccl.hgetall(IPDIC_KEY)for k in tmdic.keys():ipdic[k.decode() …

idea无法识别文件,如何把floder文件恢复成model

前景&#xff1a; 昨天&#xff0c;我在之前的A1214模块包下新增了一个demo类&#xff0c;然后又新建了一个A1216模块&#xff0c;写了算法题&#xff0c;后面打算用git提交&#xff0c;发现之前的A1214模块下的demo类和新建的模块源文件都已经被追踪了&#xff0c;都是绿色的&…

如何调大unity软件的字体

一、解决的问题&#xff1a; unity软件的字体太小&#xff0c;怎么调大点&#xff1f;二、解决方法&#xff1a; 1.操作步骤&#xff1a; 打开Unity编辑器> Edit>preferences> UI Scaling>Use custom scaling value&#xff08;取消勾选“使用默认桌面设置”&…

冯诺依曼架构与哈佛架构的对比与应用

冯诺依曼架构&#xff08;Von Neumann Architecture&#xff09;&#xff0c;也称为 冯诺依曼模型&#xff0c;是由著名数学家和计算机科学家约翰冯诺依曼&#xff08;John von Neumann&#xff09;在1945年提出的。冯诺依曼架构为现代计算机奠定了基础&#xff0c;几乎所有现代…

实操给桌面机器人加上超拟人音色

前面我们讲了怎么用CSK6大模型开发板做一个桌面机器人充当AI语音助理&#xff0c;近期上线超拟人方案&#xff0c;不仅大模型语音最快可以1秒内回复&#xff0c;还可以让我们的桌面机器人使用超拟人音色、具备声纹识别等能力&#xff0c;本文以csk6大模型开发板为例实操怎么把超…

Fiddle突然抓不到虚拟机的地址

Fiddle不抓虚拟机的地址了 查看是否更换了ip地址,我是因为换了网络 更换正确的ip地址