基于Pytorch的ResNet垃圾图片分类
数据集预处理
画图片的宽高分布散点图
import os
import matplotlib.pyplot as plt
import PIL.Image as Image
def plot_resolution(dataset_root_path):
image_size_list = []#存放图片尺寸
for root, dirs, files in os.walk(dataset_root_path):
for file in files:
image_full_path = os.path.join(root, file)
image = Image.open(image_full_path)
image_size = image.size
image_size_list.append(image_size)
print(image_size_list)
image_width_list = [image_size_list[i][0] for i in range(len(image_size_list))]#存放图片的宽
image_height_list = [image_size_list[i][1] for i in range(len(image_size_list))]#存放图片的高
plt.rcParams['font.sans-serif'] = ['SimHei']#设置中文字体
plt.rcParams['font.size'] = 8
plt.rcParams['axes.unicode_minus'] = False#解决图像中的负号乱码问题
plt.scatter(image_width_list, image_height_list, s=1)
plt.xlabel('宽')
plt.ylabel('高')
plt.title('图像宽高分布散点图')
plt.show()
if __name__ == '__main__':
dataset_root_path = "F:\细粒度识别项目\清洁用品"
plot_resolution(dataset_root_path)
运行结果:
注: os.walk详细解释参考
画出数据集的各个类别图片数量的条形图
文件组织结构:
def plot_bar(dataset_root_path):
file_name_list = []
file_num_list = []
for root, dirs, files in os.walk(dataset_root_path):
if len(dirs) != 0 :
for dir in dirs:
file_name_list.append(dir)
file_num_list.append(len(files))
file_num_list = file_num_list[1:]#去掉根目录下面的文件数量(0) [0, 20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22, 27, 5, 6, 4]
#[20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22,27, 5, 6, 4]
mean = np.mean(file_num_list)
print("mean= ", mean)
bar_positions = np.arange(len(file_name_list))
fig, ax = plt.subplots()
ax.bar(bar_positions, file_num_list, 0.5)# 柱间的距离, 柱的值, 柱的宽度
ax.plot(bar_positions, [mean for i in bar_positions], color="red")#画出平均线
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['font.size'] = 8
plt.rcParams['axes.unicode_minus'] = False # 解决图像中的负号乱码问题
ax.set_xticks(bar_positions)#设置x轴的刻度
ax.set_xticklabels(file_name_list, rotation=98) #设置x轴的标签
ax.set_ylabel("类别数量")
ax.set_title("各个类别数量分布散点图")
plt.show()