dataset.py
ultralytics\data\dataset.py
目录
dataset.py
1.所需的库和模块
2.class YOLODataset(BaseDataset):
3.class ClassificationDataset(torchvision.datasets.ImageFolder):
4.def load_dataset_cache_file(path):
5.def save_dataset_cache_file(prefix, path, x):
6.class SemanticDataset(BaseDataset):
1.所需的库和模块
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
import cv2
import numpy as np
import torch
import torchvision
from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
from ultralytics.utils.ops import resample_segments
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = "1.0.3"
2.class YOLODataset(BaseDataset):
# 这段代码定义了一个名为 YOLODataset 的类,继承自 BaseDataset ,用于处理YOLO模型的数据集。
# 定义了 YOLODataset 类,继承自 BaseDataset 。
class YOLODataset(BaseDataset):
# 用于以 YOLO 格式加载对象检测和/或分割标签的数据集类。
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
# 这段代码是 YOLODataset 类的初始化方法 __init__ ,用于创建类的实例时进行一些基本的设置和检查。
# 定义了 YOLODataset 类的初始化方法。这个方法接受以下参数 :
# 1.*args :任意数量的位置参数,这些参数会被传递给父类的初始化方法。
# 2.data :一个可选的关键字参数,用于传入数据集的相关信息,默认值为 None 。
# 3.task :一个可选的关键字参数,用于指定任务类型,默认值为 "detect" ,表示目标检测任务。
# 4.**kwargs :任意数量的关键字参数,这些参数也会被传递给父类的初始化方法。
def __init__(self, *args, data=None, task="detect", **kwargs):
# 使用可选的片段和关键点配置初始化 YOLODataset。
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
# 检查任务类型是否为 "segment" 。如果是,则将 self.use_segments 设置为 True ,表示 这个数据集将用于分割任务 。
self.use_segments = task == "segment"
# 检查任务类型是否为 "pose" 。如果是,则将 self.use_keypoints 设置为 True ,表示 这个数据集将用于姿态估计任务 。
self.use_keypoints = task == "pose"
# 检查任务类型是否为 "obb" 。如果是,则将 self.use_obb 设置为 True ,表示 这个数据集将用于定向边界框(Oriented Bounding Box)任务 。
self.use_obb = task == "obb"
# 将传入的 data 参数赋值给实例变量 self.data ,这个变量用于 存储数据集的相关信息 。
self.data = data
# 这行代码是一个断言,用于确保不会同时启用分割和关键点任务。如果 self.use_segments 和 self.use_keypoints 同时为 True ,则会抛出一个异常,提示不能同时使用分割和关键点。
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." # 不能同时使用段和关键点。
# 调用父类 BaseDataset 的初始化方法,并将位置参数 *args 和关键字参数 **kwargs 传递给它。这是继承机制中的一个常见做法,用于确保父类的初始化逻辑也被执行。
super().__init__(*args, **kwargs)
# 这个初始化方法的主要作用是根据传入的任务类型设置相应的标志变量,这些变量用于后续的数据处理和模型训练过程中,以确定如何处理数据(例如,是否需要处理分割、关键点或定向边界框)。同时,它还确保了数据集的配置不会同时启用不兼容的任务类型(如分割和关键点)。最后,通过调用父类的初始化方法,确保了 YOLODataset 类能够继承 BaseDataset 类的属性和方法。
# 这段代码定义了 YOLODataset 类中的 cache_labels 方法,其主要功能是缓存数据集的标签信息,以便后续快速加载和使用。
# 定义了 cache_labels 方法,该方法接受一个可选参数。
# path :默认值为当前目录下的 labels.cache 文件,用于指定缓存文件的路径。
def cache_labels(self, path=Path("./labels.cache")):
# 缓存数据集标签,检查图像并读取形状。
"""
Cache dataset labels, check images and read shapes.
Args:
path (Path): Path where to save the cache file. Default is Path('./labels.cache').
Returns:
(dict): labels.
"""
# 初始化一个字典 x ,其中包含一个空列表 "labels" ,用于 存储每个图像的标签信息 。
x = {"labels": []}
# 初始化计数器和消息列表。
# nm :缺失的图像数量。
# nf :找到的图像数量。
# ne :为空的图像数量。
# nc :损坏的图像数量。
# msgs :存储警告和错误消息的列表。
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
# 生成描述字符串,用于进度条显示,说明正在扫描的目录。
desc = f"{self.prefix}Scanning {path.parent / path.stem}..." # {self.prefix} 正在扫描 {path.parent / path.stem}...
# 获取 图像文件的总数 。
total = len(self.im_files)
# 从数据集配置中获取 关键点的形状信息 ,如果配置中没有 "kpt_shape" ,则默认为 (0, 0) 。
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
# 如果使用关键点且关键点形状信息不正确(关键点数量必须大于0,维度必须为2或3),则抛出 ValueError 异常。
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " # data.yaml 中的“kpt_shape”缺失或不正确。
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" # 应为包含 [关键点数量、维度数量(x、y 为 2,x、y、visible 为 3)] 的列表,即“kpt_shape:[17, 3]”。
)
# 使用线程池,线程数为 NUM_THREADS ,用于 并行处理图像和标签的验证 。
with ThreadPool(NUM_THREADS) as pool:
# 使用线程池的 imap 方法,将 verify_image_label 函数应用于每个图像文件和标签文件的组合。 verify_image_label 函数用于验证图像和标签的有效性,并返回相关信息。
results = pool.imap(
# def verify_image_label(args):
# -> 验证一个图像-标签对的有效性。返回验证结果,包括 图像文件路径 、 标签数组 、 图像尺寸 、 分割信息 、 关键点信息 、 缺失 、 找到 、 为空 、 损坏的计数 以及 消息字符串 。
# -> return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg / return [None, None, None, None, None, nm, nf, ne, nc, msg]
func=verify_image_label,
iterable=zip(
self.im_files,
self.label_files,
repeat(self.prefix),
repeat(self.use_keypoints),
repeat(len(self.data["names"])),
repeat(nkpt),
repeat(ndim),
),
)
# 创建一个进度条,用于显示处理进度。
pbar = TQDM(results, desc=desc, total=total)
# 遍历进度条的结果,每个结果包含 图像文件路径 、 标签 、 图像形状 、 分割信息 、 关键点信息 、 缺失 、 找到 、 为空 、 损坏的计数 以及 消息 。
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
# 累加 缺失 、 找到 、 为空 、 损坏的图像数量。
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
# 如果图像文件存在,则将标签信息添加到 x["labels"] 列表中。标签信息包括图像文件路径、图像形状、类别、边界框、分割信息、关键点信息等。
if im_file:
x["labels"].append(
dict(
im_file=im_file,
shape=shape,
cls=lb[:, 0:1], # n, 1
bboxes=lb[:, 1:], # n, 4
segments=segments,
keypoints=keypoint,
normalized=True,
bbox_format="xywh",
)
)
# 如果存在消息,则将其添加到消息列表中。
if msg:
msgs.append(msg)
# 更新进度条的描述信息,显示 找到的图像数量 、 背景图像数量 和 损坏的图像数量 。
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" # {desc} {nf} 图像,{nm + ne} 背景,{nc} 损坏。
# 关闭进度条。
pbar.close()
# 如果消息列表不为空,则记录这些消息。
if msgs:
LOGGER.info("\n".join(msgs))
# 如果未找到任何标签,则记录警告信息。
if nf == 0:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") # {self.prefix}警告 ⚠️ 在 {path} 中未找到标签。{HELP_URL}。
# 计算并存储标签文件和图像文件的哈希值,用于后续验证缓存的有效性。
# def get_hash(paths): -> 计算一个路径列表(可以是文件或目录)的单个哈希值。返回哈希对象的十六进制摘要,即路径列表的哈希值。该哈希值可以用于验证路径列表的内容是否发生变化。 -> return h.hexdigest() # return hash
x["hash"] = get_hash(self.label_files + self.im_files)
# 存储处理结果,包括 找到的图像数量 、 缺失的图像数量 、 为空的图像数量 、 损坏的图像数量 和 图像文件总数 。
x["results"] = nf, nm, ne, nc, len(self.im_files)
# 存储消息列表。
x["msgs"] = msgs # warnings
# 将缓存信息保存到指定路径的文件中。
# def save_dataset_cache_file(prefix, path, x): -> 将一个Ultralytics数据集的缓存字典 x 保存到指定路径 path 。
save_dataset_cache_file(self.prefix, path, x)
# 返回 缓存信息字典 x 。
return x
# cache_labels 方法的主要目的是并行验证数据集中的图像和标签文件,收集标签信息,并将这些信息缓存到文件中,以便后续快速加载。该方法还处理了关键点形状信息的验证,确保数据集配置的正确性。通过进度条和日志记录,提供了处理过程的可视化和错误信息的记录。
# 这段代码定义了 YOLODataset 类中的 get_labels 方法,其主要功能是获取并处理数据集的标签信息。
# 定义了 get_labels 方法,该方法没有接受额外的参数,用于获取和处理数据集的标签信息。
def get_labels(self):
# 返回 YOLO 训练的标签字典。
"""Returns dictionary of labels for YOLO training."""
# 将图像文件路径转换为 对应的标签文件路径 ,存储在 self.label_files 中。 img2label_paths 函数是一个自定义函数,用于根据图像文件路径生成对应的标签文件路径。
# def img2label_paths(img_paths): -> 将图像文件路径转换为对应的标签文件路径。使用列表推导式,将每个 图像文件路径 转换为 对应的标签文件路径 。 -> return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
self.label_files = img2label_paths(self.im_files)
# 生成 缓存文件的路径 ,缓存文件存储在第一个标签文件的父目录下,文件扩展名为 .cache 。
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
# 这段代码是 get_labels 方法中的一部分,主要功能是尝试加载缓存文件并验证其有效性和一致性。
# 开始一个 try 块,用于捕获可能发生的异常。
try:
# 尝试加载指定路径 cache_path 的 缓存文件 ,并将加载结果赋值给变量 cache 。同时,将 exists 变量设置为 True ,表示 缓存文件存在 。
# def load_dataset_cache_file(path): -> 加载一个缓存文件,该文件通常是一个包含数据集信息的字典。返回加载的缓存字典。 -> return cache
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
# 断言缓存文件中的版本号 cache["version"] 与当前代码中定义的版本号 DATASET_CACHE_VERSION 相匹配。如果版本号不匹配,将抛出 AssertionError 异常。
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
# 断言缓存文件中的哈希值 cache["hash"] 与当前图像文件和标签文件的哈希值相匹配。 get_hash 函数计算当前文件列表的哈希值。如果哈希值不匹配,将抛出 AssertionError 异常。
# def get_hash(paths): -> 计算一个路径列表(可以是文件或目录)的单个哈希值。返回哈希对象的十六进制摘要,即路径列表的哈希值。该哈希值可以用于验证路径列表的内容是否发生变化。 -> return h.hexdigest() # return hash
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
# 捕获以下三种异常。 FileNotFoundError :缓存文件不存在。 AssertionError :版本号或哈希值不匹配。 AttributeError :缓存文件中缺少必要的键(如 "version" 或 "hash" )。
except (FileNotFoundError, AssertionError, AttributeError):
# 如果捕获到上述任何异常,调用 cache_labels 方法生成新的缓存文件,并将 exists 变量设置为 False ,表示缓存文件不存在或需要重新生成。
cache, exists = self.cache_labels(cache_path), False # run cache ops
# 这段代码的目的是确保缓存文件存在且有效。如果缓存文件不存在或其内容与当前数据集不匹配,则重新生成缓存文件。通过版本号和哈希值的验证,确保缓存文件与当前数据集的一致性,避免因缓存文件过时或损坏而导致的问题。使用异常处理机制,使得代码在缓存文件不存在或验证失败时能够优雅地处理,而不是直接崩溃。
# 这段代码用于显示缓存文件中的统计信息和相关消息。
# Display cache
# 从缓存字典 cache 中提取并移除 "results" 键对应的值,该值是一个包含五个统计信息的元组。
# nf :找到的图像数量。
# nm :缺失的图像数量。
# ne :为空的图像数量。
# nc :损坏的图像数量。
# n :总图像数量。
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
# 检查缓存文件是否存在,并且当前进程的排名( LOCAL_RANK )为 -1 或 0 。在分布式训练中, LOCAL_RANK 为 0 通常表示主进程, -1 通常表示单进程环境。这行代码确保只有主进程或单进程环境会执行后续的显示操作,避免多个进程重复显示相同的信息。
if exists and LOCAL_RANK in (-1, 0):
# 生成描述字符串 d ,用于显示缓存文件的扫描结果,包括缓存文件路径、找到的图像数量、背景图像数量(缺失和为空的图像总数)以及损坏的图像数量。
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" # 扫描 {cache_path}...{nf} 幅图像、{nm + ne} 幅背景、{nc} 幅损坏图像。
# 使用 TQDM 进度条库显示扫描结果。这里 TQDM 的用法稍有不同,通常 TQDM 用于显示进度,但在这里它被用来显示一个静态的消息,因为 total 和 initial 参数相等,表示进度条已经完成。 self.prefix 是类实例的一个属性,用于添加前缀信息,如数据集名称等。
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
# 如果缓存字典 cache 中存在 "msgs" 键,并且该键对应的值(消息列表)不为空,则将这些消息合并为一个字符串,并使用日志记录器 LOGGER 记录这些信息。这些消息通常是警告或错误信息,用于提示用户数据集中的潜在问题。
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# 这段代码的主要目的是从缓存文件中提取统计信息,并在主进程或单进程环境中显示这些信息和相关消息。通过 TQDM 进度条库,以一种用户友好的方式显示扫描结果。使用日志记录器记录缓存文件中的消息,帮助用户了解数据集的状况。
# 这段代码用于从缓存中读取标签信息,并进行一些基本的验证和更新操作。
# Read cache
# 从缓存字典 cache 中移除 "hash" 、 "version" 和 "msgs" 这三个键及其对应的值。这些键在之前的步骤中已经使用过,移除它们是为了简化后续的处理。
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
# 从缓存字典 cache 中提取 "labels" 键对应的值,该值是一个 包含所有图像标签信息的列表 ,存储在变量 labels 中。
labels = cache["labels"]
# 检查 labels 列表是否为空。
if not labels:
# 如果为空,表示没有找到任何有效的图像标签信息,记录一个警告信息。警告信息中包含缓存文件的路径 cache_path ,提示用户训练可能无法正常进行,并提供一个帮助链接 HELP_URL 。
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}") # 警告 ⚠️ 在 {cache_path} 中未找到图像,训练可能无法正常工作。{HELP_URL}。
# 更新 self.im_files 列表,使其包含所有有效标签的图像文件路径。这是通过列表推导式实现的,遍历 labels 列表,提取每个标签字典中的 "im_file" 键对应的值。
self.im_files = [lb["im_file"] for lb in labels] # update im_files
# 这段代码的主要目的是从缓存中读取标签信息,并进行基本的验证。如果没有找到任何有效的图像标签信息,会记录一个警告信息,提示用户可能存在的问题。更新 self.im_files 列表,确保它只包含有有效标签的图像文件路径,为后续的数据加载和处理提供准备。
# 这段代码用于检查数据集中的标签信息,确保数据集是纯边界框(boxes)或纯分割(segments)数据集。
# Check if the dataset is all boxes or all segments
# 生成一个生成器表达式 lengths ,遍历 labels 列表中的 每个标签字典 lb ,计算 每个标签中的类别数量 len(lb["cls"]) 、 边界框数量 len(lb["bboxes"]) 和 分割数量 len(lb["segments"]) 。
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
# 使用 zip(*lengths) 将生成器表达式中的元组解包,然后计算每个位置上的总和,分别得到 类别总数 len_cls 、 边界框总数 len_boxes 和 分割总数 len_segments 。
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
# 检查 分割总数 len_segments 是否为非零且与边界框总数 len_boxes 不相等。
if len_segments and len_boxes != len_segments:
# 如果条件成立,记录一个警告信息,提示用户边界框和分割的数量应该相等,但实际不相等。
LOGGER.warning(
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " # 警告 ⚠️ 框和段数应该相等,但得到的 len(segments) = {len_segments},
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " # len(boxes) = {len_boxes}。为了解决这个问题,将只使用边界框,所有段都将被删除。
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." # 为了避免这种情况,请提供检测或片段数据集,而不是检测-片段混合数据集。
)
# 为了处理这种情况,将所有标签中的分割信息清空(即设置为一个空列表)。
for lb in labels:
lb["segments"] = []
# 检查 类别总数 len_cls 是否为零。
if len_cls == 0:
# 如果为零,表示没有找到任何有效的标签信息,记录一个警告信息,提示用户缓存文件中没有找到任何标签,训练可能无法正常进行,并提供一个帮助链接 HELP_URL 。
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}") # 警告 ⚠️ 在 {cache_path} 中未找到标签,训练可能无法正常工作。{HELP_URL}。
# 返回 处理后的标签列表 labels 。
return labels
# 这段代码的主要目的是确保数据集中的标签信息是一致的,即数据集应该是纯边界框或纯分割数据集。如果发现数据集中同时存在边界框和分割信息,且数量不相等,会记录警告信息并清空所有分割信息。如果没有找到任何有效的标签信息,会记录警告信息,提示用户可能存在的问题。最后,返回处理后的标签列表,为后续的数据加载和处理提供准备。
# get_labels 方法的主要目的是加载或生成缓存文件,读取缓存中的标签信息,并进行一些基本的验证和处理。该方法确保了数据集的标签信息是完整和一致的,为后续的数据加载和模型训练提供了准备。通过日志记录和进度条显示,提供了处理过程的可视化和错误信息的记录。
# 这段代码定义了 YOLODataset 类中的 build_transforms 方法,其主要功能是构建数据增强和格式化转换的流程。
# 定义了 build_transforms 方法,该方法接受一个可选参数。
# 1.hyp :通常是一个包含超参数的字典,用于控制数据增强的行为。
def build_transforms(self, hyp=None):
# 构建转换并将其附加到列表。
"""Builds and appends transforms to the list."""
# 检查 是否启用数据增强 。如果 self.augment 为 True ,则进行数据增强相关的设置。
if self.augment:
# 如果 启用数据增强 且 不是矩形训练 ( self.rect 为 False ),则使用 hyp 中的 mosaic 和 mixup 值,否则将这些值设置为 0.0 。 mosaic 和 mixup 是两种常见的数据增强技术,用于提高模型的泛化能力。
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
# 调用 v8_transforms 函数,传入当前实例、图像尺寸 self.imgsz 和超参数 hyp ,生成 数据增强转换流程 。 v8_transforms是一个自定义函数,用于根据提供的参数生成一系列数据增强操作。
# def v8_transforms(dataset, imgsz, hyp, stretch=False):
# -> 用于为YOLOv8训练准备图像数据增强流程。返回最终的数据增强流程,使用 Compose 将多个变换组合在一起。
# -> return Compose([pre_transform, MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), Albumentations(p=1.0), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), RandomFlip(direction="vertical", p=hyp.flipud), RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx), ]) # transforms
transforms = v8_transforms(self, self.imgsz, hyp)
# 如果未启用数据增强,则使用基本的转换流程。
else:
# 使用 Compose 函数组合一个 LetterBox 转换,将图像调整为指定的尺寸 self.imgsz ,且不进行缩放放大。 LetterBox 转换用于保持图像的宽高比,通过在图像周围添加填充(padding)来调整尺寸。
# class Compose:
# -> 用于组合多个图像变换操作。
# -> def __init__(self, transforms):
# class LetterBox:
# -> 用于将图像调整为指定的尺寸,同时保持图像的纵横比。这种变换常用于数据预处理,特别是在图像分类和目标检测任务中。
# -> def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
# 向转换流程中添加一个 Format 转换,用于格式化数据。 Format 转换的参数包括。
transforms.append(
# class Format:
# -> 用于对图像及其标签进行格式化处理,以便于后续的训练和推理。这个类提供了多种选项,可以处理边界框、分割掩码、关键点和定向边界框(OBB)。
# -> def __init__(self, bbox_format="xywh", normalize=True, return_mask=False, return_keypoint=False, return_obb=False, mask_ratio=4, mask_overlap=True, batch_idx=True, bgr=0.0,):
Format(
# 边界框格式为 (x, y, width, height) 。
bbox_format="xywh",
# 对图像进行归一化处理。
normalize=True,
# 如果启用分割,则返回掩码。
return_mask=self.use_segments,
# 如果启用关键点,则返回关键点信息。
return_keypoint=self.use_keypoints,
# 如果启用定向边界框,则返回定向边界框信息。
return_obb=self.use_obb,
# 返回批次索引。
batch_idx=True,
# 掩码比例。
mask_ratio=hyp.mask_ratio,
# 掩码重叠。
mask_overlap=hyp.overlap_mask,
# 如果启用数据增强,则使用 hyp.bgr 值,否则设置为 0.0 。这通常用于控制图像的色彩空间转换。
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
)
)
# 返回构建好的 转换流程 。
return transforms
# build_transforms 方法的主要目的是根据是否启用数据增强以及超参数 hyp 的设置,构建一个数据处理流程,包括数据增强和格式化转换。该方法确保了数据在训练前被正确处理,以适应模型的输入要求。通过条件判断和参数设置,方法灵活地支持不同的训练配置,如启用/禁用数据增强、矩形训练等。
# 这段代码定义了 YOLODataset 类中的 close_mosaic 方法,其主要功能是关闭mosaic数据增强,并更新数据增强流程。
# 定义了 close_mosaic 方法,该方法接受一个参数。
# 1.hyp :通常是一个包含超参数的字典,用于控制数据增强的行为。
def close_mosaic(self, hyp):
# 将马赛克、复制粘贴和混合选项设置为 0.0 并构建转换。
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
# 将 hyp 字典中的 mosaic 值设置为 0.0 ,表示关闭mosaic数据增强。Mosaic是一种数据增强技术,通过将多张图像拼接在一起形成一张新的图像,增加模型对不同图像组合的泛化能力。
hyp.mosaic = 0.0 # set mosaic ratio=0.0
# 将 hyp 字典中的 copy_paste 值设置为 0.0 ,表示关闭copy-paste数据增强。Copy-paste是一种数据增强技术,通过将一个图像中的对象复制并粘贴到另一个图像中,增加模型对不同对象组合的泛化能力。这里设置为 0.0 是为了保持与之前版本关闭mosaic时相同的行为。
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
# 将 hyp 字典中的 mixup 值设置为 0.0 ,表示关闭mixup数据增强。Mixup是一种数据增强技术,通过将两张图像及其标签按一定比例混合,增加模型对不同图像混合的泛化能力。这里设置为 0.0 也是为了保持与之前版本关闭mosaic时相同的行为。
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
# 调用 build_transforms 方法,传入更新后的 hyp 字典,重新构建 数据增强和格式化转换流程 ,并将结果赋值给 self.transforms 。这一步确保了数据增强流程中不再包含mosaic、copy-paste和mixup等操作。
self.transforms = self.build_transforms(hyp)
# close_mosaic 方法的主要目的是关闭mosaic数据增强,并更新数据增强流程,以确保在后续的数据处理中不再使用这些增强技术。通过设置 hyp 字典中的相关值为 0.0 ,方法确保了mosaic、copy-paste和mixup数据增强被关闭。最后,通过调用 build_transforms 方法,更新了数据增强流程,确保新的设置生效。
# 这段代码定义了 YOLODataset 类中的 update_labels_info 方法,其主要功能是更新标签信息,特别是处理分割信息和定向边界框(OBB)。
# 定义了 update_labels_info 方法,该方法接受一个参数。
# 1.label :这是一个字典,包含图像的标签信息。
def update_labels_info(self, label):
# 在此处自定义您的标签格式。
# 注意:
# cls 现在不包含 bboxes,分类和语义分割需要独立的 cls 标签。
# 还可以通过添加或删除字典键来支持分类和语义分割。
"""
Custom your label format here.
Note:
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
Can also support classification and semantic segmentation by adding or removing dict keys there.
"""
# 从 label 字典中移除并获取 "bboxes" 键对应的值,即 边界框信息 ,存储在变量 bboxes 中。
bboxes = label.pop("bboxes")
# 从 label 字典中移除并获取 "segments" 键对应的值,即 分割信息 。如果该键不存在,则默认为空列表 [] ,存储在变量 segments 中。
segments = label.pop("segments", [])
# 从 label 字典中移除并获取 "keypoints" 键对应的值,即 关键点信息 。如果该键不存在,则默认为 None ,存储在变量 keypoints 中。
keypoints = label.pop("keypoints", None)
# 从 label 字典中移除并获取 "bbox_format" 键对应的值,即 边界框的格式 ,存储在变量 bbox_format 中。
bbox_format = label.pop("bbox_format")
# 从 label 字典中移除并获取 "normalized" 键对应的值,即 边界框是否归一化 ,存储在变量 normalized 中。
normalized = label.pop("normalized")
# NOTE: do NOT resample oriented boxes 注意:不要重新采样定向框。
# 根据是否使用定向边界框(OBB),设置分割信息的重采样数量。如果使用OBB,则 重采样数量 为100;否则为1000。这里特别注明不重采样定向边界框。
segment_resamples = 100 if self.use_obb else 1000
# 如果 分割信息不为空 ,则对每个分割进行重采样,使其具有相同数量的点( segment_resamples )。使用 resample_segments 函数进行重采样,并将结果堆叠成一个三维数组。
if len(segments) > 0:
# list[np.array(1000, 2)] * num_samples
# (N, 1000, 2)
# def resample_segments(segments, n=1000): -> 用于对多边形线段(segments)进行重采样,使其具有指定数量的点(n=1000)。返回重采样后的线段列表。 -> return segments
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
# 如果分割信息为空,则创建一个形状为 (0, segment_resamples, 2) 的零数组。
else:
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
# 创建一个 Instances 对象,将边界框、分割信息、关键点信息、边界框格式和归一化信息封装在一起,并将该对象存储在 label 字典的 "instances" 键中。 Instances 是一个自定义类,用于表示图像中的实例信息。
# class Instances:
# -> 用于表示和操作图像实例,包括边界框、分割掩码和关键点等信息。
# -> def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
# 返回更新后的 label 字典。
return label
# update_labels_info 方法的主要目的是更新和标准化标签信息,特别是处理分割信息和定向边界框。该方法从输入的标签字典中提取必要的信息,进行重采样处理(如果需要),并封装成一个 Instances 对象,以便后续处理。通过这种方式,方法确保了标签信息的一致性和标准化,为模型训练提供了准备。
# 在 YOLODataset 类中,并没有直接对定向框(OBB)进行重采样的流程。定向框通常是指具有方向信息的边界框,用于表示旋转或倾斜的对象。在目标检测任务中,定向框可以提供更精确的对象定位,尤其是在处理具有明显方向性的对象(如飞机、船只等)时。
# 为什么不对定向框进行重采样?
# 定向框的重采样通常不是必要的,原因如下 :
# 保持方向信息 :定向框的一个关键特性是它们的方向信息。重采样可能会改变这些方向信息,从而影响模型对对象方向的准确预测。例如,如果一个定向框表示一个倾斜的飞机,重采样可能会使飞机的方向变得不准确。
# 数据一致性 :在数据增强和预处理过程中,保持定向框的一致性非常重要。如果对定向框进行重采样,可能会导致边界框和分割信息之间的不一致,从而影响模型的训练效果。
# 计算复杂性 :重采样定向框可能会增加计算复杂性,尤其是在处理大量数据时。保持定向框的原始信息可以简化数据处理流程,提高数据加载和预处理的效率。
# 重采样分割信 :
# 息在 update_labels_info 方法中,对分割信息进行了重采样,而不是定向框。
# 分割信息的重采样是为了确保所有分割具有相同数量的点,这有助于模型在处理不同形状和大小的分割时保持一致性。具体步骤如下 :
# 确定重采样数量 :根据是否使用定向边界框(OBB),设置分割信息的重采样数量。如果使用OBB,则重采样数量为100;否则为1000。
# 重采样分割信息 :如果分割信息不为空,则对每个分割进行重采样,使其具有相同数量的点。使用 resample_segments 函数进行重采样,并将结果堆叠成一个三维数组。
# 处理空分割 :如果分割信息为空,则创建一个形状为 (0, segment_resamples, 2) 的零数组。
# 这段代码定义了 YOLODataset 类中的一个静态方法 collate_fn ,其主要功能是将一个批次(batch)中的多个样本(sample)合并成一个批次数据,以便模型可以一次性处理这些数据。
@staticmethod
# 定义了一个静态方法 collate_fn ,该方法接受一个参数。
# 1.batch :一个列表,其中每个元素是一个字典,表示一个样本的数据。
def collate_fn(batch):
# 将数据样本整理成批次。
"""Collates data samples into batches."""
# 初始化一个空字典 new_batch ,用于 存储合并后的批次数据 。
new_batch = {}
# 获取第一个样本字典的键(keys),假设所有样本字典的键是相同的。这些键可能包括 "img" (图像)、 "masks" (掩码)、 "keypoints" (关键点)、 "bboxes" (边界框)、 "cls" (类别)、 "segments" (分割)、 "obb" (定向边界框)等。
keys = batch[0].keys()
# 将 每个样本字典 的值(values)提取出来,并使用 zip 函数将相同键的值组合在一起,形成一个列表的列表。这样, values 中的每个元素是一个元组,包含 所有样本中对应键的值 。
values = list(zip(*[list(b.values()) for b in batch]))
# 遍历每个键 k 及其索引 i 。
for i, k in enumerate(keys):
# 获取当前键 k 对应的值列表 value 。
value = values[i]
# 如果键是 "img" ,表示图像数据,使用 torch.stack 函数将图像数据堆叠成一个四维张量。 torch.stack 函数会增加一个新的维度,用于表示批次大小。
if k == "img":
value = torch.stack(value, 0)
# 如果键是 "masks" 、 "keypoints" 、 "bboxes" 、 "cls" 、 "segments" 或 "obb" ,表示这些数据是需要合并的张量,使用 torch.cat 函数将它们在第一个维度(批次维度)上进行拼接。
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
value = torch.cat(value, 0)
# 将处理后的值 value 存储在 new_batch 字典中,键为 k 。
new_batch[k] = value
# 将 "batch_idx" 键对应的值转换为列表。 "batch_idx" 通常用于记录 每个样本在批次中的索引 。
new_batch["batch_idx"] = list(new_batch["batch_idx"])
# 遍历 "batch_idx" 列表,将每个索引值加上 其在列表中的位置 i 。这一步是为了在构建目标时,能够 正确地将每个样本的标签信息与对应的图像关联起来 。
for i in range(len(new_batch["batch_idx"])):
new_batch["batch_idx"][i] += i # add target image index for build_targets()
# 将 "batch_idx" 列表转换为一个一维张量。
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
# 返回 合并后的批次数据 new_batch 。
return new_batch
# collate_fn 方法的主要目的是将一个批次中的多个样本合并成一个批次数据,以便模型可以一次性处理这些数据。该方法处理了不同类型的张量数据,包括图像、掩码、关键点、边界框、类别、分割和定向边界框,并将它们适当地堆叠或拼接。通过处理 "batch_idx" ,确保每个样本的标签信息与对应的图像正确关联,这对于后续的目标构建和损失计算非常重要。
# YOLODataset 类是一个用于处理YOLO模型数据集的自定义数据集类,继承自 BaseDataset 。它提供了数据加载、预处理、数据增强、标签处理和批次合并等功能。通过灵活的配置,支持多种任务类型,如目标检测、分割和姿态估计。类中的方法确保了数据的一致性和标准化,为模型训练提供了高质量的数据输入。
3.class ClassificationDataset(torchvision.datasets.ImageFolder):
# Classification dataloaders -------------------------------------------------------------------------------------------
# 这段代码定义了一个名为 ClassificationDataset 的类,它继承自 torchvision.datasets.ImageFolder ,用于处理图像分类数据集。
# 定义了 ClassificationDataset 类,继承自 torchvision.datasets.ImageFolder 。
class ClassificationDataset(torchvision.datasets.ImageFolder):
# 扩展 torchvision ImageFolder 以支持 YOLO 分类任务,提供图像增强、缓存和验证等功能。它旨在高效处理用于训练深度学习模型的大型数据集,并具有可选的图像转换和缓存机制以加快训练速度。
# 此类允许使用 torchvision 和 Albumentations 库进行增强,并支持在 RAM 或磁盘上缓存图像以减少训练期间的 IO 开销。此外,它还实现了强大的验证过程以确保数据的完整性和一致性。
"""
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
learning models, with optional image transformations and caching mechanisms to speed up training.
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
to ensure data integrity and consistency.
Attributes:
cache_ram (bool): Indicates if caching in RAM is enabled.
cache_disk (bool): Indicates if caching on disk is enabled.
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
torch_transforms (callable): PyTorch transforms to be applied to the images.
"""
# 这段代码是 ClassificationDataset 类的初始化方法 __init__ ,用于设置数据集的各种属性和参数。
# 定义了 ClassificationDataset 类的初始化方法,接受以下参数 :
# 1.root :数据集的根目录。
# 2.args :一个包含各种配置参数的对象。
# 3.augment :一个布尔值,表示是否启用数据增强,默认为 False 。
# 4.prefix :一个字符串,用于日志信息的前缀,默认为空字符串。
def __init__(self, root, args, augment=False, prefix=""):
# 使用 root、图像大小、增强和缓存设置初始化 YOLO 对象。
"""
Initialize YOLO object with root, image size, augmentations, and cache settings.
Args:
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
debugging. Default is an empty string.
"""
# 调用父类 ImageFolder 的初始化方法,传入 根目录 root 。这一步初始化了数据集的基本属性,如样本列表 self.samples 。
super().__init__(root=root)
# 如果 启用数据增强 且 args.fraction 小于1.0,则减少训练数据的比例。 args.fraction 表示要使用的数据比例,这在调试或快速训练时非常有用。
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
# 设置日志信息的前缀,如果 prefix 不为空,则使用 colorstr 函数添加颜色。 colorstr 函数用于在终端中显示彩色文本,增强日志的可读性。
# def colorstr(*input):
# -> 用于生成带有颜色和其他格式化选项的字符串,通常用于在终端或命令行界面中输出彩色文本。构建并返回最终的字符串。它通过遍历 args 中的每个元素,查找 colors 字典中对应的 ANSI 转义序列,并将它们与 string 连接起来。最后,它添加一个结束转义序列 colors["end"] 来重置终端的颜色和样式设置。
# -> return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
# 检查 是否将图像缓存到RAM中 。如果 args.cache 为 True 或等于 "ram" ,则启用RAM缓存。这可以显著加快数据加载速度,但会占用大量内存。
self.cache_ram = args.cache is True or args.cache == "ram" # cache images into RAM
# 检查 是否将图像缓存到硬盘上 。如果 args.cache 等于 "disk" ,则启用硬盘缓存。这会将图像保存为未压缩的 .npy 文件,可以加快后续的加载速度。
self.cache_disk = args.cache == "disk" # cache images on hard drive as uncompressed *.npy files
# 调用 verify_images 方法,过滤掉损坏的图像。这一步确保数据集中的图像都是有效的,避免在训练过程中出现错误。
self.samples = self.verify_images() # filter out bad images
# 更新 self.samples 列表,为每个样本添加一个 .npy 文件路径和一个 None 值,用于 缓存图像 。这一步是为了 支持硬盘缓存 。
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
# 设置 图像缩放范围 。 args.scale 通常是一个小的浮点数,表示最小缩放比例。例如,如果 args.scale 为0.08,则缩放范围为 (0.92, 1.0) 。
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
# 根据是否启用数据增强,选择合适的转换流程。
self.torch_transforms = (
# 如果启用数据增强,使用 classify_augmentations 函数生成数据增强转换流程。
# def classify_augmentations(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, scale=None, ratio=None, hflip=0.5, vflip=0.0, auto_augment=None, hsv_h=0.015, hsv_s=0.4, hsv_v=0.4, force_color_jitter=False, erasing=0.0, interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,):
# -> 用于生成图像分类任务的数据增强流程。这个函数结合了多种数据增强技术,以提高模型的泛化能力和鲁棒性。使用 T.Compose 将所有数据增强步骤( primary_tfl 、 secondary_tfl 和 final_tfl )组合成一个完整的数据增强流程。返回这个组合后的数据增强流程,可以作为一个整体应用于图像数据。
# -> return T.Compose(primary_tfl + secondary_tfl + final_tfl)
classify_augmentations(
size=args.imgsz,
scale=scale,
hflip=args.fliplr,
vflip=args.flipud,
erasing=args.erasing,
auto_augment=args.auto_augment,
hsv_h=args.hsv_h,
hsv_s=args.hsv_s,
hsv_v=args.hsv_v,
)
if augment
# 如果不启用数据增强,使用 classify_transforms 函数生成基本的转换流程。
# def classify_transforms(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, crop_fraction: float = DEFAULT_CROP_FTACTION,):
# -> 用于生成图像分类任务的数据预处理流程。使用 T.Compose 将所有预处理步骤组合成一个完整的预处理流程,并返回该流程。
# -> return T.Compose(tfl)
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
)
# __init__ 方法的主要目的是初始化 ClassificationDataset 类,设置数据集的各种属性和参数。该方法支持数据增强、图像缓存(RAM或硬盘)和图像验证,确保数据集的完整性和加载效率。通过灵活的配置,可以适应不同的训练需求,如调试、快速训练或大规模训练。
# 这段代码定义了 ClassificationDataset 类中的 __getitem__ 方法,其主要功能是获取数据集中的第 i 个样本,并进行必要的处理和转换。
# 定义了 __getitem__ 方法,该方法接受一个参数。
# 1.i :要获取的样本索引。
def __getitem__(self, i):
# 返回与给定索引相对应的数据子集和目标。
"""Returns subset of data and targets corresponding to given indices."""
# 从 self.samples 列表中提取第 i 个样本的 文件名 f 、 索引 j 、 .npy 文件路径 fn 和 图像 im 。
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
# 如果启用了RAM缓存且图像 im 未缓存,则读取图像并缓存到 self.samples 中。 cv2.imread(f) 读取图像文件,返回一个NumPy数组。
if self.cache_ram and im is None:
im = self.samples[i][3] = cv2.imread(f)
# 如果启用了硬盘缓存且 .npy 文件不存在。
elif self.cache_disk:
if not fn.exists(): # load npy
# 则读取图像并保存为 .npy 文件,然后加载 .npy 文件。 np.save 函数将图像保存为未压缩的 .npy 文件, np.load 函数加载 .npy 文件。
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
im = np.load(fn)
# 如果未启用缓存,则直接读取图像文件。 cv2.imread(f) 读取图像文件,返回一个NumPy数组,图像格式为BGR。
else: # read image
im = cv2.imread(f) # BGR
# Convert NumPy array to PIL image
# 将图像从BGR格式转换为RGB格式,并从NumPy数组转换为PIL图像。 cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 将BGR图像转换为RGB图像, Image.fromarray 将NumPy数组转换为PIL图像。
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
# 应用转换流程 self.torch_transforms 。 self.torch_transforms 是一个 torchvision 转换流程,可以包含数据增强、归一化等操作。
sample = self.torch_transforms(im)
# 返回一个字典,包含 处理后的图像 sample 和 类别标签 j 。
return {"img": sample, "cls": j}
# __getitem__ 方法的主要目的是获取数据集中的第 i 个样本,并进行必要的处理和转换。该方法支持RAM缓存和硬盘缓存,可以显著加快数据加载速度。图像从BGR格式转换为RGB格式,并从NumPy数组转换为PIL图像,以适应 torchvision 的转换流程。通过应用转换流程,可以对图像进行数据增强和预处理,确保图像格式和内容符合模型输入的要求。
# 这段代码定义了 ClassificationDataset 类中的 __len__ 方法,其主要功能是返回数据集的长度,即数据集中样本的数量。
# 定义了 __len__ 方法,该方法返回一个整数,表示数据集的长度。
def __len__(self) -> int:
# 返回数据集中的样本总数。
"""Return the total number of samples in the dataset."""
# 返回 self.samples 列表的长度, self.samples 列表包含数据集中的所有样本。每个样本是一个元组,包含文件路径、类别索引、 .npy 文件路径和图像数据(如果缓存)。
return len(self.samples)
# __len__ 方法的主要目的是返回数据集中的样本数量。这个方法在PyTorch数据加载器( DataLoader )中非常有用,因为它需要知道数据集的长度来正确地批量加载数据。通过返回 self.samples 列表的长度, __len__ 方法提供了数据集大小的准确信息,确保数据加载器可以正确地迭代数据集。
# 这段代码定义了 ClassificationDataset 类中的 verify_images 方法,其主要功能是验证数据集中的图像文件,确保它们是有效的,并且过滤掉损坏的图像。
# 定义了 verify_images 方法,该方法不接受额外的参数。
def verify_images(self):
# 验证数据集中的所有图像。
"""Verify all images in dataset."""
# 生成 描述字符串 desc ,用于进度条显示,说明正在扫描的目录。
desc = f"{self.prefix}Scanning {self.root}..." # {self.prefix}正在扫描 {self.root}...
# 生成 缓存文件的路径 path ,缓存文件的扩展名为 .cache 。
path = Path(self.root).with_suffix(".cache") # *.cache file path
# 使用 contextlib.suppress 上下文管理器,忽略在加载缓存文件过程中可能抛出的 FileNotFoundError 、 AssertionError 和 AttributeError 异常。
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
# 尝试加载缓存文件,缓存文件包含数据集的验证结果。
# def load_dataset_cache_file(path): -> 加载一个缓存文件,该文件通常是一个包含数据集信息的字典。返回加载的缓存字典。 -> return cache
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
# 断言缓存文件的版本号与当前代码中定义的版本号 DATASET_CACHE_VERSION 相匹配。如果版本号不匹配,抛出 AssertionError 异常。
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
# 断言缓存文件中的哈希值与当前数据集文件路径的哈希值相匹配。如果哈希值不匹配,抛出 AssertionError 异常。
# def get_hash(paths): -> 计算一个路径列表(可以是文件或目录)的单个哈希值。返回哈希对象的十六进制摘要,即路径列表的哈希值。该哈希值可以用于验证路径列表的内容是否发生变化。 -> return h.hexdigest() # return hash
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
# 从缓存字典中提取验证结果,包括找到的 图像数量 nf 、 损坏的图像数量 nc 、 总图像数量 n 和 有效的样本列表 samples 。
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
# 检查当前进程的排名 LOCAL_RANK 是否为 -1 或 0 。在分布式训练中, LOCAL_RANK 为 0 通常表示主进程, -1 通常表示单进程环境。这行代码确保只有主进程或单进程环境会执行后续的显示操作,避免多个进程重复显示相同的信息。
if LOCAL_RANK in (-1, 0):
# 生成描述字符串 d ,用于显示验证结果,包括扫描的目录、找到的图像数量 nf 和损坏的图像数量 nc 。
d = f"{desc} {nf} images, {nc} corrupt" # {desc} {nf} 图像,{nc} 损坏。
# 使用 TQDM 进度条库显示验证结果。这里 TQDM 的用法稍有不同,通常 TQDM 用于显示进度,但在这里它被用来显示一个静态的消息,因为 total 和 initial 参数相等,表示进度条已经完成。 desc 参数用于设置进度条的描述信息。
TQDM(None, desc=d, total=n, initial=n)
# 如果缓存字典 cache 中存在 "msgs" 键,并且该键对应的值(消息列表)不为空,则将这些消息合并为一个字符串,并使用日志记录器 LOGGER 记录这些信息。这些消息通常是警告或错误信息,用于提示用户数据集中的潜在问题。
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# 返回有效的样本列表 samples 。
return samples
# Run scan if *.cache retrieval failed
# 如果缓存文件加载失败,初始化 计数器 和 列表 ,准备 进行图像验证 。
nf, nc, msgs, samples, x = 0, 0, [], [], {}
# 使用线程池,线程数为 NUM_THREADS ,用于并行验证图像文件。
with ThreadPool(NUM_THREADS) as pool:
# 使用线程池的 imap 方法,将 verify_image 函数应用于每个样本和前缀的组合。 verify_image 函数用于验证图像文件的有效性。
# def verify_image(args): -> 验证单个图像文件的有效性。返回一个元组,包含 图像文件路径 和 类别 、 找到的图像数量 、 损坏的图像数量 和 消息字符串 。 -> return (im_file, cls), nf, nc, msg
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
# 创建一个进度条,用于显示验证进度。
pbar = TQDM(results, desc=desc, total=len(self.samples))
# 遍历进度条 pbar 的结果,每个结果是一个元组,包含以下元素 :
# sample :当前样本的信息。
# nf_f :当前样本是否有效(1表示有效,0表示无效)。
# nc_f :当前样本是否损坏(1表示损坏,0表示未损坏)。
# msg :与当前样本相关的消息(如警告或错误信息)。
for sample, nf_f, nc_f, msg in pbar:
# 如果当前样本有效( nf_f 为1),则将该样本添加到 有效样本列表 samples 中。
if nf_f:
samples.append(sample)
# 如果存在与当前样本相关的消息( msg 不为空),则将该消息添加到 消息列表 msgs 中。
if msg:
msgs.append(msg)
# 更新 找到的图像数量 nf ,将其与当前样本的 有效性标志 nf_f 相加。
nf += nf_f
# 更新 损坏的图像数量 nc ,将其与当前样本的 损坏标志 nc_f 相加。
nc += nc_f
# 更新进度条的描述信息,显示当前找到的图像数量 nf 和损坏的图像数量 nc 。这一步确保进度条实时显示最新的验证结果。
pbar.desc = f"{desc} {nf} images, {nc} corrupt" # {desc} {nf} 图像,{nc} 损坏。
# 关闭进度条。
pbar.close()
# 如果存在警告信息,记录这些信息。
if msgs:
LOGGER.info("\n".join(msgs))
# 生成 新的缓存字典 x ,包含 哈希值 、 验证结果 和 警告信息 。
# def get_hash(paths): -> 计算一个路径列表(可以是文件或目录)的单个哈希值。返回哈希对象的十六进制摘要,即路径列表的哈希值。该哈希值可以用于验证路径列表的内容是否发生变化。 -> return h.hexdigest() # return hash
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
# 将新的缓存字典保存到文件中。
# def save_dataset_cache_file(prefix, path, x): -> 将一个Ultralytics数据集的缓存字典 x 保存到指定路径 path 。
save_dataset_cache_file(self.prefix, path, x)
# 返回 有效的样本列表 samples 。
return samples
# verify_images 方法的主要目的是验证数据集中的图像文件,确保它们是有效的,并且过滤掉损坏的图像。该方法首先尝试加载缓存文件,如果缓存文件存在且有效,则直接返回验证结果。如果缓存文件加载失败,方法将并行验证每个图像文件的有效性,生成新的缓存文件,并返回有效的样本列表。通过并行处理和缓存机制,方法提高了验证效率,确保数据集的完整性和一致性。
# ClassificationDataset 类继承自 torchvision.datasets.ImageFolder ,用于处理图像分类数据集。该类支持数据增强、图像缓存(RAM或硬盘)和图像验证。通过缓存机制,可以加速数据加载,特别是在处理大型数据集时。通过验证机制,可以过滤掉损坏的图像,确保数据集的完整性。
4.def load_dataset_cache_file(path):
# 这段代码定义了一个函数 load_dataset_cache_file ,其主要功能是加载一个缓存文件,该文件通常是一个包含数据集信息的字典。
# 定义了 load_dataset_cache_file 函数,该函数接受一个参数。
# 1.path :表示缓存文件的路径。
def load_dataset_cache_file(path):
# 从路径加载 Ultralytics *.cache 字典。
"""Load an Ultralytics *.cache dictionary from path."""
# 导入Python的垃圾回收器模块 gc 。
import gc
# 禁用垃圾回收器。在加载大型文件时,禁用垃圾回收器可以提高加载速度,因为垃圾回收器在加载过程中不会干扰内存分配。
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
# np.load(file, mmap_mode=None, allow_pickle=False, fix_imports=False, encoding='bytes')
# np.load() 函数是 NumPy 库中用于加载 .npy 或 .npz 文件的函数。这些文件格式用于存储单一的 NumPy 数组或多个数组(分别对应 .npy 和 .npz 文件)。 np.load() 函数可以读取这些文件中存储的数组数据。
# 参数 :
# file :要加载的文件路径,可以是字符串路径或文件对象。
# mmap_mode :(可选)内存映射模式,用于控制如何将文件内容映射到内存中。默认为 None ,表示不使用内存映射。
# allow_pickle :(可选)布尔值,指示是否允许加载 pickle 对象。默认为 False ,出于安全考虑,防止执行不受信任的数据。
# fix_imports :(可选)布尔值,指示是否修复导入路径。默认为 False ,仅在加载 pickle 对象时相关。
# encoding :(可选)字符串,指定文件编码。默认为 'bytes' ,表示文件内容被读取为字节字符串。
# 返回值 :
# 返回加载的 NumPy 数组或包含多个数组的字典(对于 .npz 文件)。
# np.load() 函数是 NumPy 数据持久化和读取数据的重要工具,特别适用于需要保存和恢复大型数组数据的场景。
# 使用 numpy 的 np.load 函数加载缓存文件。 str(path) 确保路径是字符串格式。 allow_pickle=True 允许使用 pickle 模块反序列化对象,因为缓存文件可能包含复杂的数据结构(如字典)。 np.load 返回一个数组,调用 .item() 方法将其转换为Python字典。
cache = np.load(str(path), allow_pickle=True).item() # load dict
# 重新启用垃圾回收器,确保内存管理恢复正常。
gc.enable()
# 返回加载的缓存字典。
return cache
# load_dataset_cache_file 函数的主要目的是加载一个包含数据集信息的缓存文件。该函数通过禁用和重新启用垃圾回收器,优化了加载过程,特别是在处理大型文件时。使用 numpy 的 np.load 函数加载缓存文件,并将其转换为Python字典,以便后续处理。这种加载机制在数据预处理和模型训练中非常有用,可以快速加载和使用缓存的数据集信息。
5.def save_dataset_cache_file(prefix, path, x):
# 这段代码定义了一个函数 save_dataset_cache_file ,其主要功能是将一个Ultralytics数据集的缓存字典 x 保存到指定路径 path 。
# 定义了 save_dataset_cache_file 函数,该函数接受三个参数。
# 1.prefix :前缀,用于日志信息。
# 2.path :缓存文件的路径。
# 3.x :要保存的缓存字典。
def save_dataset_cache_file(prefix, path, x):
# 将 Ultralytics 数据集 *.cache 字典 x 保存到路径。
"""Save an Ultralytics dataset *.cache dictionary x to path."""
# 在缓存字典 x 中添加一个 "version" 键,值为 DATASET_CACHE_VERSION ,表示 缓存的版本 。这有助于在加载缓存时验证版本的一致性。
x["version"] = DATASET_CACHE_VERSION # add cache version
# 检查缓存文件的父目录是否可写。 is_dir_writeable 是一个自定义函数,用于检查目录是否具有写权限。
# def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
# -> 检查指定目录是否具有写权限。如果目录具有写权限, os.access 返回 True ,否则返回 False 。
# -> return os.access(str(dir_path), os.W_OK)
if is_dir_writeable(path.parent):
# 如果缓存文件已经存在,则使用 path.unlink() 方法删除它,以便重新保存新的缓存文件。
if path.exists():
path.unlink() # remove *.cache file if exists
# 使用 numpy 的 np.save 函数将缓存字典 x 保存到指定路径 path 。 np.save 函数会自动添加 .npy 后缀。
np.save(str(path), x) # save cache for next time
# 将保存的文件重命名,去除 .npy 后缀,使其符合预期的缓存文件格式(例如, dataset.cache )。
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
# 记录一条信息,提示新的缓存文件已创建。
LOGGER.info(f"{prefix}New cache created: {path}") # {prefix} 创建新缓存:{path}。
# 如果缓存文件的父目录不可写,则记录一条警告信息,提示缓存文件未保存。
else:
LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") # {prefix}警告 ⚠️ 缓存目录 {path.parent} 不可写入,缓存未保存。
# save_dataset_cache_file 函数的主要目的是将数据集的缓存字典保存到指定路径,以便后续快速加载和使用。该函数在保存缓存文件前,会检查父目录的写权限,确保文件可以成功保存。通过记录日志信息,函数提供了操作的反馈,帮助开发者了解缓存文件的保存状态。
6.class SemanticDataset(BaseDataset):
# TODO: support semantic segmentation TODO:支持语义分割。
# 这段代码定义了一个名为 SemanticDataset 的类,它继承自 BaseDataset 。这个类目前只是一个占位符,用于处理语义分割任务的数据集。
# 定义了 SemanticDataset 类,继承自 BaseDataset 。 BaseDataset 是一个包含基本数据集操作的基类,如数据加载、预处理等。
class SemanticDataset(BaseDataset):
# 语义分割数据集。
# 此类负责处理用于语义分割任务的数据集。它从 BaseDataset 类继承功能。
# 注意:
# 此类当前为占位符,需要填充方法和属性以支持语义分割任务。
"""
Semantic Segmentation Dataset.
This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
from the BaseDataset class.
Note:
This class is currently a placeholder and needs to be populated with methods and attributes for supporting
semantic segmentation tasks.
"""
# 定义了 SemanticDataset 类的初始化方法 __init__ ,该方法不接受额外的参数。
def __init__(self):
# 初始化 SemanticDataset 对象。
"""Initialize a SemanticDataset object."""
# 调用父类 BaseDataset 的初始化方法。这一步确保了 BaseDataset 中的初始化逻辑被正确执行,初始化了 BaseDataset 中的属性和方法。
super().__init__()
# SemanticDataset 类目前只是一个简单的继承自 BaseDataset 的类,没有添加额外的属性或方法。通过调用 super().__init__() ,确保了父类的初始化逻辑被正确执行。这个类可以进一步扩展,添加特定于语义分割数据集的属性和方法,如数据增强、标签处理等。
# 扩展解释 :
# 初始化方法 : __init__ 方法接受数据集根目录 root 、图像转换 transform 和标签转换 target_transform 。
# 加载数据集 : load_dataset 方法遍历根目录,加载图像和标签的路径。
# 获取样本 : __getitem__ 方法加载第 index 个样本的图像和标签,并应用转换。
# 数据集长度 : __len__ 方法返回数据集中的样本数量。通过这样的扩展, SemanticDataset 类可以处理语义分割任务中的图像和标签数据。