1. 导入资源包
import torch.nn as nn
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk,ImageDraw,ImageFont
import torch
from torchvision import transforms, models
from efficientnet_pytorch import EfficientNet
import numpy as np
import cv2
import tkinter.ttk as ttk
import time
from tkinter import filedialog
import subprocess
import os
import re
import threading
注:
1. torch.nn:这是PyTorch的一个模块,用于构建神经网络。
2. tkinte:Python的标准GUI库,用于创建图形用户界面。
3. PIL(Pillow):一个图像处理库,用于打开、处理和保存多种不同格式的图像。
4. torch:PyTorch的核心库,用于构建和训练神经网络。
5. torchvision:包含图像处理、数据集和预训练模型等工具的库。
6. efficientnet_pytorch:一个提供EfficientNet模型的PyTorch实现。
7. numpy:一个强大的数学库,用于进行科学计算。
8. cv2:OpenCV库的Python接口,用于计算机视觉任务。
9. tkinter.ttk:tkinter的一个模块,提供了改进的Tk控件。
10. threading:Python的标准线程库,用于创建和管理线程。
2. 创建EfficientNet 的模型
2.1. 加载预训练模型
class EfficientNetModel(nn.Module):
def __init__(self, num_classes=10, pretrained=True):
super(EfficientNetModel, self).__init__()
# 加载预训练的EfficientNet模型
self.efficientnet = EfficientNet.from_name('efficientnet-b3')
num_ftrs = self.efficientnet._fc.in_features
# 将EfficientNet模型的最后一层全连接层替换为一个新的全连接层,输出特征数量设置为num_classes
self.efficientnet._fc = nn.Linear(num_ftrs, num_classes)
注:类继承自 torch.nn.Module,用于创建一个基于 EfficientNet 的模型,可以用于图像分类任务。这个类接受两个参数:num_classes 表示输出分类的数量,pretrained 表示是否使用预训练的模型权重。
2.2. 向前传播
def forward(self, x):
return self.efficientnet(x)
注:提供的代码片段是Python编程语言中的一个方法定义,这个方法forward是许多深度学习框架中神经网络模型的一个标准部分。特别是在PyTorch框架中,forward方法定义了数据通过网络的前向传播方式。在这个特定的例子中,方法接受一个参数x,这通常代表输入数据,然后通过一个名为efficientnet的层或模块进行处理,并返回结果。
2.3. 定义EfficientNet类
var foo = 'bar';
注:提供的代码片段是JavaScript语言中的一个变量声明和赋值语句。这个语句创建了一个名为foo的变量,并将字符串’bar’赋值给它。在JavaScript中,使用var关键字声明的变量拥有函数作用域或全局作用域,这意味着变量foo可以在声明它的函数内部或者全局环境中被访问和修改。
3. 加载训练好的模型参数
model_path = 'best_EfficientNet_b3_updata6.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
注:
1. model_path 变量指定了模型权重文件的路径,这个文件通常是一个 PyTorch 模型的状态字典,它包含了模型的参数。
2. torch.load() 函数用于加载模型权重文件。
3. map_location=torch.device(‘cpu’) 参数指定了加载模型权重时使用的设备。在这里,它被设置为 CPU。这意味着即使模型是在 GPU 上训练的,加载时也会将其参数移动到 CPU 上。
4. model.load_state_dict() 方法用于将加载的状态字典加载到模型中,这样模型就有了训练好的参数。
5. model.eval() 方法将模型设置为评估模式。这个方法对模型没有什么影响,但是如果模型中有定义为仅在训练模式下使用的层(如 Dropout 或 BatchNorm),则这个调用会使其在评估模式下正确地行为。
4. 定义图像转换
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ResNet-50 的标准化
])
注:
1. transforms.Resize((300, 300)):将输入图像的大小调整为 300x300 像素。这个步骤确保所有输入到模型的图像都具有相同的尺寸,这是许多深度学习模型的要求。
2. transforms.ToTensor():将 PIL 图像或 NumPy 数组转换为浮点数张量,并对其进行归一化,使得其数值范围在 [0, 1] 内。此外,它还会将图像的维度从 (H, W, C) 转换为 (C, H, W),这与 PyTorch 的默认图像维度格式相匹配。
3. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):对图像进行标准化处理。这个步骤根据提供的均值和标准差对图像的每个通道进行标准化。这些均值和标准差通常来自预训练模型的训练数据集,对于 ResNet-50 模型,这些值是常用的。标准化可以使得模型在训练和推理时对输入数据的分布有更稳定的预期,这有助于模型的性能和收敛。
5. 类别标签
classes = ['皮卡', '敞篷车', '跑车', '掀背两箱车', '小型面包车', 'SUV', '轿车', '厢式货车', '旅行车', '公共汽车', '消防车', '出租车']
6. 初始化全局变量
// An highlighted block
var foo = 'bar';
注:
1. selected_image_path:这个变量可能用于存储用户选择的图像文件的路径。在 GUI 应用程序中,通常会提供一个文件对话框让用户选择图像文件,然后应用程序将选择的文件的路径存储在这个变量中。
2. label_text:这个变量可能用于存储与图像相关联的标签文本。例如,如果应用程序是一个图像分类器,label_text 可能用于显示图像的最可能的类别。
3. right_canvas_image:这个变量可能用于存储一个图像对象,该对象将在 GUI 的画布上显示。在 Tkinter(Python 的 GUI 库)中,画布(Canvas)是一个可以用来绘制图形和图像的组件。right_canvas_image 可能是一个 ImageTk.PhotoImage 对象,它是一个可以在 Tkinter 中显示的图像。
7. 上传图片
def upload_image():
global selected_image_path, label_text
# 获取'.\output'目录下的所有文件
files = os.listdir(r'.\output')
# 遍历文件
for file in files:
# 拼接文件的完整路径
file_path = os.path.join(r'.\output', file)
# 检查文件是否为图片文件(这里假设只需要删除jpg和png文件)
if file.endswith('.jpg') or file.endswith('.png'):
# 删除文件
os.remove(file_path)
file_path = filedialog.askopenfilename()
if file_path:
selected_image_path = file_path
image = Image.open(file_path)
# 调整图片大小为500x400
image = image.resize((500, 400), Image.Resampling.LANCZOS)
# 居中图片
photo = ImageTk.PhotoImage(image)
canvas_left.create_image(0, 0, anchor='nw', image=photo)
canvas_left.image = photo # Keep a reference!
# 创建图片的标签
if label_text is None:
label_text = tk.Label(window, text="", font=("Arial", 16))
label_text.place(x=1185,y=155)
注:
1. 函数首先获取 .\output 目录下的所有文件,并遍历这些文件。如果文件是图片文件(假设只有 .jpg 和 .png 格式),则将其删除。这可能是为了清理之前的输出,为新的图像处理做准备。
2. 使用 tkinter.filedialog.askopenfilename() 弹出一个文件对话框,让用户选择一个图像文件。如果用户选择了文件,则将文件的路径存储在 selected_image_path 变量中.
3. 使用 PIL.Image.open() 打开用户选择的图像文件.
4. 调整图像的大小为 500x400 像素,使用 Image.Resampling.LANCZOS 重采样方法,这是一种高质量的重采样滤波器。
5. 将调整大小后的图像居中显示在一个名为 canvas_left 的画布上。这里使用了 ImageTk.PhotoImage 来创建一个可以在 Tkinter 中显示的图像对象,并将其附加到画布上。
6. 如果 label_text 是 None,则创建一个新的标签 label_text,用于显示图像的标签文本。标签被放置在窗口的特定位置(坐标为 x=1185, y=155),并设置了字体样式和大小。
8. 车型检查
def start_detection():
global right_canvas_image
if selected_image_path is not None:
image = Image.open(selected_image_path)
input_image = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(input_image)
_, predicted = torch.max(outputs, 1)
label = classes[predicted.item()]
probabilities = torch.nn.functional.softmax(outputs, dim=1)
max_probability = probabilities[0][predicted].item() * 100 # 将概率值乘以100
label_text.config(text=f"{label} - {max_probability:.2f}%") # 显示为百分比格式
# 调整图片大小为500x400
image = image.resize((500, 400), Image.Resampling.LANCZOS)
# 显示图片在右侧画布
photo = ImageTk.PhotoImage(image)
# 检查是否已经创建了右侧画布的图片
if right_canvas_image is None:
right_canvas_image = canvas_right.create_image(0, 0, anchor='nw', image=photo)
else:
canvas_right.itemconfig(right_canvas_image, image=photo)
canvas_right.image = photo # 保持引用
else:
messagebox.showwarning("警告", "请先选择一张图像")
# 将标签放置在图片上
label_text.place(x=1185,y=155)
注:
-
函数首先检查 selected_image_path 是否为 None,如果不是,则说明用户已经选择了一张图像。
-
使用 PIL.Image.open() 打开用户选择的图像文件。
-
将图像通过之前定义的 transform 进行预处理,然后添加一个批次维度,因为模型期望输入的是一个批次的数据。
-
使用 torch.no_grad() 上下文管理器,确保在推理过程中不会计算梯度,这会减少计算资源的使用。
-
将预处理后的图像输入到模型中,得到输出 outputs。
-
使用 torch.max() 函数获取最高分数的类别索引 predicted。
-
根据类别索引从 classes 列表中获取类别标签 label。
-
计算 outputs 的softmax概率,并获取最大概率值 max_probability。
-
更新 label_text 的配置,显示类别标签和最大概率值。
-
调整图像大小为 500x400 像素,并使用 ImageTk.PhotoImage 创建一个可以在 Tkinter 中显示的图像对象。
-
如果 right_canvas_image 是 None,则创建一个新的图像对象在 canvas_right 上显示。如果不是 None,则更新现有的图像对象。
-
如果用户没有选择图像,则使用 tkinter.messagebox.showwarning() 显示一个警告消息框。
-
最后,将 label_text 放置在窗口的指定位置。
运行结果: