第97步 深度学习图像目标检测:RetinaNet建模

基于WIN10的64位系统演示

一、写在前面

本期开始,我们继续学习深度学习图像目标检测系列,RetinaNet模型。

二、RetinaNet简介

RetinaNet 是由 Facebook AI Research (FAIR) 的研究人员在 2017 年提出的一种目标检测模型。它是一种单阶段(one-stage)的目标检测方法,但通过引入一个名为 Focal Loss 的创新损失函数,RetinaNet 解决了单阶段检测器常面临的正负样本不平衡问题。以下是 RetinaNet 的主要特点:

(1)Focal Loss:

传统的交叉熵损失往往由于背景类(负样本)的数量远大于目标类(正样本)而导致训练不稳定。为了解决这一不平衡问题,RetinaNet 引入了 Focal Loss。Focal Loss 被设计为更重视那些难以分类的负样本,而减少对容易分类的背景类的关注。这有助于提高模型对目标的检测精度。

(2)特征金字塔网络 (FPN):

RetinaNet 使用了特征金字塔网络 (FPN) 作为其骨干网络,这是一个为多尺度目标检测设计的卷积网络结构。FPN 可以从单张图像中提取多尺度的特征,使得模型能够有效地检测不同大小的物体。

(3)预定义锚框:

与其他一阶段检测器相似,RetinaNet 在其特征图上使用预定义的锚框来预测目标的位置和类别。

三、数据源

来源于公共数据,文件设置如下:

大概的任务就是:用一个框框标记出MTB的位置。

四、RetinaNet实战

直接上代码:

import os
import random
import torch
import torchvision
from torchvision.models.detection import retinanet_resnet50_fpn 
from torchvision.transforms import functional as F
from PIL import Image
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

# Function to parse XML annotations
def parse_xml(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    boxes = []
    for obj in root.findall("object"):
        bndbox = obj.find("bndbox")
        xmin = int(bndbox.find("xmin").text)
        ymin = int(bndbox.find("ymin").text)
        xmax = int(bndbox.find("xmax").text)
        ymax = int(bndbox.find("ymax").text)

        # Check if the bounding box is valid
        if xmin < xmax and ymin < ymax:
            boxes.append((xmin, ymin, xmax, ymax))
        else:
            print(f"Warning: Ignored invalid box in {xml_path} - ({xmin}, {ymin}, {xmax}, {ymax})")

    return boxes

# Function to split data into training and validation sets
def split_data(image_dir, split_ratio=0.8):
    all_images = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]
    random.shuffle(all_images)
    split_idx = int(len(all_images) * split_ratio)
    train_images = all_images[:split_idx]
    val_images = all_images[split_idx:]
    
    return train_images, val_images


# Dataset class for the Tuberculosis dataset
class TuberculosisDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, annotation_dir, image_list, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.image_list = image_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_list[idx])
        image = Image.open(image_path).convert("RGB")
        
        xml_path = os.path.join(self.annotation_dir, self.image_list[idx].replace(".jpg", ".xml"))
        boxes = parse_xml(xml_path)
        
        # Check for empty bounding boxes and return None
        if len(boxes) == 0:
            return None
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((len(boxes),), dtype=torch.int64)
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        target["iscrowd"] = iscrowd
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
    
        return image, target

# Define the transformations using torchvision
data_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),  # Convert PIL image to tensor
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])


# Adjusting the DataLoader collate function to handle None values
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return tuple(zip(*batch))


def get_retinanet_model_for_finetuning(num_classes):
    # Load a RetinaNet model with a ResNet-50-FPN backbone without pre-trained weights
    model = retinanet_resnet50_fpn(pretrained=False, num_classes=num_classes)
    return model


# Function to save the model
def save_model(model, path="RetinaNet_mtb.pth", save_full_model=False):
    if save_full_model:
        torch.save(model, path)
    else:
        torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

# Function to compute Intersection over Union
def compute_iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

# Adjusting the DataLoader collate function to handle None values and entirely empty batches
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0:
        # Return placeholder batch if entirely empty
        return [torch.zeros(1, 3, 224, 224)], [{}]
    return tuple(zip(*batch))

