utils.py
ultralytics\data\utils.py
目录
utils.py
1.所需的库和模块
2.def img2label_paths(img_paths):
3.def get_hash(paths):
4.def exif_size(img: Image.Image):
5.def verify_image(args):
6.def verify_image_label(args):
7.def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
8.def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
9.def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
10.def find_dataset_yaml(path: Path) -> Path:
11.def check_det_dataset(dataset, autodownload=True):
12.def check_cls_dataset(dataset, split=""):
13.class HUBDatasetStats:
14.def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
15.def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
1.所需的库和模块
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import hashlib
import json
import os
import random
import subprocess
import time
import zipfile
from multiprocessing.pool import ThreadPool
from pathlib import Path
from tarfile import is_tarfile
import cv2
import numpy as np
from PIL import Image, ImageOps
from ultralytics.nn.autobackend import check_class_names
from ultralytics.utils import (
DATASETS_DIR,
LOGGER,
NUM_THREADS,
ROOT,
SETTINGS_YAML,
TQDM,
clean_url,
colorstr,
emojis,
yaml_load,
yaml_save,
)
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
from ultralytics.utils.ops import segments2boxes
# 这段代码定义了一些常量和环境变量,用于配置和指导数据集的格式化以及数据加载器的行为。
# 定义了一个字符串常量 HELP_URL ,包含一个URL,指向Ultralytics文档中关于物体检测数据集格式化的指导页面。这个URL提供了一个链接,用户可以访问以获取如何格式化数据集的详细信息。
HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance." # 有关数据集格式指南,请参阅https://docs.ultralytics.com/datasets/detect。
# 定义了一个集合 IMG_FORMATS ,包含支持的图像文件格式的后缀。这些后缀用于验证图像文件的有效性,确保数据集中的图像文件格式是受支持的。
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
# 定义了一个集合 VID_FORMATS ,包含支持的视频文件格式的后缀。这些后缀用于验证视频文件的有效性,确保数据集中的视频文件格式是受支持的。
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
# 定义了一个布尔变量 PIN_MEMORY ,用于控制数据加载器( DataLoader )是否使用内存锁定(pin memory)。内存锁定可以提高数据传输到GPU的效率。 os.getenv("PIN_MEMORY", True) 从环境变量中获取 PIN_MEMORY 的值,如果环境变量未设置,则默认为 True 。然后,将获取的值转换为字符串并转换为小写,最后检查是否等于 "true" 。
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
# HELP_URL 提供了一个链接,指向Ultralytics文档中关于物体检测数据集格式化的指导页面。 IMG_FORMATS 和 VID_FORMATS 分别定义了支持的图像和视频文件格式的后缀,用于验证文件的有效性。 PIN_MEMORY 是一个布尔变量,用于控制数据加载器是否使用内存锁定,以提高数据传输到GPU的效率。这些常量和环境变量在数据处理和模型训练中非常有用,确保数据集的格式正确,并优化数据加载的性能。
2.def img2label_paths(img_paths):
# 这段代码定义了一个函数 img2label_paths ,其主要功能是将图像文件路径转换为对应的标签文件路径。
# 定义了 img2label_paths 函数,该函数接受一个参数。
# 1.img_paths :这是一个包含图像文件路径的列表。
def img2label_paths(img_paths):
# 将标签路径定义为图像路径的函数。
"""Define label paths as a function of image paths."""
# 定义两个字符串 sa 和 sb ,分别表示 图像目录 和 标签目录 的子字符串。 os.sep 是操作系统特定的路径分隔符(在Windows中为 \ ,在Unix/Linux中为 / )。
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
# str.rsplit(separator=None, maxsplit=-1)
# 在Python中, rsplit() 是字符串( str )对象的一个方法,用于在字符串末尾进行分割操作。这个方法从字符串的右侧(末尾)开始分割,而不是默认的左侧(开头)。
# 参数 :
# separator : 分隔符,用于指定分隔字符串的字符或字符串。如果未指定或为 None ,则任何空白字符(如空格、换行 \n 、制表符 \t 等)都被视为分隔符。
# maxsplit :最大分割次数。如果设置为 -1 (默认值),则没有分割次数限制,字符串会被完全分割直到没有分隔符为止。如果设置为其他整数,则在达到指定的分割次数后停止分割。
# 返回值 :
# rsplit() 方法返回一个列表,包含分割后的子字符串。
# 使用列表推导式,将每个 图像文件路径 转换为 对应的标签文件路径 。
# x.rsplit(sa, 1) :从右侧开始分割路径 x ,最多分割一次,分割符为 sa (即 /images/ )。这将路径分为两部分,前部分是目录路径,后部分是文件名。
# sb.join(...) :将前部分目录路径和 sb (即 /labels/ )重新连接起来,形成新的路径。
# rsplit(".", 1)[0] :从右侧开始分割文件名,最多分割一次,分割符为 . 。这将文件名分为两部分,前部分是文件名(不包括扩展名),后部分是扩展名。取前部分,即文件名(不包括扩展名)。
# + ".txt" :将文件名(不包括扩展名)与 .txt 扩展名连接起来,形成标签文件的完整路径。
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
# img2label_paths 函数的主要目的是将图像文件路径转换为对应的标签文件路径。该函数通过字符串操作,将 /images/ 替换为 /labels/ ,并更改文件扩展名为 .txt ,从而生成标签文件的路径。这种转换在数据预处理和加载标签时非常有用,特别是在处理大规模数据集时,可以自动化地生成标签文件路径。
3.def get_hash(paths):
# 这段代码定义了一个函数 get_hash ,其主要功能是计算一个路径列表(可以是文件或目录)的单个哈希值。
# 定义了 get_hash 函数,该函数接受一个参数。
# 1.paths :这是一个路径列表,可以包含文件或目录的路径。
def get_hash(paths):
# 返回路径列表(文件或目录)的单个哈希值。
"""Returns a single hash value of a list of paths (files or dirs)."""
# 计算路径列表中 所有存在的文件的总大小 。使用 os.path.getsize(p) 获取每个文件的大小,并使用 sum 函数计算总大小。如果路径不存在,则忽略该路径。
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
# hashlib.sha256([data])
# hashlib.sha256() 函数是 Python 标准库 hashlib 模块中的一个函数,用于创建一个 SHA-256 哈希对象。SHA-256 是一种加密哈希函数,可以将任意长度的数据转换为一个固定长度(256位,即32字节)的唯一哈希值。
# 参数 :
# data :(可选)一个初始字节序列,用于初始化哈希对象。如果不提供此参数,则创建一个空的哈希对象。
# 返回值 :
# 返回一个新的 sha256 哈希对象。
# 方法 :
# 创建 sha256 哈希对象后,你可以使用以下方法 :
# update(data) :向哈希对象添加数据。 data 必须是字节序列或字节数组。
# digest() :返回当前哈希对象的二进制(十六进制编码)哈希值。
# hexdigest() :返回当前哈希对象的十六进制编码哈希值。
# copy() :返回当前哈希对象的一个副本。
# hashlib.sha256() 函数在需要确保数据完整性、生成数据签名或验证数据未被篡改时非常有用。它是许多安全应用程序和协议中的一个基本组件。
# 使用 hashlib.sha256 创建一个 SHA-256哈希对象 ,并将文件总大小的字符串表示(编码为字节)更新到哈希对象中。这一步是为了 在哈希计算中包含文件大小信息 。
h = hashlib.sha256(str(size).encode()) # hash sizes
# 将 路径列表 连接成一个字符串,并将其编码为字节,然后更新到哈希对象中。这一步是为了 在哈希计算中包含路径信息 。
h.update("".join(paths).encode()) # hash paths
# 返回哈希对象的十六进制摘要,即路径列表的哈希值。
return h.hexdigest() # return hash
# get_hash 函数的主要目的是计算一个路径列表的单个哈希值,该哈希值可以用于验证路径列表的内容是否发生变化。该函数通过计算文件的总大小和路径信息的哈希值,生成一个唯一的哈希值。这种哈希值可以用于缓存机制,确保数据的一致性和完整性,特别是在数据预处理和模型训练中。
4.def exif_size(img: Image.Image):
# 这段代码定义了一个函数 exif_size ,其主要功能是获取并根据EXIF信息(特别是图像的旋转信息)修正PIL图像的尺寸。
# 定义了 exif_size 函数,该函数接受一个参数。
# 1.img :类型为 Image.Image ,即PIL库中的图像对象。
def exif_size(img: Image.Image):
# 返回经过 exif 校正的 PIL 大小。
"""Returns exif-corrected PIL size."""
# 获取 图像的原始尺寸 ,存储在变量 s 中。 img.size 返回一个元组,包含图像的宽度和高度。
s = img.size # (width, height)
# 检查图像格式是否为JPEG。EXIF信息通常存在于JPEG图像中,因此这里只处理JPEG格式的图像。
if img.format == "JPEG": # only support JPEG images
# 使用 contextlib.suppress(Exception) 上下文管理器,忽略在获取和处理EXIF信息过程中可能抛出的任何异常。这确保了即使EXIF信息不存在或有误,函数也能正常运行。
with contextlib.suppress(Exception):
# PIL.Image.Image.getexif()
# getexif() 函数是 Python Imaging Library (PIL) 的一个扩展库 Pillow 中的一个方法,它用于从 PIL 图像对象中提取 EXIF 数据。EXIF 数据包含了数字照片的元数据,例如拍摄设备、拍摄参数、拍摄日期等信息。
# 参数 :
# Image :PIL 或 Pillow 库中的 Image 类的一个实例。
# 返回值 :
# 返回一个 Exif 对象,该对象包含了图像的 EXIF 数据。这个对象可以像字典一样被访问,其中的键是 EXIF 标签的 ID,值是对应的数据。
# 功能 :
# getexif() 方法读取图像文件中的 EXIF 信息,并将其作为一个可读写的 Exif 对象返回。这个对象允许你访问、修改和删除 EXIF 信息。
# 注意事项 :
# EXIF 数据可能包含多个 IFD(Image File Directory), getexif() 返回的对象默认只访问 IFD0。如果需要访问其他 IFD,可以使用 ExifTags 模块来获取标签的名称,并使用 get_ifd() 方法来访问特定的 IFD。
# 某些图像可能不包含 EXIF 数据,或者在图像处理过程中 EXIF 数据可能被剥离,这种情况下 getexif() 方法可能返回 None 。
# 尝试获取图像的EXIF信息,存储在变量 exif 中。 img.getexif() 返回一个字典,包含图像的EXIF元数据。
exif = img.getexif()
# 检查是否成功获取到EXIF信息。
if exif:
# 从EXIF信息中获取图像的旋转信息。EXIF键274对应于 图像的方向标签 。如果该键不存在,返回 None 。
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
# 检查旋转信息是否为6或8,这分别表示图像旋转了270度或90度。
if rotation in [6, 8]: # rotation 270 or 90
# 如果图像旋转了90度或270度,交换图像的宽度和高度,以修正尺寸。
s = s[1], s[0]
# 返回修正后的图像尺寸。
return s
# exif_size 函数的主要目的是获取PIL图像的尺寸,并根据EXIF信息中的旋转标签修正尺寸。该函数特别处理JPEG格式的图像,忽略其他格式的图像。通过处理EXIF信息,确保图像的尺寸在考虑旋转后是正确的,这对于后续的图像处理和模型训练非常重要。
5.def verify_image(args):
# 这段代码定义了一个名为 verify_image 的函数,其主要功能是验证单个图像文件的有效性。
# 定义了 verify_image 函数,该函数接受一个参数。
# 1.args :通常是一个元组,包含图像文件路径和类别信息,以及一个前缀字符串。
def verify_image(args):
# 验证一张图片。
"""Verify one image."""
# 解包 args 元组,提取 图像文件路径 im_file 、 类别 cls 和 前缀 prefix 。
(im_file, cls), prefix = args
# Number (found, corrupt), message
# 初始化计数器和消息字符串。
# nf :找到的图像数量。
# nc :损坏的图像数量。
# msg :消息字符串,用于记录验证过程中的警告或错误信息。
nf, nc, msg = 0, 0, ""
# 开始一个 try 块,用于捕获可能发生的异常。
try:
# 使用 PIL 库的 Image.open 函数打开图像文件,并调用 im.verify() 方法验证图像文件的完整性。这一步会检查图像文件是否损坏。
im = Image.open(im_file)
im.verify() # PIL verify
# 调用 exif_size 函数获取图像的实际尺寸(考虑EXIF信息中的旋转),并将尺寸转换为 (height, width) 格式。
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
# 断言图像的高度和宽度都大于9像素。如果图像尺寸小于10像素,抛出异常,提示图像尺寸过小。
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" # 图像大小{形状} <10像素。
# 断言图像格式在支持的格式列表 IMG_FORMATS 中。如果图像格式不支持,抛出异常,提示图像格式无效。
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" # 无效的图像格式 {im.format}。
# 如果图像格式为JPEG,检查文件末尾的两个字节是否为JPEG的结束标记 b"\xff\xd9" 。
if im.format.lower() in ("jpg", "jpeg"):
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG
# 如果不符合,表示JPEG文件可能损坏,使用 ImageOps.exif_transpose 重新打开图像,考虑EXIF信息中的旋转,并以高质量(100)和无子采样( subsampling=0 )保存图像,修复损坏的JPEG文件。
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" # {prefix}警告⚠️{im_file}:损坏的 JPEG 已恢复并保存。
# 如果图像验证通过,将 nf 设置为1,表示找到了一个有效的图像。
nf = 1
# 捕获在验证过程中可能发生的任何异常。
except Exception as e:
# 如果发生异常,将 nc 设置为1,表示图像损坏。
nc = 1
# 生成一个警告消息,提示图像或标签损坏,记录异常信息。
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" # {prefix}警告⚠️{im_file}:忽略损坏的图像/标签:{e}。
# 返回一个元组,包含 图像文件路径 和 类别 、 找到的图像数量 、 损坏的图像数量 和 消息字符串 。
return (im_file, cls), nf, nc, msg
# verify_image 函数的主要目的是验证单个图像文件的有效性,确保图像文件没有损坏,并且尺寸和格式符合要求。该函数通过 PIL 库的 Image.open 和 im.verify 方法检查图像文件的完整性。通过断言和异常处理,确保图像的高度和宽度都大于9像素,图像格式在支持的格式列表中。如果图像格式为JPEG且文件损坏,尝试修复并保存图像。函数返回一个元组,包含图像文件路径和类别、找到的图像数量、损坏的图像数量和消息字符串,为调用者提供了详细的验证结果。
6.def verify_image_label(args):
# 这段代码定义了一个名为 verify_image_label 的函数,其主要功能是验证一个图像-标签对的有效性。
# 定义了 verify_image_label 函数,该函数接受一个参数。
# 1.args :通常是一个元组,包含验证所需的各种参数。
def verify_image_label(args):
# 验证一对图像-标签。
"""Verify one image-label pair."""
# 解包 args 元组,提取以下参数 :
# im_file :图像文件路径。
# lb_file :标签文件路径。
# prefix :前缀,用于日志信息。
# keypoint :是否使用关键点。
# num_cls :类别数量。
# nkpt :关键点数量。
# ndim :关键点维度(2或3)。
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
# Number (missing, found, empty, corrupt), message, segments, keypoints
# 初始化计数器和变量。
# nm :缺失的图像数量。
# nf :找到的图像数量。
# ne :为空的图像数量。
# nc :损坏的图像数量。
# msg :消息字符串。
# segments :分割信息列表。
# keypoints :关键点信息。
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
# 开始一个 try 块,用于捕获可能发生的异常。
try:
# 这段代码是 verify_image_label 函数中的一部分,主要功能是验证图像文件的完整性和有效性,并进行必要的修复。
# Verify images
# 使用 PIL 库的 Image.open 函数打开图像文件,存储在变量 im 中。
im = Image.open(im_file)
# im.verify()
# 在 Python 的 PIL(Python Imaging Library)库中, .verify() 方法用于验证图像文件的完整性。当处理图像文件时,这个方法尝试确认文件是否未损坏并且可以被正确解码。
# im :一个 PIL Image 对象。
# 功能 :
# .verify() 方法检查图像文件是否完整且未损坏。如果文件损坏或无法被识别,这个方法会抛出一个 IOError (输入/输出错误)异常。
# 注意事项 :
# .verify() 方法只适用于某些图像格式,特别是那些 PIL 支持的格式。
# 这个方法不会检查图像的元数据或内容,只检查图像文件的完整性和可读性。
# 在处理大量图像文件时,使用 .verify() 方法可以帮助识别和排除损坏的文件,以避免在后续处理中出现问题。
# 调用 im.verify() 方法,验证图像文件的完整性。这个方法会检查图像文件是否损坏,如果图像文件损坏,会抛出异常。
im.verify() # PIL verify
# 调用 exif_size 函数,获取 图像的实际尺寸 (考虑EXIF信息中的旋转)。 exif_size 函数返回一个元组 (width, height) 。
# def exif_size(img: Image.Image): -> 获取并根据EXIF信息(特别是图像的旋转信息)修正PIL图像的尺寸。返回修正后的图像尺寸。 -> return s
shape = exif_size(im) # image size
# 将图像尺寸转换为 (height, width) 格式,存储在变量 shape 中。
shape = (shape[1], shape[0]) # hw
# 断言图像的高度和宽度都大于9像素。如果图像尺寸小于10像素,抛出异常,提示图像尺寸过小。
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" # 图像大小{形状} <10像素。
# 断言图像格式在支持的格式列表 IMG_FORMATS 中。如果图像格式不支持,抛出异常,提示图像格式无效。
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" # 无效的图像格式 {im.format}。
# 检查图像格式是否为JPEG。如果是JPEG格式,进行进一步的检查和修复。
if im.format.lower() in ("jpg", "jpeg"):
# 以二进制读取模式打开图像文件。
with open(im_file, "rb") as f:
# 将文件指针移动到文件末尾前2个字节的位置。这是为了检查JPEG文件的结束标记。
f.seek(-2, 2)
# 读取文件末尾的2个字节,检查是否为JPEG文件的结束标记 b"\xff\xd9" 。如果不符合,表示JPEG文件可能损坏。
if f.read() != b"\xff\xd9": # corrupt JPEG
# PIL.ImageOps.exif_transpose(image, *, in_place=False)
# ImageOps.exif_transpose() 函数是 Python Imaging Library (PIL) 的一个扩展库 Pillow 中的一个函数,它用于根据图像的 EXIF 定向标签来调整图像的方向,使得图像按照 EXIF 标签中指定的方向进行正确的显示。如果图像没有 EXIF 定向标签或者标签值为 1(表示图像已经是正确的方向),则返回图像的一个副本。
# 参数 :
# image :要调整方向的 PIL 图像对象。
# in_place :(关键字参数)布尔值,如果设置为 True ,则在原图像对象上进行修改,并返回 None ;如果设置为 False (默认值),则返回一个新的图像对象,原图像对象不变。
# 返回值 :
# 如果 in_place 参数为 False (默认),返回一个新的图像对象,该对象根据 EXIF 定向标签调整了方向。
# 如果 in_place 参数为 True ,则原图像对象被修改,函数返回 None 。
# 功能 :
# 读取图像的 EXIF 信息,特别是 Orientation 标签。
# 根据 Orientation 标签的值,确定如何调整图像的方向。
# 应用相应的变换(例如旋转、翻转)来调整图像的方向。
# 如果 in_place 为 False ,则返回调整方向后的新图像对象;否则,修改原图像对象。
# 如果JPEG文件损坏,使用 ImageOps.exif_transpose 函数重新打开图像,考虑EXIF信息中的旋转,并以高质量(100)和无子采样( subsampling=0 )保存图像,修复损坏的JPEG文件。
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
# 生成一个警告消息,提示损坏的JPEG文件已修复并保存。
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" # {prefix}警告⚠️{im_file}:损坏的 JPEG 已恢复并保存。
# 这段代码的主要目的是验证图像文件的完整性和有效性,确保图像文件没有损坏,并且尺寸和格式符合要求。如果发现JPEG文件损坏,会尝试修复并保存图像,确保图像文件的完整性。通过断言和异常处理,确保图像文件的尺寸和格式符合要求,为后续的数据处理和模型训练提供准备。
# 这段代码是 verify_image_label 函数中的一部分,主要功能是验证标签文件的有效性。
# Verify labels
# 检查标签文件是否存在。如果存在,继续进行验证。
if os.path.isfile(lb_file):
# 如果标签文件存在,将 nf (label found)设置为1,表示找到了标签文件。
nf = 1 # label found
# 以只读模式打开标签文件。
with open(lb_file) as f:
# 读取标签文件的内容,去除首尾空白字符,按行分割,然后将每行分割成列表。 lb 是一个二维列表,其中每个子列表表示一个标签。
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
# 检查是否有任何标签的长度大于6,并且当前任务不是关键点检测( keypoint 为 False ),这表示标签文件中可能包含分割信息。
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
# 提取每个标签的 类别信息 ,存储在 classes 数组中。
classes = np.array([x[0] for x in lb], dtype=np.float32)
# 将每个标签的 分割信息 提取出来,重塑为 (n, 2) 的数组,其中 n 是 分割点的数量 。 segments 是一个列表,每个元素是一个二维数组,表示一个分割的点集。
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
# 将 类别信息 和 分割信息 转换为边界框格式,使用 segments2boxes 函数将 分割信息 转换为 边界框信息 ,并与 类别信息 拼接成一个新的数组 lb 。
# def segments2boxes(segments):
# -> 用于将多边形线段(segments)转换为矩形框(boxes),并将矩形框的格式从左上角和右下角坐标(xyxy)转换为中心点坐标加宽高(xywh)。将 boxes 列表转换为NumPy数组,并调用 xyxy2xywh 函数将矩形框的格式从左上角和右下角坐标(xyxy)转换为中心点坐标加宽高(xywh),然后返回转换后的矩形框数组。
# -> return xyxy2xywh(np.array(boxes)) # cls, xywh
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
# 将 标签列表 lb 转换为浮点数数组。
lb = np.array(lb, dtype=np.float32)
# 获取 标签的数量 nl 。
nl = len(lb)
# 如果 标签数量 nl 大于0,表示标签文件不为空。
if nl:
# 如果当前任务是关键点检测( keypoint 为 True ),检查标签的列数是否为 (5 + nkpt * ndim) ,其中 nkpt 是关 键点的数量 , ndim 是 关键点的维度 (2或3)。提取 关键点信息 ,重塑为 (n, ndim) 的数组,并取前两列(x, y坐标)。
if keypoint:
assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" # 每个标签需要 {(5 + nkpt * ndim)} 列。
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
# 如果当前任务不是关键点检测,检查标签的列数是否为5(类别, x, y, w, h)。提取边界框信息。
else:
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" # 标签需要 5 列,检测到 {lb.shape[1]} 列。
points = lb[:, 1:]
# 检查所有坐标值是否在0到1之间(归一化坐标)。如果坐标值大于1,表示坐标未归一化或超出边界。
assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}" # 非规范化或超出范围的坐标 {points[points > 1]}。
# 检查所有标签值是否非负。如果有负值,表示标签文件中存在无效值。
assert lb.min() >= 0, f"negative label values {lb[lb < 0]}" # 负标签值 {lb[lb < 0]}。
# 这段代码的主要目的是验证标签文件的有效性,确保标签文件存在且格式正确。对于分割任务,将分割信息转换为边界框格式。对于关键点检测任务,检查标签的列数和坐标值的有效性。通过断言和异常处理,确保标签文件的格式和内容符合要求,为后续的数据处理和模型训练提供准备。
# 这段代码继续 verify_image_label 函数中对标签的验证过程,主要处理标签的类别范围、去除重复标签、处理空标签或缺失标签的情况,以及对关键点信息的处理。
# All labels
# 计算 标签数组 lb 中类别列(第一列)的最大值,存储在变量 max_cls 中,表示 最大的类别标签 。
max_cls = lb[:, 0].max() # max label count
# 断言 最大类别标签 max_cls 不超过 数据集的类别总数 num_cls 。如果超过,抛出异常,提示标签类别超出数据集的类别范围。
assert max_cls <= num_cls, (
f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " # 标签类别 {int(max_cls)} 超出数据集类别数量 {num_cls}。
f"Possible class labels are 0-{num_cls - 1}" # 可能的类别标签为 0-{num_cls - 1}。
)
# 使用 np.unique 函数找出 lb 中唯一的行,并返回这些行的索引 i 。这一步用于 检查和去除重复的标签 。
_, i = np.unique(lb, axis=0, return_index=True)
# 如果 去重后的标签数量 少于 原始标签数量 nl ,表示 存在重复标签 。
if len(i) < nl: # duplicate row check
# 使用索引 i 去除重复的标签,更新 lb 数组。
lb = lb[i] # remove duplicates
# 如果存在分割信息,也相应地去除重复的分割信息。
if segments:
segments = [segments[x] for x in i]
# 生成一个警告消息,提示去除了多少个重复的标签。
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" # {prefix}警告 ⚠️ {im_file}: {nl - len(i)} 重复标签已删除。
# 如果标签数量 nl 为0,表示 标签文件为空 ,设置 ne (label empty)为1,并生成一个 空的标签数组 lb 。标签数组的列数根据是否使用关键点而定。
else:
ne = 1 # label empty
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
# 如果标签文件不存在,设置 nm (label missing)为1,并生成一个 空的标签数组 lb 。这里有一个小错误,应该是 keypoint 而不是 keypoints 。
else:
nm = 1 # label missing
lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32) # ❌⚠️ 这行代码应修改为 : lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
# 如果使用关键点,提取关键点信息,重塑为 (nl, nkpt, ndim) 的数组。
if keypoint:
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
# 如果关键点的维度为2,生成一个掩码数组 kpt_mask ,用于标记无效的关键点(x或y坐标小于0)。将掩码信息添加到关键点数组中,使其维度变为 (nl, nkpt, 3) 。
if ndim == 2:
kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
# 将 标签数组 lb 截取为前5列,即 类别 和 边界框信息 。
lb = lb[:, :5]
# 返回验证结果,包括 图像文件路径 、 标签数组 、 图像尺寸 、 分割信息 、 关键点信息 、 缺失 、 找到 、 为空 、 损坏的计数 以及 消息字符串 。
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
# 这段代码的主要目的是完成标签的验证,包括检查类别范围、去除重复标签、处理空标签或缺失标签的情况,以及对关键点信息的处理。通过这些步骤,确保标签文件的完整性和一致性,为后续的数据处理和模型训练提供准备。
# 这段代码是 verify_image_label 函数中的异常处理部分,用于捕获并处理在验证图像和标签过程中可能出现的任何异常。
# 捕获在验证过程中可能发生的任何异常,并将异常对象存储在变量 e 中。
except Exception as e:
# 将 nc (corrupt count,损坏计数)设置为1,表示当前图像或标签文件损坏。
nc = 1
# 生成一个警告消息,包含前缀 prefix 、图像文件路径 im_file 和异常信息 e ,提示当前图像或标签文件损坏,将被忽略。
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" # {prefix}警告⚠️{im_file}:忽略损坏的图像/标签:{e} 。
# 返回一个包含 None 值的列表,表示验证失败。列表中的元素分别对应 图像文件路径 、 标签数组 、 图像尺寸 、 分割信息 、 关键点信息 、 缺失计数 nm 、 找到计数 nf 、 为空计数 ne 、 损坏计数 nc 和 消息 msg 。
return [None, None, None, None, None, nm, nf, ne, nc, msg]
# 这段代码的主要目的是处理验证过程中可能出现的异常,确保即使在遇到损坏的图像或标签文件时,函数也能正常返回,不会导致整个程序崩溃。通过设置损坏计数 nc 为1和生成警告消息,记录了异常情况,方便后续的调试和数据清理。返回的列表中包含 None 值,表示当前图像和标签文件在验证过程中被忽略,为调用者提供了明确的反馈。
# verify_image_label 函数的主要目的是验证一个图像-标签对的有效性,包括图像文件的完整性和标签文件的格式正确性。该函数处理了图像文件的损坏修复、标签文件的解析和验证,确保图像和标签数据的一致性和有效性。通过返回详细的验证结果,该函数为数据预处理和缓存生成提供了基础,帮助后续的数据加载和模型训练。
7.def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
# 这段代码定义了一个名为 polygon2mask 的函数,用于将多边形转换为掩码图像。
# 定义函数 polygon2mask ,它有四个参数。
# 1.imgsz :图像尺寸,是一个元组,如 (height, width)。
# 2.polygons :多边形的顶点坐标列表。
# 3.color :填充多边形的颜色,默认为1。
# 4.downsample_ratio :下采样比例,默认为1,用于调整最终掩码图像的尺寸。
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
# 将多边形列表转换为指定图像大小的二进制掩码。
# polygons (list[np.ndarray]): polygons 列表。每个多边形都是一个形状为 [N, M] 的数组,其中 N 是多边形的数量,M 是点的数量,使得 M % 2 = 0。
"""
Convert a list of polygons to a binary mask of the specified image size.
Args:
imgsz (tuple): The size of the image as (height, width).
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
N is the number of polygons, and M is the number of points such that M % 2 = 0.
color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1.
downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1.
Returns:
(np.ndarray): A binary mask of the specified image size with the polygons filled in.
"""
# 创建一个与图像尺寸 imgsz 相同大小的掩码图像 mask ,初始值全为0,数据类型为 uint8 ,即无符号8位整数。
mask = np.zeros(imgsz, dtype=np.uint8)
# 将输入的多边形顶点坐标列表 polygons 转换为 numpy 数组,并指定数据类型为 int32 ,即32位整数,以满足后续 cv2.fillPoly 函数对顶点坐标数据类型的要求。
polygons = np.asarray(polygons, dtype=np.int32)
# 将多边形顶点坐标数组 polygons 重塑为 (n, -1, 2) 的形状,其中 n 为多边形的数量, -1 表示自动计算该维度的大小,2表示每个顶点坐标由两个数值(x和y)组成,这样可以使其符合 cv2.fillPoly 函数对顶点坐标输入格式的要求。
polygons = polygons.reshape((polygons.shape[0], -1, 2))
# cv2.fillPoly(img, pts, color)
# cv2.fillPoly 是 OpenCV 库中的一个函数,它用于在图像中填充一个或多个多边形。
# 参数说明 :
# img :目标图像,它是一个 NumPy 数组,可以是灰度图或彩色图。这个函数会直接在原图像上进行操作,所以如果需要保留原始图像,应该先复制一份。
# pts :一个或多个多边形的顶点坐标。它是一个列表,其中每个元素是一个 NumPy 数组,包含多边形顶点的坐标。对于单个多边形, pts 可以是 (numpy_array,) 的形式,其中 numpy_array 包含多边形顶点的坐标。对于多个多边形, pts 可以是 (numpy_array1, numpy_array2, ...) 的形式,其中每个 numpy_array 包含一个多边形的顶点坐标。
# color :填充多边形的颜色。对于灰度图像,它是一个单通道的灰度值;对于彩色图像,它是一个 (B, G, R) 元组,分别代表蓝色、绿色和红色通道的值。
# 返回值 :该函数不返回任何值,它直接在输入的图像 img 上进行操作。
# 使用 cv2.fillPoly 函数在掩码图像 mask 上根据多边形顶点坐标 polygons 填充指定颜色 color ,将多边形区域内的像素值设置为 color 。
cv2.fillPoly(mask, polygons, color=color)
# 根据下采样比例 downsample_ratio 计算 最终掩码图像 的高 nh 和宽 nw ,即原图像尺寸 imgsz 的高和宽分别除以下采样比例。
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
# Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1 注意:先 fillPoly 然后 resize 是为了在 mask-ratio=1 时保持相同的损失计算方法。
# 使用 cv2.resize 函数将填充后的掩码图像 mask 调整为新的尺寸 (nw, nh) ,并返回调整后的掩码图像。
return cv2.resize(mask, (nw, nh))
# 这段代码通过创建一个初始全0的掩码图像,然后根据输入的多边形顶点坐标在掩码图像上填充指定颜色,最后根据下采样比例调整掩码图像的尺寸,实现了将多边形转换为掩码图像的功能,常用于计算机视觉任务中,如目标检测、图像分割等场景,用于生成用于训练或评估的掩码标签。
8.def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
# 这段代码定义了一个名为 polygons2masks 的函数,用于将多个多边形批量转换为掩码图像数组。
# 定义函数 polygons2masks ,它有四个参数。
# 1.imgsz :图像尺寸,是一个元组,如 (height, width)。
# 2.polygons :一个包含多个多边形顶点坐标列表的列表。
# 3.color :填充多边形的颜色。
# 4.downsample_ratio :下采样比例,默认为1,用于调整最终掩码图像的尺寸。
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
# 将多边形列表转换为一组指定图像大小的二进制掩码。
# polygons (list[np.ndarray]):多边形列表。每个多边形都是一个形状为 [N, M] 的数组,其中 N 是多边形的数量,M 是点的数量,使得 M % 2 = 0。
"""
Convert a list of polygons to a set of binary masks of the specified image size.
Args:
imgsz (tuple): The size of the image as (height, width).
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
N is the number of polygons, and M is the number of points such that M % 2 = 0.
color (int): The color value to fill in the polygons on the masks.
downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1.
Returns:
(np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
"""
# 使用列表推导式对输入的每个多边形 x 进行处理。
# x.reshape(-1) 将每个多边形的顶点坐标数组 x 重塑为一维数组,因为 polygon2mask 函数内部需要将多边形顶点坐标重塑为 (n, -1, 2) 的形状,这里先将其变为一维数组是为了方便后续处理。
# [x.reshape(-1)] 将重塑后的一维数组 x 放入一个列表中,因为 polygon2mask 函数的 polygons 参数需要是一个列表,即使只处理一个多边形也需要以列表形式传入。
# polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) 调用之前定义的 polygon2mask 函数,将单个多边形转换为掩码图像。
# 列表推导式对所有多边形进行上述操作,生成一个包含所有掩码图像的列表。
# np.array(...) 将这个列表转换为 numpy 数组,最终返回这个数组,数组中的每个元素是一个掩码图像,对应输入的每个多边形。
return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
# 这段代码通过调用 polygon2mask 函数,实现了将多个多边形批量转换为掩码图像数组的功能。它利用列表推导式对每个多边形进行处理,并将结果收集到一个 numpy 数组中返回,方便后续在计算机视觉任务中对多个掩码图像进行统一处理和操作,如批量训练模型或进行批量评估等。
9.def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
# 这段代码定义了一个名为 polygons2masks_overlap 的函数,用于处理多个可能存在重叠的多边形,并生成一个掩码图像,同时返回多边形的排序索引。
# 定义函数 polygons2masks_overlap ,它有三个参数。
# 1.imgsz :图像尺寸,是一个元组,如 (height, width) 。
# 2.segments :一个包含多个多边形顶点坐标列表的列表。
# 3.downsample_ratio :下采样比例,默认为1,用于调整最终掩码图像的尺寸。
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
# 返回(640,640)重叠掩码。
"""Return a (640, 640) overlap mask."""
# 创建一个与下采样后的图像尺寸相同大小的掩码图像 masks ,初始值全为0。
masks = np.zeros(
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
# 数据类型根据多边形的数量 len(segments) 决定,如果多边形数量超过255,则使用 np.int32 ,否则使用 np.uint8 ,以确保能够容纳足够的标签值。
dtype=np.int32 if len(segments) > 255 else np.uint8,
)
# 初始化两个列表。
# areas 用于存储 每个多边形掩码的面积 。
areas = []
# ms 用于存储 每个多边形的掩码图像 。
ms = []
# 遍历每个多边形 segments[si] 。
for si in range(len(segments)):
# 使用 polygon2mask 函数生成单个多边形的掩码图像 mask ,颜色设置为1。
mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
# 将生成的 掩码图像 mask 添加到列表 ms 中。
ms.append(mask)
# 计算掩码图像 mask 的面积(即掩码中值为1的像素数量),并将面积添加到列表 areas 中。
areas.append(mask.sum())
# 将面积列表 areas 转换为 numpy 数组。
areas = np.asarray(areas)
# np.argsort(a, axis=-1, kind=None, order=None)
# np.argsort() 是 NumPy 库中的一个函数,它返回数组元素从小到大的索引排序。这个函数对于一维和多维数组都适用,并且可以沿着指定的轴进行操作。、
# 参数 :
# a :要排序的数组。
# axis :指定要排序的轴。默认为 -1 ,即最后一个轴。如果数组是一维的,则此参数将被忽略。
# kind :指定排序算法的类型。默认为 None ,即让 NumPy 自动选择。其他选项包括 'quicksort'、'mergesort' 等。
# order :当数组元素是结构化数组时,用于指定排序的字段。
# 返回值 :
# 返回一个数组,其中包含输入数组 a 中元素从小到大排序后的索引。
# 异常 :
# 如果输入数组 a 不是 NumPy 数组,可能会抛出异常。
# np.argsort 是一个非常有用的函数,它允许你在不改变原始数组的情况下,快速获取排序后的元素索引,这在很多数据处理和分析任务中都非常有用。
# 使用 np.argsort(-areas) 获取按面积 从大到小 排序的索引 index 。
index = np.argsort(-areas)
# 根据索引 index 对掩码图像列表 ms 进行排序,确保面积较大的多边形先被处理。
ms = np.array(ms)[index]
# 遍历排序后的掩码图像列表 ms 。
for i in range(len(segments)):
# 将当前掩码图像 ms[i] 乘以 (i + 1) ,为每个多边形分配一个唯一的标签值。
mask = ms[i] * (i + 1)
# 将处理后的掩码图像 mask 加到总掩码图像 masks 上。
masks = masks + mask
# 使用 np.clip 函数将 masks 中的值限制在 [0, i + 1] 范围内,确保不会出现超出预期的标签值。
masks = np.clip(masks, a_min=0, a_max=i + 1)
# 返回最终的 掩码图像 masks 和 多边形的排序索引 index 。
return masks, index
# 这段代码通过处理多个可能存在重叠的多边形,生成一个掩码图像,其中每个多边形被分配一个唯一的标签值。通过先计算每个多边形的面积并按面积排序,确保面积较大的多边形先被处理,从而在重叠区域中优先保留面积较大的多边形的标签。最终返回的掩码图像和排序索引可以用于后续的图像处理和分析任务,如目标检测、实例分割等。
10.def find_dataset_yaml(path: Path) -> Path:
# 这段代码定义了一个名为 find_dataset_yaml 的函数,用于在指定路径下查找YAML文件。
# 定义函数 find_dataset_yaml ,接收一个 Path 类型的参数,并且返回值也是 Path 类型,即返回找到的YAML文件的路径。
# 1.path :要查找的路径
def find_dataset_yaml(path: Path) -> Path:
# 查找并返回与 Detect、Segment 或 Pose 数据集关联的 YAML 文件。
# 此函数首先在提供的目录的根级别搜索 YAML 文件,如果未找到,则执行递归搜索。它首选具有与提供的路径相同的词干的 YAML 文件。如果未找到 YAML 文件或找到多个 YAML 文件,则会引发 AssertionError。
"""
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
is raised if no YAML file is found or if multiple YAML files are found.
Args:
path (Path): The directory path to search for the YAML file.
Returns:
(Path): The path of the found YAML file.
"""
# 首先尝试在路径 path 的根目录下查找所有扩展名为 .yaml 的文件,使用 path.glob("*.yaml") 实现。如果在根目录下没有找到,则使用 path.rglob("*.yaml") 进行递归查找,即在 path 及其所有子目录下查找 .yaml 文件。将找到的文件列表赋值给变量 files 。
files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
# 使用 assert 语句断言 files 不为空,即至少找到一个YAML文件。如果 files 为空,说明没有找到YAML文件,则会抛出异常,并输出提示信息,显示在路径 path.resolve() 下没有找到YAML文件。
assert files, f"No YAML file found in '{path.resolve()}'" # 在“{path.resolve()}”中未找到 YAML 文件。
# 如果找到的YAML文件数量大于1,则进入该条件判断语句。
if len(files) > 1:
# Path(path)
# 在Python中, Path 是 pathlib 模块中的一个类,用于表示文件系统路径。 pathlib 是一个现代的文件路径操作库,它提供了面向对象的方式来处理文件和目录路径。
# 导入 Path 类 : from pathlib import Path。
# 可以使用 Path 类来创建一个路径对象。这个对象可以是一个文件或者目录的路径。
# 路径操作 :
# Path 对象提供了许多方法来操作路径。
# p.exists() :检查路径是否存在。
# p.is_file() :检查路径是否指向一个文件。
# p.is_dir() :检查路径是否指向一个目录。
# p.resolve() :解析路径,返回绝对路径。
# p.parent :返回路径的父目录。
# p.name :返回路径的最后一部分(文件名)。
# p.suffix :返回文件的后缀名。
# p.stem :返回文件名不包括后缀的部分。
# 在找到的多个YAML文件中,筛选出文件名(不包含扩展名)与路径 path 的文件名(不包含扩展名)相同的文件,即优先选择与路径同名的YAML文件,将筛选后的文件列表重新赋值给 files 。
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
# 再次使用 assert 语句断言 files 的长度为1,即期望找到一个YAML文件。如果找到的文件数量不是1,则抛出异常,并输出提示信息,显示期望在路径 path.resolve() 下找到1个YAML文件,但实际找到了 len(files) 个,并列出这些文件。
assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}" # '{path.resolve()}' 中预期有 1 个 YAML 文件,但找到了 {len(files)}。\n{files} 。
# 返回 找到的YAML文件的路径 ,即 files 列表中的第一个元素。
return files[0]
# 此函数的主要目的是在给定路径下查找YAML文件。它先在根目录查找,若未找到则递归查找。若找到多个文件,则优先选择与路径同名的文件,最终确保返回一个YAML文件的路径。通过断言语句保证了查找过程的正确性和唯一性,若不符合预期则会抛出异常并给出提示,方便调试和使用。
11.def check_det_dataset(dataset, autodownload=True):
# 这段代码定义了一个名为 check_det_dataset 的函数,用于检查目标检测数据集的完整性和正确性,并进行一些必要的处理。
# 定义函数 check_det_dataset ,接收两个参数。
# 1.dataset :数据集的路径或名称。
# 2.autodownload :一个布尔值,默认为 True ,表示是否自动下载数据集。
def check_det_dataset(dataset, autodownload=True):
# 如果在本地找不到数据集,则下载、验证和/或解压缩数据集。
# 此函数检查指定数据集的可用性,如果未找到,则可选择下载并解压缩数据集。然后,它会读取并解析随附的 YAML 数据,确保满足关键要求,并解析与数据集相关的路径。
"""
Download, verify, and/or unzip a dataset if not found locally.
This function checks the availability of a specified dataset, and if not found, it has the option to download and
unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
resolves paths related to the dataset.
Args:
dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
Returns:
(dict): Parsed dataset information and paths.
"""
# 调用 check_file 函数检查 dataset 指定的文件是否存在,返回文件的路径。
# def check_file(file, suffix="", download=True, hard=True):
# -> 用于检查文件的存在性,并根据需要下载或搜索文件。如果满足上述条件之一,则返回文件名。返回下载后的文件名。根据搜索结果返回文件名或抛出错误。
# -> return file / return file / return files[0] if len(files) else [] if hard else file # return file
file = check_file(dataset)
# 这段代码是 check_det_dataset 函数的一部分,主要负责检查数据集文件是否为压缩文件,若是则进行下载和解压操作,并读取数据集的YAML配置文件。
# Download (optional)
# 定义一个空字符串 extract_dir ,用于后续存储解压后的目录路径。
extract_dir = ""
# is_zip = zipfile.is_zipfile(filename)
# is_zipfile() 函数是 Python zipfile 模块中的一个函数,用于检查一个文件是否是有效的 ZIP 文件格式。
# 参数 :
# filename : 要检查的文件的路径,可以是字符串、文件对象或路径对象。
# 返回值 :
# is_zipfile() 函数返回一个布尔值。 如果文件是有效的 ZIP 文件,则返回 True 。 如果文件不是有效的 ZIP 文件或文件不存在,则返回 False 。
# is_zipfile() 函数的实现依赖于文件的“魔术数字”(文件开头的字节序列),这是许多文件格式用来标识自己的一种方式。ZIP 文件的魔术数字是 PK ( 0x50 0x4B ),这个序列出现在所有 ZIP 文件的开头。
# 如果一个文件以这个序列开头, is_zipfile() 函数就会返回 True ,表明该文件是一个 ZIP 文件。这个函数在处理文件上传、归档和解压缩任务时非常有用,因为它可以帮助程序确定如何处理特定的文件。
# is_tar = tarfile.is_tarfile(name)
# is_tarfile() 函数是 Python tarfile 模块中的一个函数,用于检查一个文件是否是有效的 tar 归档文件格式。
# 参数 :
# name : 要检查的文件的路径,可以是字符串、文件对象或路径对象。
# 返回值 :
# is_tarfile() 函数返回一个布尔值。 如果文件是有效的 tar 归档文件,则返回 True 。 如果文件不是有效的 tar 归档文件或文件不存在,则返回 False 。
# is_tarfile() 函数的实现依赖于文件的“魔术数字”(文件开头的字节序列),这是许多文件格式用来标识自己的一种方式。tar 归档文件的魔术数字是 ustar (在文件的 257 字节处),这个序列出现在所有 tar 归档文件的特定位置。
# 如果一个文件在指定位置包含这个序列, is_tarfile() 函数就会返回 True ,表明该文件是一个 tar 归档文件。这个函数在处理文件上传、归档和解压缩任务时非常有用,因为它可以帮助程序确定如何处理特定的文件。
# 判断 file 变量所指向的文件是否为zip格式或tar格式的压缩文件。 zipfile.is_zipfile 和 is_tarfile 分别是用来检测文件是否为zip和tar压缩文件的函数。
if zipfile.is_zipfile(file) or is_tarfile(file):
# 如果文件是压缩文件,则调用 safe_download 函数。该函数会下载文件到指定目录 DATASETS_DIR ,并进行解压( unzip=True 表示解压, delete=False 表示不解压后不删除压缩文件),返回解压后的新目录名称赋值给 new_dir 变量。
# def safe_download(url, file=None, dir=None, unzip=True, delete=False, curl=False, retry=3, min_bytes=1e0, exist_ok=False, progress=True,): -> 用于安全地下载文件,并在需要时解压。返回解压目录。 -> return unzip_dir
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
# 调用 find_dataset_yaml 函数,在解压后的目录 DATASETS_DIR / new_dir 中查找数据集的YAML配置文件,并将该文件的路径赋值给 file 变量。
# def find_dataset_yaml(path: Path) -> Path: -> 用于在指定路径下查找YAML文件。返回 找到的YAML文件的路径 ,即 files 列表中的第一个元素。 -> return files[0]
file = find_dataset_yaml(DATASETS_DIR / new_dir)
# 将找到的 YAML配置文件的父目录 路径赋值给 extract_dir ,并将 autodownload 变量设置为 False ,表示 后续不再自动下载 。
extract_dir, autodownload = file.parent, False
# Read YAML
# 调用 yaml_load 函数读取 file 变量所指向的YAML配置文件。 append_filename=True 参数表示在读取过程中将文件名附加到数据中。读取后的配置数据以字典形式存储在 data 变量中。
# def yaml_load(file="data.yaml", append_filename=False): -> 用于加载YAML文件并返回其内容作为字典。返回解析后的字典。 -> return data
data = yaml_load(file, append_filename=True) # dictionary
# 这段代码首先检查给定的数据集文件是否为压缩格式,若是,则通过 safe_download 函数进行下载和解压操作,并在解压后的目录中查找YAML配置文件。找到YAML文件后,通过 yaml_load 函数读取其内容,并将读取到的配置数据存储在字典 data 中,为后续的数据集检查和处理操作提供基础数据。同时,通过设置 extract_dir 和 autodownload 变量,为后续的路径解析和下载逻辑做准备。
# 这段代码是对数据集的YAML配置文件内容进行一系列检查和处理,以确保数据集配置的完整性和正确性。
# Checks
# 遍历字符串 "train" 和 "val" ,分别检查这两个键是否存在于数据集的配置字典 data 中。
for k in "train", "val":
# 如果当前键( "train" 或 "val" )不在 data 字典中。
if k not in data:
# 进一步判断,如果当前键不是 "val" ,或者 "validation" 键也不在 data 中。
if k != "val" or "validation" not in data:
# 则抛出语法错误异常,提示用户 "train" 和 "val" 键在所有数据集的YAML配置文件中都是必需的。
raise SyntaxError(
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.") # {dataset} '{k}:' 键缺少 ❌。\n所有数据 YAML 中都需要 'train' 和 'val'。
)
# 如果 "val" 键不存在,但 "validation" 键存在,则记录一条警告信息,提示将 "validation" 键重命名为 "val" ,以符合YOLO格式。
LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") # 警告⚠️将数据 YAML“验证”键重命名为“val”以匹配 YOLO 格式。
# 将 "validation" 键的值赋给 "val" 键,并从字典中删除 "validation" 键。
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
# 检查 "names" 和 "nc" 这两个键是否都不在 data 字典中。
if "names" not in data and "nc" not in data:
# 如果都不在,则抛出语法错误异常,提示用户在所有数据集的YAML配置文件中, "names" 或 "nc" 至少有一个是必需的。
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs.")) # {dataset} 键缺失 ❌。\n 所有数据 YAML 中都需要“names”或“nc”。
# 如果 "names" 和 "nc" 键都存在,但 "names" 列表的长度与 "nc" 的值不相等。
if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
# 则抛出语法错误异常,提示用户 "names" 的长度和 "nc" 的值必须匹配。
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match.")) # {dataset} ‘names’ 长度 {len(data['names'])} 和 ‘nc: {data['nc']}’ 必须匹配。
# 如果 "names" 键不存在。
if "names" not in data:
# 则根据 "nc" 的值生成一个默认的类别名称列表,列表中的每个元素格式为 "class_i" ,其中 i 是类别索引。
data["names"] = [f"class_{i}" for i in range(data["nc"])]
# 如果 "names" 键存在。
else:
# 则将 "names" 列表的长度赋值给 "nc" 键,确保 "nc" 的值与类别名称列表的长度一致。
data["nc"] = len(data["names"])
# 调用 check_class_names 函数对 "names" 列表中的类别名称进行检查,确保它们的有效性,并将检查后的结果重新赋值给 "names" 键。
# def check_class_names(names): -> 验证和转换类名列表或字典,以确保它们符合特定的格式要求。返回处理后的 names 字典。 -> return names
data["names"] = check_class_names(data["names"])
# 这段代码通过一系列的检查和处理,确保了数据集配置的完整性和正确性。它首先检查 "train" 和 "val" 这两个关键的训练和验证集配置是否缺失,并处理了 "validation" 键到 "val" 键的兼容性问题。接着,它检查类别名称列表 "names" 和类别数量 "nc" 的配置是否正确,包括它们是否存在、是否匹配等,并在必要时进行自动修正或生成默认值。最后,通过调用 check_class_names 函数对类别名称进行有效性检查,确保数据集配置能够正确地用于后续的模型训练和验证过程。
# 这段代码主要负责解析和设置数据集相关路径,确保路径的正确性和完整性。
# Resolve paths
# 这行代码用于确定 数据集的根目录路径 。首先尝试使用 extract_dir (之前解压目录的路径),如果 extract_dir 为空,则尝试从 data 字典中获取 "path" 键的值,如果 "path" 键也不存在,则使用YAML文件的父目录作为数据集根目录。这里使用 Path 类来处理路径,确保路径操作的正确性。
path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
# 判断 path 是否为绝对路径。
if not path.is_absolute():
# 如果不是绝对路径,则将其与 DATASETS_DIR (数据集目录)拼接,并 解析为绝对路径 。这样可以确保路径的正确性,即使在YAML配置文件中使用了相对路径。
path = (DATASETS_DIR / path).resolve()
# Set paths
# 将 解析后的数据集根目录路径 赋值给 data 字典中的 "path" 键。这一步是为了更新配置字典,使其包含正确的数据集根目录路径,便于后续的下载脚本等操作使用。
data["path"] = path # download scripts
# 遍历 "train" 、 "val" 和 "test" 这三个键,分别处理 训练集 、 验证集 和 测试集 的路径。
for k in "train", "val", "test":
# 如果当前键( "train" 、 "val" 或 "test" )在 data 字典中存在,则进行路径处理。
if data.get(k): # prepend path
# 如果当前键的值是字符串类型,说明该数据集的路径是一个单一的路径。
if isinstance(data[k], str):
# 将数据集根目录路径与当前键的值拼接,并 解析为绝对路径 ,赋值给变量 x 。
x = (path / data[k]).resolve()
# 如果解析后的路径 x 不存在,并且当前键的值以 "../" 开头,说明可能存在路径解析问题。
if not x.exists() and data[k].startswith("../"):
# 去掉 "../" 后重新解析路径,尝试修正路径问题。
x = (path / data[k][3:]).resolve()
# 将修正后的路径转换为字符串,重新赋值给 data 字典中的当前键。
data[k] = str(x)
# 如果当前键的值不是字符串类型,说明可能是路径列表。
else:
# 遍历路径列表中的每个路径,将数据集根目录路径与每个路径拼接,并 解析为绝对路径 ,然后将所有解析后的路径转换为字符串列表,重新赋值给 data 字典中的当前键。
data[k] = [str((path / x).resolve()) for x in data[k]]
# 这段代码通过解析和设置路径,确保了数据集配置中路径的正确性和完整性。它首先确定数据集的根目录路径,然后根据根目录路径更新训练集、验证集和测试集的路径。对于路径的处理,代码考虑了相对路径和可能的路径解析问题,通过拼接和解析操作,确保最终的路径是绝对路径,并且路径存在。这样可以避免在后续的数据加载和处理过程中出现路径错误的问题。
# 这段代码主要负责解析YAML配置文件中的特定字段,检查数据集文件是否存在,以及在需要时自动下载数据集。最后,它还会检查并下载所需的字体文件。
# Parse YAML
# 从 data 字典中获取 "val" 和 "download" 键的值,分别赋值给变量 val 和 s 。 val 变量用于 存储验证集的路径 , s 变量用于 存储下载指令或URL 。
val, s = (data.get(x) for x in ("val", "download"))
# 如果 val 变量不为空,说明配置文件中指定了验证集的路径。
if val:
# 将 val 变量中的路径(可能是单个路径或路径列表) 解析为绝对路径 。如果 val 是单个路径,则将其转换为列表,然后解析每个路径。
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
# 检查 val 列表中的 所有路径是否都存在 。如果存在任何一个路径不存在。
if not all(x.exists() for x in val):
# 调用 clean_url 函数清理数据集名称中的URL认证信息,得到 干净的数据集名称 赋值给 name 变量。
# def clean_url(url): -> 用于清理和规范化URL。使用 split("?") 将URL分割为查询参数前的部分和查询参数部分,取第一部分,即不包含查询参数的URL。 -> return urllib.parse.unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth
name = clean_url(dataset) # dataset name with URL auth stripped
# os.path.exists(path)
# exists() 函数是 Python os.path 模块中的一个函数,用于检查指定路径的文件或目录是否存在。
# 参数 :
# path :要检查的文件或目录的路径。
# 返回值 :
# 返回 True :如果指定的路径存在。
# 返回 False :如果指定的路径不存在。
# 功能 :
# exists() 函数用于检查给定的路径是否指向一个存在的文件或目录。
# 它不会区分路径是指向文件还是目录,只要路径存在,它就返回 True 。
# 注意事项 :
# exists() 函数不会抛出异常,即使路径不存在,它也会安静地返回 False 。
# 如果你需要区分路径是文件还是目录,可以使用 os.path.isfile() 和 os.path.isdir() 函数。
# 这个函数不适用于检查路径是否可访问(例如,是否有读取或写入权限),它只检查路径是否存在。
# 生成一条提示信息,说明缺失的验证集路径。
m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'" # 未找到数据集“{name}”图像⚠️,缺少路径“{[x for x in val if not x.exists()][0]}” 。
# 如果 s 变量不为空且 autodownload 参数为 True ,说明配置了自动下载。
if s and autodownload:
# 记录一条警告信息,提示用户数据集文件未找到,但将尝试自动下载。
LOGGER.warning(m)
# 如果未配置自动下载。
else:
# 在提示信息中添加数据集下载目录的说明,并提示用户可以在 SETTINGS_YAML 文件中更新下载目录。
m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'" # 注意数据集下载目录为“{DATASETS_DIR}”。您可以在“{SETTINGS_YAML}”中更新此目录。
# 抛出 FileNotFoundError 异常,提示用户数据集文件未找到。
raise FileNotFoundError(m)
# 记录当前时间,用于计算下载操作的耗时。
t = time.time()
# 初始化变量 r 为 None ,表示成功。
r = None # success
# 如果 s 是一个以 .zip 结尾的HTTP URL。
if s.startswith("http") and s.endswith(".zip"): # URL
# 调用 safe_download 函数下载并解压该URL指向的压缩文件到 DATASETS_DIR 目录,并在解压后删除压缩文件。
# def safe_download(url, file=None, dir=None, unzip=True, delete=False, curl=False, retry=3, min_bytes=1e0, exist_ok=False, progress=True,): -> 用于安全地下载文件,并在需要时解压。返回解压目录。 -> return unzip_dir
safe_download(url=s, dir=DATASETS_DIR, delete=True)
# 如果 s 是一个以 "bash " 开头的bash脚本。
elif s.startswith("bash "): # bash script
# 记录一条信息,提示正在运行bash脚本。
LOGGER.info(f"Running {s} ...") # 正在运行 {s} ...
# os.system(command)
# os.system() 函数是 Python 的 os 模块中的一个函数,用于执行指定的命令行字符串。这个函数会调用系统的命令行解释器(通常是 shell)来执行命令,并返回命令执行后的退出状态码。
# 参数说明 :
# command :一个字符串,包含要执行的命令。
# 返回值 :返回命令执行后的退出状态码。在 Unix 和类 Unix 系统中,通常 0 表示成功,非 0 表示失败。在 Windows 系统中,返回值的具体含义取决于命令本身。
# 异常 :
# 如果发生错误(例如,无法找到命令解释器),可能会抛出 OSError 异常。
# 需要注意的是, os.system() 函数会创建一个新的 shell 来执行命令,这意味着它可能会受到当前工作目录的影响,并且不会继承当前 Python 进程的环境变量。
# 此外,由于安全原因,通常不推荐在程序中使用 os.system() ,因为它可能会执行任意命令,存在安全风险。
# 在可能的情况下,可以考虑使用 subprocess 模块中的函数,如 subprocess.run() 或 subprocess.call() ,因为它们提供了更多的控制和安全性。
# 使用 os.system 函数执行bash脚本,并将执行结果赋值给 r 变量。
r = os.system(s)
# 如果 s 是一个Python脚本。
else: # python script
# exec(object, globals=None, locals=None)
# 在Python中, exec() 函数用于执行存储在字符串或对象中的Python代码。这个函数非常强大,但也需谨慎使用,因为它会执行任意的代码,这可能导致安全风险。
# 参数说明 :
# object :必需,要执行的代码,可以是字符串或代码对象。
# globals :可选,用于执行代码时的全局变量字典。如果为 None ,则使用当前环境的全局变量。
# locals :可选,用于执行代码时的局部变量字典。如果为 None ,则使用 globals 作为局部变量环境。
# 返回值 :
# exec() 函数没有返回值(即返回 None ),因为它直接执行代码,而不是返回执行结果。
# 安全注意事项 :
# 由于 exec() 可以执行任意代码,因此它可能会被用来执行恶意代码。因此,只有在完全信任代码来源的情况下才应该使用 exec() ,并且永远不要对用户提供的输入使用 exec() ,除非经过了严格的验证和清理。
# 使用 exec 函数执行Python脚本,并将 data 字典作为局部变量传递给脚本。
exec(s, {"yaml": data})
# 计算 下载操作的耗时 ,并格式化为字符串。
dt = f"({round(time.time() - t, 1)}s)"
# 根据下载操作的结果生成一条成功或失败的提示信息。如果 r 为 0 或 None ,表示成功;否则表示失败。
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" # 成功✅ {dt},保存至 {colorstr('bold', DATASETS_DIR)}。 失败 {dt} ❌。
# 记录一条信息,显示数据集下载的结果。
LOGGER.info(f"Dataset download {s}\n") # 数据集下载 {s} 。
# 调用 check_font 函数检查并下载所需的字体文件。如果类别名称列表 data["names"] 中的所有名称都是ASCII字符,则下载 Arial.ttf 字体;否则下载 Arial.Unicode.ttf 字体。
# def check_font(font="Arial.ttf"):
# -> 用于检查和获取指定字体文件的路径。如果文件存在,则直接返回该路径,表示找到了字体文件。如果列表不为空,表示找到了至少一个匹配的字体文件路径,则返回列表中的第一个路径,表示找到了字体文件。返回字体文件的路径。无论是在用户配置目录中找到字体文件,还是从系统字体目录中找到匹配的字体文件,或者从网络下载字体文件,最终都返回字体文件的路径,供后续使用。
# -> return file / return matches[0] / return file
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
# 返回更新后的 data 字典,其中包含了数据集的配置信息和解析后的路径等。
return data # dictionary
# 这段代码通过解析YAML配置文件中的 "val" 和 "download" 字段,检查验证集文件是否存在,并在需要时自动下载数据集。它还处理了不同类型的下载指令,包括HTTP URL下载、bash脚本执行和Python脚本执行。最后,代码检查并下载所需的字体文件,确保在后续的数据可视化等操作中能够正确显示类别名称。通过这些操作,代码确保了数据集的可用性和完整性,为后续的模型训练和验证提供了基础。
# check_det_dataset 函数是一个用于检查和处理目标检测数据集的工具函数。它首先验证数据集文件是否存在,若为压缩文件则进行下载和解压。接着读取数据集的YAML配置文件,检查其中的关键配置项如训练集、验证集路径以及类别名称等是否完整且正确,必要时进行自动修正或生成默认值。然后解析和设置数据集相关路径,确保路径的正确性和完整性。最后,根据配置检查数据集文件是否存在,若不存在且配置了自动下载,则尝试下载数据集,并检查并下载所需的字体文件,以确保数据集的可用性和后续操作的顺利进行。整个函数通过一系列严谨的检查和处理步骤,为模型训练和验证提供了可靠的数据支持。
12.def check_cls_dataset(dataset, split=""):
# 这段代码定义了一个名为 check_cls_dataset 的函数,用于检查和处理分类数据集,确保数据集的完整性和可用性,并返回数据集的相关信息。
# 定义函数 check_cls_dataset ,接受两个参数。
# 1.dataset :数据集名称或路径。
# 2.split :数据集分割类型,如 "train" 、 "val" 或 "test" 。
def check_cls_dataset(dataset, split=""):
# 检查分类数据集,例如 Imagenet。
# 此函数接受 `dataset` 名称并尝试检索相应的数据集信息。
# 如果在本地找不到数据集,它会尝试从互联网上下载数据集并将其保存在本地。
"""
Checks a classification dataset such as Imagenet.
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
Args:
dataset (str | Path): The name of the dataset.
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns:
(dict): A dictionary containing the following keys:
- 'train' (Path): The directory path containing the training set of the dataset.
- 'val' (Path): The directory path containing the validation set of the dataset.
- 'test' (Path): The directory path containing the test set of the dataset.
- 'nc' (int): The number of classes in the dataset.
- 'names' (dict): A dictionary of class names in the dataset.
"""
# Download (optional if dataset=https://file.zip is passed directly) 说明如果直接传递的是URL,将尝试下载数据集。
# 判断 dataset 是否以 http:/ 或 https:/ 开头,即是否为URL。
if str(dataset).startswith(("http:/", "https:/")):
# 如果是URL,则调用 safe_download 函数下载数据集到 DATASETS_DIR 目录,解压并保留压缩文件。
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
# 如果 dataset 是一个本地文件路径,并且文件后缀为 .zip 、 .tar 或 .gz 。
elif Path(dataset).suffix in (".zip", ".tar", ".gz"):
# 调用 check_file 函数检查文件是否存在。
file = check_file(dataset)
# 调用 safe_download 函数下载并解压文件到 DATASETS_DIR 目录。
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
# 将 dataset 转换为 Path 对象,便于后续路径操作。
dataset = Path(dataset)
# 如果 dataset 是一个目录,则直接使用;否则,将其视为文件名,拼接到 DATASETS_DIR 目录下,并解析为绝对路径。
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
# 如果 data_dir 不是一个目录。
if not data_dir.is_dir():
# 记录警告信息,提示数据集未找到,将尝试下载。
LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...") # 未找到数据集⚠️,缺少路径 {data_dir},正在尝试下载...
# 记录当前时间,用于计算下载耗时。
t = time.time()
# 如果 dataset 是 "imagenet" 。
if str(dataset) == "imagenet":
# subprocess.run(args, *, stdin=None, input=None, stdout=None, stderr=None, capture_output=False,
# check=False, timeout=None, device=None, text=None, encoding=None, errors=None,
# shell=False, cwd=None, env=None, universal_newlines=False, bufsize=-1,
# start_new_session=False, restore_signals=True, preexec_fn=None,
# pass_fds=(), warn=False)
# subprocess.run() 是 Python 标准库 subprocess 模块中的一个函数,用于执行外部命令和程序。这个函数在 Python 3.5 中被引入,作为替代旧的 subprocess.call() 、 subprocess.check_call() 和 subprocess.check_output() 函数的一个更现代的接口。
# 参数说明 :
# args :命令及其参数。可以是字符串列表或单个字符串。如果 shell=True ,则 args 应该是单个字符串。
# stdin :指定 stdin 的文件描述符或文件对象。
# input :如果提供,写入到 stdin 。
# stdout 和 stderr :指定 stdout 和 stderr 的文件描述符或文件对象。
# capture_output :如果为 True ,则捕获 stdout 和 stderr 。
# check :如果为 True ,则在命令返回非零退出状态时引发 CalledProcessError 异常。
# timeout :等待命令完成的秒数。超时则引发 TimeoutExpired 异常。
# device 、 text 、 encoding 、 errors :控制输出和错误输出的编码。
# shell :如果为 True ,则通过 shell 执行 args 。
# cwd :子进程的当前工作目录。
# env :子进程的环境变量。
# universal_newlines :如果为 True ,则将输出和错误解释为文本。
# bufsize :缓冲区大小。
# start_new_session :如果为 True ,则在子进程中启动一个新的会话。
# restore_signals :如果为 True ,则在子进程中恢复信号。
# preexec_fn :在子进程中执行的函数。
# pass_fds :需要传递给子进程的文件描述符列表。
# warn :(已弃用)是否发出警告。
# 返回值 :
# 返回一个 CompletedProcess 实例,它包含执行命令后的信息,如 args 、 returncode 、 stdout 、 stderr 等。
# 运行一个bash脚本下载ImageNet数据集。
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
# 如果 dataset 不是 "imagenet" 。
else:
# 构建下载URL。
url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
# 调用 download 函数下载数据集。
# def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False): -> 用于从指定URL下载文件,并提供了多种选项来控制下载过程,包括解压、删除源文件、使用 curl 下载、多线程下载等。
download(url, dir=data_dir.parent)
# 生成下载成功的提示信息。
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" # 数据集下载成功✅ ({time.time() - t:.1f}s),保存至{colorstr('bold', data_dir)} 。
# 记录信息,提示下载成功。
LOGGER.info(s)
# 定义 训练集路径 。
train_set = data_dir / "train"
# 定义 验证集路径 ,优先使用 "val" 目录,如果不存在则使用 "validation" 目录,如果都不存在则为 None 。
val_set = (
data_dir / "val"
if (data_dir / "val").exists()
else data_dir / "validation"
if (data_dir / "validation").exists()
else None
) # data/test or data/val
# 定义 测试集路径 ,如果 "test" 目录存在则使用,否则为 None 。
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
# 如果 split 为 "val" 但验证集路径不存在。
if split == "val" and not val_set:
# 记录警告信息,提示使用测试集代替验证集。
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") # 警告⚠️未找到数据集“split=val”,请改用“split=test”。
# 如果 split 为 "test" 但测试集路径不存在。
elif split == "test" and not test_set:
# 记录警告信息,提示使用验证集代替测试集。
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.") # 警告⚠️未找到数据集“split=test”,请改用“split=val”。
# 计算训练集目录下的子目录数量,即 类别数量 。
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
# path.iterdir()
# path.iterdir() 是 Python pathlib 模块中 Path 类的一个方法,它用于遍历指定路径下的目录内容。这个方法返回一个迭代器,该迭代器产生路径下的所有文件和子目录的 Path 对象。
# 参数 :
# path :一个 Path 对象,表示你想要遍历的目录。
# 返回值 :
# 返回一个迭代器,产生路径下每个文件和子目录的 Path 对象。
# iterdir() 方法是处理文件系统时非常有用的工具,它提供了一种简洁的方式来访问目录内容,并且能够以面向对象的方式操作路径。
# 获取训练集目录下的子目录名称,即 类别名称列表 。
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
# 将 类别名称列表 排序并转换为字典,键为 类别索引 ,值为 类别名称 。
names = dict(enumerate(sorted(names)))
# Print to console
# 遍历 训练集 、 验证集 和 测试集 的 路径 。
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
# 生成前缀字符串,包含数据集类型和路径。
prefix = f'{colorstr(f"{k}:")} {v}...'
# 如果路径为 None 。
if v is None:
# 记录信息,提示路径不存在。
LOGGER.info(prefix)
# 如果路径存在。
else:
# 查找路径下的 所有图像文件 。
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
# 计算 图像文件数量 。
nf = len(files) # number of files
# 计算 包含图像文件的目录数量 。
nd = len({file.parent for file in files}) # number of directories
# 如果图像文件数量为0。
if nf == 0:
# 如果是训练集。
if k == "train":
# 抛出文件未找到错误,提示训练集图像文件未找到。
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ ")) # {dataset} ‘{k}:’未找到训练图像❌。
# 如果不是训练集。
else:
# 记录警告信息,提示未找到图像文件。
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found") # {prefix} 在 {nd} 个类别中找到 {nf} 个图像:警告 ⚠️ 未找到图像。
# 如果目录数量与类别数量不匹配。
elif nd != nc:
# 记录警告信息,提示类别数量不匹配。
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}") # {prefix} 在 {nd} 个类别中找到了 {nf} 个图像:错误 ❌️ 需要 {nc} 个类别,而不是 {nd} 。
# 如果图像文件数量和目录数量都正确。
else:
# 记录信息,提示图像文件和类别数量正确。
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ") # {prefix} 在 {nd} 个类别中找到 {nf} 张图片✅。
# 返回一个字典,包含 训练集 、 验证集 、 测试集的路径 , 类别数量 和 类别名称 字典。
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
# check_cls_dataset 函数通过一系列检查和处理步骤,确保分类数据集的完整性和可用性。它首先检查数据集是否为URL或压缩文件,并进行下载和解压。然后,它确定数据集的根目录路径,并检查训练集、验证集和测试集的路径是否存在。如果路径不存在,函数会尝试自动下载数据集。接着,函数计算类别数量和名称,并检查每个数据集分割中的图像文件数量和类别数量是否匹配。最后,函数打印相关信息到控制台,并返回一个包含数据集路径、类别数量和类别名称的字典。通过这些操作,函数为后续的模型训练和验证提供了可靠的数据支持。
13.class HUBDatasetStats:
# 这段代码定义了一个名为 HUBDatasetStats 的类,用于处理和统计Ultralytics HUB数据集的相关信息,包括数据集的基本信息、图像统计和图像压缩等操作。
# 定义了一个名为 HUBDatasetStats 的类。
class HUBDatasetStats:
# 用于生成 HUB 数据集 JSON 和 `-hub` 数据集目录的类。
"""
A class for generating HUB dataset JSON and `-hub` dataset directory.
Args:
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
Example:
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
```python
from ultralytics.data.utils import HUBDatasetStats
stats = HUBDatasetStats('path/to/coco8.zip', task='detect') # detect dataset
stats = HUBDatasetStats('path/to/coco8-seg.zip', task='segment') # segment dataset
stats = HUBDatasetStats('path/to/coco8-pose.zip', task='pose') # pose dataset
stats = HUBDatasetStats('path/to/imagenet10.zip', task='classify') # classification dataset
stats.get_json(save=True)
stats.process_images()
```
"""
# 这段代码是 HUBDatasetStats 类的初始化方法 __init__ ,用于初始化类的实例并进行一系列数据集检查和处理操作。
# 定义类的初始化方法,接受三个参数。
# 1.path :数据集路径或配置文件路径,默认为 "coco8.yaml" 。
# 2.task :任务类型,可以是 "detect" 、 "segment" 、 "pose" 或 "classify" ,默认为 "detect" 。
# 3.autodownload :是否自动下载数据集,默认为 False 。
def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
"""Initialize class."""
# Path(path)
# 在Python中, Path 是 pathlib 模块中的一个类,用于表示文件系统路径。 pathlib 是一个现代的文件路径操作库,它提供了面向对象的方式来处理文件和目录路径。
# 导入 Path 类 : from pathlib import Path。
# 可以使用 Path 类来创建一个路径对象。这个对象可以是一个文件或者目录的路径。
# 路径操作 :
# Path 对象提供了许多方法来操作路径。
# p.exists() :检查路径是否存在。
# p.is_file() :检查路径是否指向一个文件。
# p.is_dir() :检查路径是否指向一个目录。
# p.resolve() :解析路径,返回绝对路径。
# p.parent :返回路径的父目录。
# p.name :返回路径的最后一部分(文件名)。
# p.suffix :返回文件的后缀名。
# p.stem :返回文件名不包括后缀的部分。
# 将 path 转换为 Path 对象并 解析为绝对路径 ,确保路径的正确性。
path = Path(path).resolve()
# 记录一条信息,提示开始对指定路径的数据集进行检查。
LOGGER.info(f"Starting HUB dataset checks for {path}....") # 正在启动 HUB 数据集检查 {path}....
# 将 任务类型 赋值给实例变量 self.task ,用于后续的条件判断和处理。
self.task = task # detect, segment, pose, classify
# 如果任务类型是 分类 ( "classify" )。
if self.task == "classify":
# 调用 unzip_file 函数解压数据集文件,返回 解压后的目录路径 赋值给 unzip_dir 。
# def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): -> 用于解压ZIP文件到指定路径,并提供了多种选项来控制解压过程。返回解压后的目录路径。 -> return path # return unzip dir
unzip_dir = unzip_file(path)
# 调用 check_cls_dataset 函数检查分类数据集,返回 数据集的相关信息字典 赋值给 data 。
# def check_cls_dataset(dataset, split=""): -> 用于检查和处理分类数据集,确保数据集的完整性和可用性,并返回数据集的相关信息。返回一个字典,包含 训练集 、 验证集 、 测试集的路径 , 类别数量 和 类别名称 字典。 -> return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
data = check_cls_dataset(unzip_dir)
# 将解压后的目录路径赋值给 data 字典中的 "path" 键,确保路径信息的正确性。
data["path"] = unzip_dir
# 如果任务类型是 检测 ( "detect" )、 分割 ( "segment" )或 姿态 ( "pose" )。
else: # detect, segment, pose
# 调用实例方法 _unzip 解压数据集文件,返回 解压状态 、 解压目录 和 YAML配置文件路径 。
_, data_dir, yaml_path = self._unzip(Path(path))
# 开始一个 try 块,用于捕获和处理可能的异常。
try:
# Load YAML with checks
# 调用 yaml_load 函数加载YAML配置文件,返回 配置数据字典 赋值给 data 。
# def yaml_load(file="data.yaml", append_filename=False): -> 用于加载YAML文件并返回其内容作为字典。返回解析后的字典。 -> return data
data = yaml_load(yaml_path)
# 将 data 字典中的 "path" 键值设置为空字符串, 确保YAML配置文件位于数据集根目录 。
data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
# 调用 yaml_save 函数将更新后的配置数据保存回YAML文件。
# def yaml_save(file="data.yaml", data=None, header=""): -> 用于将数据以YAML格式保存到文件中。
yaml_save(yaml_path, data)
# 调用 check_det_dataset 函数检查检测/分割/姿态数据集,返回 数据集的相关信息字典 赋值给 data 。
# def check_det_dataset(dataset, autodownload=True): -> 用于检查目标检测数据集的完整性和正确性,并进行一些必要的处理。返回更新后的 data 字典,其中包含了数据集的配置信息和解析后的路径等。 -> return data # dictionary
data = check_det_dataset(yaml_path, autodownload) # dict
# 将解压 目录路径 赋值给 data 字典中的 "path" 键,确保路径信息的正确性。
data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
# 捕获 try 块中可能抛出的任何异常。
except Exception as e:
# 抛出一个新的异常,提示初始化过程中发生错误,并将原始异常作为上下文。
raise Exception("error/HUB/dataset_stats/init") from e
# 根据 data 字典中的 "path" 键值, 构建HUB数据集目录路径 赋值给实例变量 self.hub_dir 。
self.hub_dir = Path(f'{data["path"]}-hub')
# 构建HUB数据集中的 图像目录路径 赋值给实例变量 self.im_dir 。
self.im_dir = self.hub_dir / "images"
# 初始化实例变量 self.stats ,存储 数据集的类别数量 和 类别名称列表 。
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
# 将 数据集的相关信息字典 赋值给实例变量 self.data ,供后续方法使用。
self.data = data
# HUBDatasetStats 类的初始化方法 __init__ 通过一系列检查和处理步骤,确保数据集的完整性和可用性。它首先解析和检查数据集路径,根据任务类型调用相应的检查函数(分类任务调用 check_cls_dataset ,检测/分割/姿态任务调用 check_det_dataset ),并处理YAML配置文件。然后,它初始化实例变量,存储数据集的路径、类别信息和统计信息,为后续的数据集处理和统计操作提供基础。通过这些操作,类的实例能够有效地管理和处理Ultralytics HUB数据集。
# 这段代码定义了 HUBDatasetStats 类中的一个静态方法 _unzip ,用于解压数据集文件并返回相关信息。
@staticmethod
# 定义一个静态方法 _unzip ,接受一个参数。
# 1.path :数据集文件或配置文件的路径。
def _unzip(path):
"""Unzip data.zip."""
# 检查 path 是否以 .zip 结尾。
if not str(path).endswith(".zip"): # path is data.yaml
# 如果不是,假设 path 是 一个YAML配置文件路径 ,返回 False 、 None 和 原始路径 path 。这表示路径不是一个压缩文件,不需要解压。
return False, None, path
# 如果 path 是一个 .zip 文件,调用 unzip_file 函数解压该文件。解压目录设置为 path 的父目录,并将 解压后的目录路径 赋值给 unzip_dir 。
# def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): -> 用于解压ZIP文件到指定路径,并提供了多种选项来控制解压过程。返回解压后的目录路径。 -> return path # return unzip dir
unzip_dir = unzip_file(path, path=path.parent)
# 使用 assert 语句检查 unzip_dir 是否是一个目录。如果不是,抛出一个断言错误,提示解压失败,并说明正确的解压目录结构。例如, path/to/abc.zip 应该解压到 path/to/abc/ 。
assert unzip_dir.is_dir(), (
f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/" # 解压 {path} 时出错,未找到 {unzip_dir}。“f”path/to/abc.zip 必须解压至 path/to/abc/ 。
)
# 如果解压成功,返回三个值 : True 表示路径是一个压缩文件并已解压。 str(unzip_dir) 解压后的目录路径,转换为字符串。 find_dataset_yaml(unzip_dir) 在解压后的目录中查找YAML配置文件的路径。
# # def find_dataset_yaml(path: Path) -> Path: -> 用于在指定路径下查找YAML文件。返回 找到的YAML文件的路径 ,即 files 列表中的第一个元素。 -> return files[0]
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
# 静态方法 _unzip 用于处理数据集文件的解压操作。它首先检查输入路径是否为压缩文件,如果不是则直接返回。如果是压缩文件,它会解压该文件到指定目录,并确保解压后的目录存在。最后,它返回解压状态、解压目录路径和YAML配置文件路径,为后续的数据集检查和处理提供必要的信息。这个方法在 HUBDatasetStats 类的初始化过程中被调用,确保数据集文件的正确解压和路径的正确设置。
# 这段代码定义了 HUBDatasetStats 类中的一个实例方法 _hub_ops ,用于处理单个图像文件,将其压缩并保存到HUB数据集的图像目录中。
# 定义一个实例方法 _hub_ops ,接受一个参数。
# 1.f :图像文件的路径。
def _hub_ops(self, f):
# 保存压缩图像以供 HUB 预览。
"""Saves a compressed image for HUB previews."""
# 调用 compress_one_image 函数,将图像文件 f 压缩并保存到 self.im_dir 目录下。 self.im_dir 是HUB数据集的图像目录, Path(f).name 获取图像文件的名称,确保保存的文件名与原文件名一致。
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
# 方法 _hub_ops 用于处理单个图像文件,将其压缩并保存到HUB数据集的图像目录中。这个方法在 HUBDatasetStats 类中用于批量处理数据集中的图像文件,生成HUB预览所需的压缩图像。通过这种方式,可以减少图像文件的大小,提高数据集在HUB上的加载速度。这个方法通常在多线程或异步环境中调用,以提高处理效率。
# 这段代码定义了 HUBDatasetStats 类中的一个实例方法 get_json ,用于生成和返回Ultralytics HUB所需的数据集JSON统计信息。这个方法还支持将统计信息保存到文件,并可以选择性地打印详细信息。
# 定义实例方法 get_json ,接受两个参数。
# 1.save :布尔值,表示是否将统计信息保存到文件,默认为 False 。
# 2.verbose :布尔值,表示是否打印详细信息,默认为 False 。
def get_json(self, save=False, verbose=False):
# 返回 Ultralytics HUB 的数据集 JSON。
"""Return dataset JSON for Ultralytics HUB."""
# 这段代码定义了一个内部函数 _round ,用于将标签数据转换为整数类别和四位小数的浮点数。这个函数在 HUBDatasetStats 类的 get_json 方法中被调用,用于处理数据集的标签信息,确保标签格式符合Ultralytics HUB的要求。
# 定义内部函数 _round ,接受一个参数。
# 1.labels :标签数据字典。
def _round(labels):
# 将标签更新为整数类和 4 位小数浮点数。
"""Update labels to integer class and 4 decimal place floats."""
# 如果任务类型是检测( "detect" ),则从 labels 字典中获取 "bboxes" 键对应的 边界框坐标 。
if self.task == "detect":
coordinates = labels["bboxes"]
# 如果任务类型是分割( "segment" ),则从 labels 字典中获取 "segments" 键对应的 分割掩码 ,并将其展平为一维数组。
elif self.task == "segment":
# np.ndarray.flatten(order='C')
# np.flatten() 是 NumPy 提供的一个方法,用于将多维数组(ndarray)转换为一维数组。与 ravel() 方法不同, flatten() 总是返回一个新数组,即原始数据的副本,而不是视图。这意味着对返回数组的修改不会影响原始数组。
# 参数说明 :
# order :(可选)一个字符串,指定数组的遍历顺序。'C' 表示按行优先顺序(C语言风格),这是默认值。'F' 表示按列优先顺序(Fortran风格)。'A' 表示按数组的原始布局顺序,如果数组不是F顺序,则与'C'相同,如果是F顺序,则与'F'相同。
# 返回值 :
# 返回一个新的一维数组,它是原始多维数组数据的副本。
coordinates = [x.flatten() for x in labels["segments"]]
# 如果任务类型是姿态估计( "pose" ),则从 labels 字典中获取 "bboxes" 和 "keypoints" 键对应的 边界框 和 关键点坐标 ,并将它们拼接在一起。 labels["keypoints"].reshape(n, -1) 将关键点坐标展平为二维数组,确保 每个关键点的坐标在一行中 。
elif self.task == "pose":
n = labels["keypoints"].shape[0]
coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1)
# 如果任务类型不是上述三种之一,抛出一个 ValueError 异常,提示任务类型未定义。
else:
raise ValueError("Undefined dataset task.") # 未定义的数据集任务。
# 将 类别标签 labels["cls"] 和 坐标 coordinates 打包成一个迭代器,每个元素是一个元组,包含类别标签和对应的坐标。
zipped = zip(labels["cls"], coordinates)
# 遍历 zipped 迭代器,将每个类别标签转换为整数,并将每个坐标值转换为四位小数的浮点数。最终返回一个列表,每个元素是一个列表,包含 整数类别标签 和 四位小数的坐标值 。
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
# 函数 _round 用于处理数据集的标签信息,确保标签格式符合Ultralytics HUB的要求。它根据任务类型处理不同的标签数据,并将标签转换为整数类别和四位小数的浮点数。这个函数在 HUBDatasetStats 类的 get_json 方法中被调用,用于生成数据集的JSON统计信息。通过这种方式,可以确保数据集的标签信息在HUB上的一致性和准确性。
# 这段代码是 HUBDatasetStats 类中 get_json 方法的一部分,用于遍历数据集的各个分割( "train" 、 "val" 、 "test" ),检查每个分割是否存在,并获取其中的图像文件。
# 遍历三个字符串 "train" 、 "val" 和 "test" ,分别代表 训练集 、 验证集 和 测试集 。
for split in "train", "val", "test":
# 在遍历每个分割之前,将 self.stats 字典中对应分割的值初始化为 None 。这一步是为了 预定义统计信息 ,确保在后续处理中不会出现键不存在的错误。
self.stats[split] = None # predefine
# 从 self.data 字典中 获取当前分割的路径 。 self.data 字典包含数据集的配置信息,其中每个分割的路径存储在对应的键中( "train" 、 "val" 、 "test" )。
path = self.data.get(split)
# Check split
# 检查当前分割的路径是否为 None 。如果路径为 None ,说明该分割不存在,跳过当前迭代,继续处理下一个分割。
if path is None: # no split
continue
# 使用 Path 对象的 rglob 方法递归查找当前分割路径下的所有文件,并过滤出图像文件。 IMG_FORMATS 是一个包含图像文件后缀的列表(如 ["jpg", "jpeg", "png", "bmp"] ), f.suffix[1:].lower() 获取文件后缀并转换为小写,确保文件后缀匹配。
files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
# 检查当前分割路径下是否找到图像文件。
if not files: # no images
# 如果图像文件列表为空,说明该分割中没有图像文件,跳过当前迭代,继续处理下一个分割。
continue
# 这段代码通过遍历数据集的各个分割,检查每个分割是否存在,并获取其中的图像文件。如果分割路径不存在或其中没有图像文件,跳过当前分割,继续处理下一个分割。这一步确保了后续处理的图像文件列表是有效的,为生成数据集的统计信息提供了基础。
# 这段代码是 HUBDatasetStats 类中 get_json 方法的一部分,用于处理分类任务( "classify" )的数据集统计信息。
# Get dataset statistics
# 检查当前任务是否为分类任务( "classify" )。如果是,执行以下步骤。
if self.task == "classify":
# 从 torchvision.datasets 模块导入 ImageFolder 类,用于 加载和处理图像数据集 。
from torchvision.datasets import ImageFolder
# torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=None, is_valid_file=None)
# torchvision.datasets.ImageFolder 是 PyTorch 的一个类,它提供了一种方便的方式来加载结构化存储的图像数据集。这种结构化存储意味着图像被组织在不同的文件夹中,每个文件夹的名称对应一个类别。
# 参数 :
# root :数据集的根目录路径,其中包含所有类别的子文件夹。
# transform :一个可选的函数或可调用对象,用于对图像进行预处理或数据增强。它在图像加载后、返回前应用于图像。
# target_transform :一个可选的函数或可调用对象,用于对标签进行预处理。它在标签加载后、返回前应用于标签。
# loader :一个函数,用于加载图像文件。默认情况下,使用 PIL 库加载图像。
# is_valid_file :一个函数,用于检查文件名是否有效。如果提供,它将被用于过滤文件。
# 返回值 :
# 返回一个 ImageFolder 实例,该实例包含图像数据集的加载和预处理逻辑。
# ImageFolder 类是 PyTorch 中处理图像分类任务时常用的工具之一,它简化了数据加载和预处理的过程,使得用户可以专注于模型的训练和评估。
# torchvision.datasets.ImageFolder 类的实例通常包含以下常见的属性 :
# root : 字符串,表示数据集的根目录路径。
# samples : 列表,包含数据集中所有图像的元组信息,通常每个元组包含图像的路径和对应的标签索引。
# classes : 列表,包含数据集中所有类别的名称,顺序与 samples 中的标签索引相对应。
# class_to_idx : 字典,映射类别名称到它们在 classes 列表中的索引。
# imgs : 列表,与 samples 类似,包含图像的路径和标签,但在某些版本的 torchvision 中可能不直接提供。
# targets : 列表,包含与 imgs 列表相对应的标签。
# transform : 函数或可调用对象,用于对图像进行预处理或数据增强,如果提供了 transform 参数,则在加载图像时应用。
# target_transform : 函数或可调用对象,用于对标签进行预处理,如果提供了 target_transform 参数,则在加载标签时应用。
# loader : 函数,用于加载图像文件,默认使用 PIL 库。
# is_valid_file : 函数,用于检查文件名是否有效,如果提供了 is_valid_file 参数,则在加载图像时用于过滤文件。
# 这些属性使得 ImageFolder 实例能够方便地访问和操作图像数据集,同时提供了灵活的预处理和数据加载选项。通过这些属性,用户可以轻松地对数据集进行迭代、应用变换、加载图像和标签等操作。
# 使用 ImageFolder 类创建一个数据集实例,传入 当前分割的路径 self.data[split] 。 ImageFolder 会自动加载路径下的所有图像文件,并根据文件夹结构生成类别标签。
dataset = ImageFolder(self.data[split])
# 初始化一个长度为类别数量的零数组 x ,用于 统计每个类别的图像数量 。 dataset.classes 是一个包含所有类别名称的列表。
x = np.zeros(len(dataset.classes)).astype(int)
# 遍历 dataset.imgs , dataset.imgs 是一个列表,每个元素是一个元组,包含 图像路径和类别索引 。 im[1] 是类别索引, x[im[1]] += 1 将对应类别的计数加1。
for im in dataset.imgs:
x[im[1]] += 1
# 更新统计信息。
# 将统计信息更新到 self.stats 字典中,键为当前分割名称( "train" 、 "val" 、 "test" )。
self.stats[split] = {
# 包含 总图像数量 和 每个类别的图像数量 。
"instance_stats": {"total": len(dataset), "per_class": x.tolist()},
# 包含 总图像数量 、 未标记图像数量 (这里假设为0)和 每个类别的图像数量 。
"image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
# 包含 每个图像的路径 和 类别索引 ,格式为 {图像文件名: 类别索引} 。
"labels": [{Path(k).name: v} for k, v in dataset.imgs],
}
# 这段代码通过 ImageFolder 类加载和处理分类任务的数据集,统计每个类别的图像数量,并将统计信息更新到 self.stats 字典中。这一步确保了分类任务的数据集统计信息的准确性和完整性,为生成数据集的JSON统计信息提供了基础。
# 这段代码是 HUBDatasetStats 类中 get_json 方法的一部分,用于处理检测( "detect" )、分割( "segment" )和姿态估计( "pose" )任务的数据集统计信息。
# 如果当前任务 不是分类任务 ( "classify" ),则处理检测、分割和姿态估计任务。
else:
# 从 ultralytics.data 模块导入 YOLODataset 类,用于加载和处理图像数据集。
from ultralytics.data import YOLODataset
# 使用 YOLODataset 类创建一个数据集实例,传入当前分割的图像路径 self.data[split] 、数据集配置 self.data 和任务类型 self.task 。
# class YOLODataset(BaseDataset):
# -> 用于处理YOLO模型的数据集。
# -> def __init__(self, *args, data=None, task="detect", **kwargs):
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
# 统计每个类别的图像数量。
# np.array 将这个列表转换为一个二维数组。假设数据集有128张图像,每个图像有80个类别,最终结果是一个形状为 (128 x 80) 的二维数组。
# 示例 :
# 假设数据集有3张图像,每个图像有3个类别(类别0、类别1、类别2):
# 图像1的类别标签 : [0, 1, 1] 。
# 图像2的类别标签 : [2, 2, 2] 。
# 图像3的类别标签 : [0, 1, 2]。
# 计算每个图像的类别计数数组 :
# 图像1 : [1, 2, 0] (类别0出现1次,类别1出现2次,类别2出现0次) 。
# 图像2 : [0, 0, 3] (类别0出现0次,类别1出现0次,类别2出现3次) 。
# 图像3 : [1, 1, 1] (类别0出现1次,类别1出现1次,类别2出现1次) 。
# 最终结果是一个二维数组 :
# [
# [1, 2, 0],
# [0, 0, 3],
# [1, 1, 1]
# ]
# 形状为 (3 x 3) ,表示3张图像和3个类别的计数。
x = np.array(
[
# np.bincount(x, minlength=None)
# np.bincount 是 NumPy 库中的一个函数,它用于计算非负整数数组中每个值的出现次数。
# 参数 :
# x :输入数组,其中的元素必须是非负整数。
# minlength (可选) :输出数组的最小长度。如果提供,数组 x 中小于 minlength 的值将被忽略,而 x 中等于或大于 minlength 的值将导致数组被扩展以包含这些值。如果未提供或为 None ,则输出数组的长度将与 x 中的最大值加一相匹配。
# 返回值 :
# 返回一个数组,其中第 i 个元素代表输入数组 x 中值 i 出现的次数。
# 功能 :
# np.bincount 函数对输入数组 x 中的每个值进行计数,返回一个一维数组,其长度至少与 x 中的最大值一样大。
# 如果 x 中的某个值没有出现,那么在返回的数组中对应的位置将为 0。
# 例:
# x = np.array([1, 2, 3, 3, 0, 1, 4])
# np.bincount(x)
# '''array([1, 2, 1, 2, 1], dtype=int64)'''
# 输出 : [1 2 1 2 1]。统计索引出现次数:索引0出现1次,1出现2次,2出现1次,3出现2次,4出现1次。
# label["cls"].astype(int).flatten() 将类别标签展平为一维数组, np.bincount 计算每个类别的出现次数, minlength=self.data["nc"] 确保数组长度为类别数量。最终结果是一个二维数组 x ,形状为 (图像数量 x 类别数量) 。
np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
# 使用 TQDM 进度条 遍历数据集中的每个标签 ,计算每个图像中每个类别的出现次数。
for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
]
) # shape(128x80)
# 更新统计信息。
# 将统计信息更新到 self.stats 字典中,键为当前分割名称( "train" 、 "val" 、 "test" )。
self.stats[split] = {
# 包含总实例数量和每个类别的实例数量。
"instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
# 包含总图像数量、未标记图像数量和每个类别的图像数量。
"image_stats": {
"total": len(dataset),
"unlabelled": int(np.all(x == 0, 1).sum()),
"per_class": (x > 0).sum(0).tolist(),
},
# 包含每个图像的路径和处理后的标签信息,格式为 {图像文件名: 处理后的标签} 。
"labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
}
# 这段代码通过 YOLODataset 类加载和处理检测、分割和姿态估计任务的数据集,统计每个类别的图像数量,并将统计信息更新到 self.stats 字典中。这一步确保了检测、分割和姿态估计任务的数据集统计信息的准确性和完整性,为生成数据集的JSON统计信息提供了基础。通过这种方式,可以确保数据集的标签信息在HUB上的一致性和准确性。
# 这段代码是 HUBDatasetStats 类中 get_json 方法的一部分,用于保存、打印和返回数据集的统计信息。
# Save, print and return
# 检查参数 save 是否为 True 。如果为 True ,则执行保存操作。
if save:
# 使用 mkdir 方法创建HUB数据集目录 self.hub_dir 。 parents=True 表示如果需要,会创建所有父目录。 exist_ok=True 表示如果目录已存在,不会抛出异常。
self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/
# 构建 统计信息文件的路径 ,文件名为 "stats.json" ,存储在 self.hub_dir 目录下。
stats_path = self.hub_dir / "stats.json"
# 记录一条信息,提示正在保存统计信息文件的路径。 stats_path.resolve() 返回路径的绝对路径。
LOGGER.info(f"Saving {stats_path.resolve()}...") # 正在保存 {stats_path.resolve()}...
# 使用 with 语句打开文件 stats_path ,以写入模式( "w" )打开。使用 json.dump 将 self.stats 字典写入文件,保存为JSON格式。
with open(stats_path, "w") as f:
json.dump(self.stats, f) # save stats.json
# 检查参数 verbose 是否为 True 。如果为 True ,则执行打印操作。
if verbose:
# 记录一条信息,打印 self.stats 字典的JSON格式字符串。 indent=2 表示缩进为2个空格, sort_keys=False 表示不排序键。
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
# 返回 self.stats 字典,包含数据集的统计信息。
return self.stats
# 这段代码通过检查参数 save 和 verbose ,决定是否保存和打印数据集的统计信息。如果 save 为 True ,则创建HUB数据集目录并保存统计信息到 stats.json 文件。如果 verbose 为 True ,则打印统计信息的JSON格式字符串。最后,返回包含数据集统计信息的字典。这一步确保了统计信息的持久化和可视化,便于后续的数据分析和验证。
# 方法 get_json 通过一系列检查和处理步骤,生成Ultralytics HUB所需的数据集JSON统计信息。它首先遍历数据集的各个分割(训练集、验证集、测试集),根据任务类型处理标签和图像文件,计算统计信息并更新 self.stats 。然后,根据参数 save 和 verbose ,将统计信息保存到文件并打印详细信息。最后,返回生成的统计信息字典。这个方法为HUB数据集的管理和分析提供了重要的支持。
# 这段代码定义了 HUBDatasetStats 类中的一个实例方法 process_images ,用于压缩数据集中的图像并保存到Ultralytics HUB的图像目录中。
# 定义实例方法 process_images ,不接受额外参数。
def process_images(self):
# 压缩 Ultralytics HUB 的图像。
"""Compress images for Ultralytics HUB."""
# 从 ultralytics.data 模块导入 YOLODataset 类,用于加载和处理图像数据集。
from ultralytics.data import YOLODataset # ClassificationDataset
# 使用 mkdir 方法创建HUB数据集的图像目录 self.im_dir 。 parents=True 表示如果需要,会创建所有父目录。 exist_ok=True 表示如果目录已存在,不会抛出异常。
self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/
# 遍历三个字符串 "train" 、 "val" 和 "test" ,分别代表 训练集 、 验证集 和 测试集 。
for split in "train", "val", "test":
# 检查当前分割的路径是否为 None 。如果路径为 None ,说明该分割不存在,跳过当前迭代,继续处理下一个分割。
if self.data.get(split) is None:
continue
# 使用 YOLODataset 类创建一个 数据集实例 ,传入 当前分割的图像路径 self.data[split] 和 数据集配置 self.data 。
dataset = YOLODataset(img_path=self.data[split], data=self.data)
# 使用 ThreadPool 创建一个线程池, NUM_THREADS 是线程池的线程数量。
with ThreadPool(NUM_THREADS) as pool:
# pool.imap 方法用于并行处理 dataset.im_files 中的每个图像文件,调用 self._hub_ops 方法进行压缩操作。 TQDM 用于显示进度条, total=len(dataset) 设置进度条的总长度, desc=f"{split} images" 设置进度条的描述信息。
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
pass
# 记录一条信息,提示所有图像已压缩并保存到 self.im_dir 目录。
LOGGER.info(f"Done. All images saved to {self.im_dir}") # 完成。所有图像已保存至 {self.im_dir} 。
# 返回HUB数据集的 图像目录路径 self.im_dir 。
return self.im_dir
# 方法 process_images 通过并行处理,压缩数据集中的图像并保存到HUB数据集的图像目录中。这一步确保了图像文件的压缩和存储,便于在Ultralytics HUB上进行高效的加载和展示。通过使用线程池和进度条,提高了处理效率并提供了可视化反馈。
# HUBDatasetStats 类是一个用于处理和统计Ultralytics HUB数据集信息的工具类。它支持多种任务类型,包括目标检测、分割、姿态估计和分类。类的初始化方法 __init__ 负责检查和准备数据集,包括解压数据集文件、加载YAML配置文件,并初始化数据集的路径和统计信息。 get_json 方法用于生成数据集的JSON统计信息,支持保存到文件和打印详细信息。 process_images 方法则用于压缩数据集中的图像并保存到HUB数据集的图像目录中,提高数据集在HUB上的加载速度。通过这些方法, HUBDatasetStats 类为数据集的管理和分析提供了全面的支持,确保数据集的完整性和可用性。
14.def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
# 这段代码定义了一个名为 compress_one_image 的函数,用于压缩单个图像文件。
# 定义函数 compress_one_image ,接受四个参数。
# 1.f :原始图像文件的路径。
# 2.f_new :压缩后图像文件的路径,默认为 None ,表示覆盖原始文件。
# 3.max_dim :图像的最大维度,默认为1920像素。
# 4.quality :JPEG图像的压缩质量,默认为50。
def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
# 使用 Python 图像库 (PIL) 或 OpenCV 库将单个图像文件压缩为较小尺寸,同时保留其纵横比和质量。如果输入图像小于最大尺寸,则不会调整其大小。
"""
Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
resized.
Args:
f (str): The path to the input image file.
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
quality (int, optional): The image compression quality as a percentage. Default is 50%.
Example:
```python
from pathlib import Path
from ultralytics.data.utils import compress_one_image
for f in Path('path/to/dataset').rglob('*.jpg'):
compress_one_image(f)
```
"""
# 尝试使用PIL(Python Imaging Library)进行图像压缩。
try: # use PIL
# 使用PIL的 Image.open 方法打开图像文件 f 。
im = Image.open(f)
# 计算图像的最大维度与指定最大维度 max_dim 的比值 r 。
r = max_dim / max(im.height, im.width) # ratio
# 如果图像的尺寸超过 max_dim ,则按比例缩小图像尺寸。
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
# 将压缩后的图像保存到指定路径 f_new 或原始路径 f ,使用JPEG格式,指定压缩质量 quality ,并进行优化。
im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
# 如果PIL操作失败,捕获异常并使用OpenCV进行图像压缩。
except Exception as e: # use OpenCV
# 记录一条警告信息,提示PIL操作失败的原因。
LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}") # 警告 ⚠️ HUB 操作 PIL 失败 {f}:{e} 。
# 使用OpenCV的 cv2.imread 方法读取图像文件 f 。
im = cv2.imread(f)
# 获取图像的高度和宽度。
im_height, im_width = im.shape[:2]
# 计算图像的最大维度与指定最大维度 max_dim 的比值 r 。
r = max_dim / max(im_height, im_width) # ratio
# 如果图像的尺寸超过 max_dim ,则按比例缩小图像尺寸,使用 cv2.INTER_AREA 插值方法。
if r < 1.0: # image too large
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
# 将压缩后的图像保存到指定路径 f_new 或原始路径 f 。
cv2.imwrite(str(f_new or f), im)
# compress_one_image 函数用于压缩单个图像文件,确保图像的尺寸不超过指定的最大维度 max_dim ,并使用指定的JPEG质量进行压缩。函数首先尝试使用PIL进行压缩,如果PIL操作失败,则使用OpenCV进行压缩。通过这种方式,函数提供了两种备用方案,确保图像压缩操作的可靠性和健壮性。这使得函数在不同的环境和情况下都能正常工作,适用于Ultralytics HUB数据集的图像处理。
15.def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
# 这段代码定义了一个名为 autosplit 的函数,用于自动将图像数据集分割为训练集、验证集和测试集。
# 定义函数 autosplit ,接受三个参数。
# 1.path :图像文件夹的路径,默认为 DATASETS_DIR / "coco8/images" 。
# 2.weights :一个元组,表示训练集、验证集和测试集的权重,默认为 (0.9, 0.1, 0.0) 。
# 3.annotated_only :一个布尔值,表示是否只使用已标注的图像,默认为 False 。
def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
# 自动将数据集拆分为训练/验证/测试拆分并将结果拆分保存到 autosplit_*.txt 文件中。
"""
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
Args:
path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
Example:
```python
from ultralytics.data.utils import autosplit
autosplit()
```
"""
# 将 path 转换为 Path 对象,便于路径操作。
path = Path(path) # images dir
# 使用 rglob 方法递归查找路径下的所有文件,并过滤出图像文件。 IMG_FORMATS 是一个包含图像文件后缀的列表(如 ["jpg", "jpeg", "png", "bmp"] ), x.suffix[1:].lower() 获取文件后缀并转换为小写,确保文件后缀匹配。结果按字母顺序排序。
files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
# 计算图像文件的数量。
n = len(files) # number of files
# 设置随机种子为0,确保每次运行结果的可重复性。
random.seed(0) # for reproducibility
# random.choices(population, weights=None, cum_weights=None, k=1)
# random.choices 是 Python 标准库 random 模块中的一个函数,它用于从给定的序列中随机选择元素,可以有放回地选择多次,也可以指定每个元素被选择的概率。
# population :一个序列,表示可供选择的元素集合。
# weights :(可选)一个与 population 中元素数量相同的序列,表示每个元素被选择的相对概率。
# cum_weights :(可选)一个与 population 中元素数量相同的序列,表示累积概率。这在创建加权随机选择时非常有用,其中每个元素的选择概率取决于前面所有元素的累积概率。
# k :(可选)一个整数,表示需要选择的元素数量。
# 返回值 :
# 返回一个列表,包含从 population 中随机选择的 k 个元素。
# 功能 :
# random.choices 函数允许你从 population 中选择多个元素,可以是有放回的选择。如果提供了 weights 或 cum_weights ,则选择会根据这些权重进行加权,使得某些元素有更高的概率被选中。
# 注意事项 :
# 如果同时提供了 weights 和 cum_weights , random.choices 将使用 weights 而忽略 cum_weights 。
# k 的值可以大于 population 的长度,这意味着可能会选择重复的元素。
# 如果没有提供 weights 或 cum_weights ,则所有元素被选择的概率相等。
# 使用 random.choices 方法根据权重 weights 随机分配每个图像到训练集、验证集或测试集。 k=n 表示生成的随机索引数量与图像文件数量相同。
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
# 定义三个文本文件名,分别用于存储 训练集 、 验证集 和 测试集 的图像路径。
txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
# 遍历文本文件名列表,如果文件已存在,则删除。
for x in txt:
if (path.parent / x).exists():
(path.parent / x).unlink() # remove existing
# 记录一条信息,提示正在自动分割图像,并根据 annotated_only 的值决定是否只使用已标注的图像。
LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only) # 自动分割来自 {path}" + " 的图像,仅使用 *.txt 标签图像" * annotated_only 。
# 使用 TQDM 进度条遍历 每个图像文件 及 其对应的分割索引 。 zip(indices, files) 将图像文件和分割索引配对, total=n 设置进度条的总长度为图像文件的数量。
for i, img in TQDM(zip(indices, files), total=n):
# 检查是否需要只使用已标注的图像( annotated_only 为 True )。如果 annotated_only 为 False ,则直接通过。如果 annotated_only 为 True ,则调用 img2label_paths 函数获取图像对应的标签文件路径,并检查该标签文件是否存在。 Path(img2label_paths([str(img)])[0]).exists() 返回 True 表示标签文件存在。
# def img2label_paths(img_paths): -> 将图像文件路径转换为对应的标签文件路径。使用列表推导式,将每个 图像文件路径 转换为 对应的标签文件路径 。 -> return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
# 使用 with 语句打开文本文件,以追加模式( "a" )打开。 path.parent / txt[i] 表示文本文件的路径, txt[i] 根据分割索引选择对应的文本文件名( "autosplit_train.txt" 、 "autosplit_val.txt" 、 "autosplit_test.txt" )。
with open(path.parent / txt[i], "a") as f:
# 图像路径写入文本文件。 img.relative_to(path.parent).as_posix() 获取图像路径相对于父目录的相对路径,并转换为POSIX路径格式。 f"./{...}" + "\n" 将路径前缀为 "./" ,并添加换行符,确保每行一个路径。
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
# autosplit 函数用于自动将图像数据集分割为训练集、验证集和测试集。它首先获取图像文件列表,根据权重随机分配每个图像到不同的分割,并将结果写入对应的文本文件。如果 annotated_only 为 True ,则只使用已标注的图像。通过这种方式,函数为数据集的分割提供了自动化和灵活的解决方案,确保数据集的合理分配和使用。