1 数据集
数据集为10分类问题
2 数据集可视化代码
注意事项: cv2.imread()函数中的路径不能包括中文,否则无法正常读取。
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import math
import os
import cv2
from tqdm import tqdm
from PIL import Image
# folder_path = './/data//小波时频图//train_img//正常'
folder_path = "C:/Users/Administrator/Desktop/ResNet/data/xiaobo/train_img/Normal"
# 可视化图像的个数
N = 36
n = math.floor(np.sqrt(N))
images = []
for each_img in os.listdir(folder_path)[:N]:
img_path = os.path.join(folder_path, each_img)
img_bgr = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
images.append(img_rgb)
# print(len(images))
fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(fig, 111, # 类似绘制子图 subplot(111)
nrows_ncols=(n, n), # 创建 n 行 m 列的 axes 网格
axes_pad=0.02, # 网格间距
share_all=True
)
# 遍历每张图像
for ax, im in zip(grid, images):
ax.imshow(im)
ax.axis('off')
数据集可视化6*6图片
参考:https://www.bilibili.com/video/BV1Jd4y1T7rw/?spm_id_from=333.880.my_history.page.click&vd_source=47e66af6a90e9c41c341fd3c692ced14