#Training function with modifications for collecting IoU and loss
def train_model(model, train_loader, optimizer, device, num_epochs=10):
    model.train()
    model.to(device)
    loss_values = []
    iou_values = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        total_ious = 0
        num_boxes = 0
        for images, targets in train_loader:
            # Skip batches with placeholder data
            if len(targets) == 1 and not targets[0]:
                continue
            # Skip batches with empty targets
            if any(len(target["boxes"]) == 0 for target in targets):
                continue
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            epoch_loss += losses.item()
            
            # Compute IoU for evaluation
            with torch.no_grad():
                model.eval()
                predictions = model(images)
                for i, prediction in enumerate(predictions):
                    pred_boxes = prediction["boxes"].cpu().numpy()
                    true_boxes = targets[i]["boxes"].cpu().numpy()
                    for pred_box in pred_boxes:
                        for true_box in true_boxes:
                            iou = compute_iou(pred_box, true_box)
                            total_ious += iou
                            num_boxes += 1
                model.train()
        
        avg_loss = epoch_loss / len(train_loader)
        avg_iou = total_ious / num_boxes if num_boxes != 0 else 0
        loss_values.append(avg_loss)
        iou_values.append(avg_iou)
        print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss} Avg IoU: {avg_iou}")
    
    # Plotting loss and IoU values
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(loss_values, label="Training Loss")
    plt.title("Training Loss across Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    
    plt.subplot(1, 2, 2)
    plt.plot(iou_values, label="IoU")
    plt.title("IoU across Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("IoU")
    plt.show()

    # Save model after training
    save_model(model)

# Validation function
def validate_model(model, val_loader, device):
    model.eval()
    model.to(device)
    
    with torch.no_grad():
        for images, targets in val_loader:
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            model(images)

# Paths to your data
image_dir = "tuberculosis-phonecamera"
annotation_dir = "tuberculosis-phonecamera"

# Split data
train_images, val_images = split_data(image_dir)

# Create datasets and dataloaders
train_dataset = TuberculosisDataset(image_dir, annotation_dir, train_images, transform=data_transform)
val_dataset = TuberculosisDataset(image_dir, annotation_dir, val_images, transform=data_transform)

# Updated DataLoader with new collate function
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

# Model and optimizer
model = get_retinanet_model_for_finetuning(2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train and validate
train_model(model, train_loader, optimizer, device="cuda", num_epochs=10)
validate_model(model, val_loader, device="cuda")


#######################################Print Metrics######################################
def calculate_metrics(predictions, ground_truths, iou_threshold=0.5):
    TP = 0  # True Positives
    FP = 0  # False Positives
    FN = 0  # False Negatives
    total_iou = 0  # to calculate mean IoU

    for pred, gt in zip(predictions, ground_truths):
        pred_boxes = pred["boxes"].cpu().numpy()
        gt_boxes = gt["boxes"].cpu().numpy()

        # Match predicted boxes to ground truth boxes
        for pred_box in pred_boxes:
            max_iou = 0
            matched = False
            for gt_box in gt_boxes:
                iou = compute_iou(pred_box, gt_box)
                if iou > max_iou:
                    max_iou = iou
                    if iou > iou_threshold:
                        matched = True

            total_iou += max_iou
            if matched:
                TP += 1
            else:
                FP += 1

        FN += len(gt_boxes) - TP

    precision = TP / (TP + FP) if (TP + FP) != 0 else 0
    recall = TP / (TP + FN) if (TP + FN) != 0 else 0
    f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
    mean_iou = total_iou / (TP + FP) if (TP + FP) != 0 else 0

    return precision, recall, f1_score, mean_iou

def evaluate_model(model, dataloader, device):
    model.eval()
    model.to(device)
    all_predictions = []
    all_ground_truths = []

    with torch.no_grad():
        for images, targets in dataloader:
            images = [image.to(device) for image in images]
            predictions = model(images)

            all_predictions.extend(predictions)
            all_ground_truths.extend(targets)

    precision, recall, f1_score, mean_iou = calculate_metrics(all_predictions, all_ground_truths)
    return precision, recall, f1_score, mean_iou


train_precision, train_recall, train_f1, train_iou = evaluate_model(model, train_loader, "cuda")
val_precision, val_recall, val_f1, val_iou = evaluate_model(model, val_loader, "cuda")

print("Training Set Metrics:")
print(f"Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1 Score: {train_f1:.4f}, Mean IoU: {train_iou:.4f}")

print("\nValidation Set Metrics:")
print(f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1 Score: {val_f1:.4f}, Mean IoU: {val_iou:.4f}")

#sheet
header = "| Metric    | Training Set | Validation Set |"
divider = "+----------+--------------+----------------+"

train_metrics = f"| Precision | {train_precision:.4f}      | {val_precision:.4f}          |"
recall_metrics = f"| Recall    | {train_recall:.4f}      | {val_recall:.4f}          |"
f1_metrics = f"| F1 Score  | {train_f1:.4f}      | {val_f1:.4f}          |"
iou_metrics = f"| Mean IoU  | {train_iou:.4f}      | {val_iou:.4f}          |"

print(header)
print(divider)
print(train_metrics)
print(recall_metrics)
print(f1_metrics)
print(iou_metrics)
print(divider)

#######################################Train Set######################################
import numpy as np
import matplotlib.pyplot as plt

def plot_predictions_on_image(model, dataset, device, title):
    # Select a random image from the dataset
    idx = np.random.randint(50, len(dataset))
    image, target = dataset[idx]
    img_tensor = image.clone().detach().to(device).unsqueeze(0)

    # Use the model to make predictions
    model.eval()
    with torch.no_grad():
        prediction = model(img_tensor)

    # Inverse normalization for visualization
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    image = inv_normalize(image)
    image = torch.clamp(image, 0, 1)
    image = F.to_pil_image(image)

    # Plot the image with ground truth boxes
    plt.figure(figsize=(10, 6))
    plt.title(title + " with Ground Truth Boxes")
    plt.imshow(image)
    ax = plt.gca()

    # Draw the ground truth boxes in blue
    for box in target["boxes"]:
        rect = plt.Rectangle(
            (box[0], box[1]), box[2]-box[0], box[3]-box[1],
            fill=False, color='blue', linewidth=2
        )
        ax.add_patch(rect)
    plt.show()

    # Plot the image with predicted boxes
    plt.figure(figsize=(10, 6))
    plt.title(title + " with Predicted Boxes")
    plt.imshow(image)
    ax = plt.gca()

    # Draw the predicted boxes in red
    for box in prediction[0]["boxes"].cpu():
        rect = plt.Rectangle(
            (box[0], box[1]), box[2]-box[0], box[3]-box[1],
            fill=False, color='red', linewidth=2
        )
        ax.add_patch(rect)
    plt.show()

# Call the function for a random image from the train dataset
plot_predictions_on_image(model, train_dataset, "cuda", "Selected from Training Set")


#######################################Val Set######################################

# Call the function for a random image from the validation dataset
plot_predictions_on_image(model, val_dataset, "cuda", "Selected from Validation Set")

这回也是需要从头训练的,就不跑了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/187692.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

学习了解抽象思维的应用与实践

目录 一、快速了解抽象思维 &#xff08;一&#xff09;抽象思维的本质理解 &#xff08;二&#xff09;系统架构中的重要性 &#xff08;三&#xff09;软件开发中抽象的基本过程思考 意识和手段 抽象的方式 抽象层次的权衡 二、业务中的应用实践 &#xff08;一&…

Django 通过 Trunc(kind) 和 Extract(lookup_name) 参数进行潜在 SQL 注入 (CVE-2022-34265)

漏洞描述 Django 于 2022 年6月4 日发布了一个安全更新&#xff0c;修复了 Trunc&#xff08;&#xff09; 和 Extract&#xff08;&#xff09; 数据库函数中的 SQL 注入漏洞。 参考链接&#xff1a; Django security releases issued: 4.0.6 and 3.2.14 | Weblog | Djang…

ArkTs变量类型、数据类型

可以参考官网学习路径学习HarmonyOS第一课|应用开发视频教程学习|HarmonyOS应用开发官网 一、变量 1、ArkTS语言 ArkTS是华为自研的开发语言。它在TypeScript&#xff08;简称TS&#xff09;的基础上&#xff0c;匹配ArkUI框架&#xff0c;扩展了声明式UI、状态管理等相应的…

mac 修改 hosts 文件

打开 hosts 所在文件夹 command shift G 快捷键 输入&#xff1a;“/private/etc/hosts” 后回车 如下所示 进入 hosts 文件所在位置&#xff0c;找到 hosts 文件&#xff0c;双击打开 修改 hosts 文件 将所需要的配置信息追加到hosts 文件中&#xff0c;或者修改需要改…

mysql忘记密码,然后重置

数据库版本8.0.26 只针对以下情况 mysql忘记了密码&#xff0c;但是你navicat之前连接上了 解决方法&#xff1a; 第一步&#xff0c;选中mysql这个数据库&#xff0c;点击新建查询 第二步&#xff1a;重置密码 alter user rootlocalhost IDENTIFIED BY 你的密码; 然后就可…

ffmpeg下载与配置环境变量

FFmpeg 是一个强大的多媒体框架&#xff0c;可以让用户处理和操纵音频和视频文件。具有易于使用的界面&#xff0c;用户可以在 Windows、Mac 或 Linux Ubuntu 系统上下载 FFmpeg 并将其提取到文件夹中。然后&#xff0c;该软件可以加入 PATH 环境变量中就可以快捷的使用软件了.…

Android自动化测试中使用ADB进行网络状态管理

Android自动化测试中的网络状态切换是提高测试覆盖率、捕获潜在问题的关键步骤之一&#xff0c;本文将介绍 如何使用ADB检测和管理Android设备的网络状态。 自动化测试中的网络状态切换变得尤为重要。 同时&#xff0c;在这我准备了一份软件测试视频教程&#xff08;含接口、自…

maven 将Jar包安装到本地仓库

window系统&#xff1a; 注意事项&#xff1a;在windows中&#xff0c;使用mvn指令将jar安装到本地仓库时&#xff0c;一定要将相关资源使用“"”包裹上&#xff0c;不然会报下面的错&#xff1a; mvn install:install-file "-DfileD:\BaiduNetdiskDownload\qianzixi…

轻松植入分布式跟踪:Odigos 带你主导应用观测 | 开源日报 No.85

babysor/MockingBird Stars: 31.6k License: NOASSERTION 这个项目是一个实时语音克隆的开源项目&#xff0c;主要功能包括支持中文、使用 PyTorch 进行训练和推理、可以在 Windows 和 Linux 系统上运行以及提供 Web 服务器。该项目的核心优势和特点包括&#xff1a; 支持多种…

视频批量剪辑技巧:掌握视频嵌套合并,轻松成为视频剪辑高手

随着社交媒体的兴起&#xff0c;视频已成为人们分享和交流的重要方式。视频剪辑作为视频制作的关键环节&#xff0c;对于提升视频质量和吸引力至关重要。视频嵌套合并是一种高级视频剪辑技巧&#xff0c;它将两个或多个视频片段叠加在一起&#xff0c;创造出一种独特的效果。这…

XShell新建会话指南

XShell新建会话 我们先登录我们的xshell&#xff0c;连接我们的远程服务器 为了方便我们以后的使用&#xff0c;我们可以新建一个会话记住用户 新建好后&#xff0c;我们可以打开这个会话 我们选择记住用户名 然后继续输密码就可以了 之后我们每次打开xshell的时候&#xff0c…

数据丢失抢救神器之TOP10 Android 数据恢复榜单

在快节奏的数字时代&#xff0c;我们的生活越来越与智能手机交织在一起&#xff0c;使它们成为重要数据和珍贵记忆的存储库。由于意外删除、软件故障或硬件故障而丢失数据可能是一种痛苦的经历。值得庆幸的是&#xff0c;技术领域提供了 Android 数据恢复软件形式的解决方案。这…

02 RANSAC算法 及 Python 实现

文章目录 02 RANSAC算法 及 Python 实现2.1 简介2.2 算法流程2.3 RANSAC 算法实现直线拟合2.4 利用 RANSAC 算法减少 ORB 特征点误匹配 02 RANSAC算法 及 Python 实现 2.1 简介 RANSAC &#xff08;Random Sample Consensus&#xff0c;随机抽样一致&#xff09;算法的 基本假…

中职组网络安全-linux渗透测试-Server2203(环境+解析)

任务环境说明&#xff1a; 服务器场景&#xff1a;Server2203&#xff08;关闭链接&#xff09; 用户名&#xff1a;hacker 密码&#xff1a;123456 1.使用渗透机对服务器信息收集&#xff0c;并将服务器中SSH服务端口号作为flag提交&#xff1b; FLAG:2232 2. 使用渗透机对…

最全的软件测试教程(功能、工具、接口、自动化、性能)

一、软件测试功能测试 测试用例编写是软件测试的基本技能&#xff1b;也有很多人认为测试用例是软件测试的核心&#xff1b;软件测试中最重要的是设计和生成有效的测试用例&#xff1b;测试用例是测试工作的指导&#xff0c;是软件测试的必须遵守的准则。 在这我也准备了一份…

华为ensp:trunk链路

当我们使用trunk链路后&#xff0c;还要选择要放行的vlan那就是全部vlan&#xff08;all&#xff09;&#xff0c;但是all并不包括vlan1&#xff0c;所以我们的trunk链路中的all不对all进行放行 实现相同vlan之间的通信 先将他们加入对应的vlan lsw1 进入e0/0/3接口 interfa…

Linux-基本指令(1.0)

Linux是一个非常流行的操作的知识&#xff0c;并提供实例帮助读者更好地理解。让我们一起来学习吧&#xff01;系统&#xff0c;也是云计算、大数据、人工智能等领域的重要基础。学习Linux命令是Linux系统管理的基础&#xff0c;也是开发过程中必不可少的技能。本博客将介绍Lin…

2023年10月纸巾市场分析(京东天猫淘宝平台纸巾品类数据采集)

双十一大促期间&#xff0c;刚需品的纸巾是必囤商品之一。今年双十一&#xff0c;京东数据显示&#xff0c;10月23日至29日&#xff0c;清洁纸品成交额同比增长40%&#xff0c;由此也拉动了10月纸巾市场的销售。 鲸参谋数据显示&#xff0c;今年10月&#xff0c;京东平台纸巾市…

外汇天眼:香港监管机构对AMTD Global Markets Limited启动法律诉讼

香港证监会&#xff08;SFC&#xff09;已经启动了法律程序&#xff0c;要求首次审裁法院调查AMTD Global Markets Limited&#xff08;AMTD&#xff0c;目前以orientiert XYZ Securities Limited为名&#xff09;及其前高管在与首次公开发行&#xff08;IPO&#xff09;相关的…