第90步 深度学习图像分割:U-Net建模

基于WIN10的64位系统演示

一、写在前面

从这一期开始,我们杀个回马枪,继续学习深度学习图像分割系列,以为4090上岗了。

图像分割是计算机视觉的一个重要任务,目的是将数字图像分割成多个部分或区域,这些部分通常对应于现实世界中的物体或其组成部分。

(1)基本原理:图像分割的主要目标是为图像中的每个像素分配一个标签,从而将整个图像划分为多个不同的区域或物体。因此,本质上还是一个分类问题。

(2)常见应用:

(a)医学图像: 用于病灶检测、器官定位和疾病诊断。

(b)自动驾驶: 对周围环境进行实时分析,例如检测行人、车辆和道路。

(3)常见模型:

(a)U-Net: 该模型特别适用于医学图像分割。它有一个收缩的路径和一个对称的扩展路径,形成U型结构。

(b)Mask R-CNN: 在Faster R-CNN的基础上,增加了一个并行的分支来预测图像的分割掩模。

(c)FCN (Fully Convolutional Network): 第一个将深度卷积网络端到端应用于图像分割的方法。它使用上采样层将卷积特征图转换回像素级预测。

本期,我们来尝试一下U-Net。

二、U-Net

U-Net 是为生物医学图像分割而设计的一个深度学习模型,其名字“U-Net”来源于其U型的结构。

(1)架构:U-Net由两部分组成:一个“收缩”(或下采样)路径一个“扩展”(或上采样)路径,这两个路径共同构成了一个U型结构。

(2)收缩路径:这是一个典型的卷积神经网络结构,包含了重复的两个3x3的卷积操作(每个后面都跟着ReLU激活函数),接着是一个2x2的最大池化操作来下采样。随着网络深入,特征通道的数量会加倍。此路径的目的是捕捉图像的上下文信息。

(3)扩展路径:为了得到精确的位置信息,U-Net使用了一个对称的扩展路径。

这个路径首先使用2x2的上采样操作,然后与相应的特征图进行连接,这种连接是为了获取更高分辨率的特征。接着,进行两次3x3的卷积操作,后面跟着ReLU激活函数。特征通道的数量随着网络深入而减半。

(4)跳跃连接:U-Net的一个关键特点是其跳跃连接(或称为“跳级连接”)。

在收缩路径中的每一步都有一个直接连接到扩展路径中相应步骤的连接,这保证了即使在深层网络中也能获取高分辨率的特征。

(5)最后的图层:在网络的最后是一个1x1的卷积层,用来将64个通道的特征向量映射到所需的输出类别数。

(2)数据源:

来源于公共数据,主要目的是使用U-Net分割出电子显微镜下的细胞边缘:

数据分为训练集(train)、训练集的细胞边缘数据(label)以及验证集(test)注意哈,没有提供验证集的细胞边缘数据。因此,后面是算不出验证集的性能参数的。

(2)U-Net实战:

上代码:

(a)数据读取和数据增强

import os
import numpy as np
from skimage.io import imread
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, UpSampling2D, Dropout, Softmax
from tensorflow.python.keras.optimizers import adam_v2
from tensorflow.python.keras.callbacks import ModelCheckpoint
import tensorflow as tf

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# 设置文件路径
data_folder = 'U-net-master\data_set'
train_images_folder = os.path.join(data_folder, 'train')
label_images_folder = os.path.join(data_folder, 'label')
test_images_folder = os.path.join(data_folder, 'test')

train_images = sorted(os.listdir(train_images_folder))
label_images = sorted(os.listdir(label_images_folder))
test_images = sorted(os.listdir(test_images_folder))

# 读取训练和测试图像
X_train = np.array([imread(os.path.join(train_images_folder, img)) for img in train_images])
X_train = np.stack((X_train,)*3, axis=-1)  # 复制通道以创建三通道图像

X_test = np.array([imread(os.path.join(test_images_folder, img)) for img in test_images])
X_test = np.stack((X_test,)*3, axis=-1)


y_train = np.array([imread(os.path.join(label_images_folder, img)) for img in label_images])
y_train = np.expand_dims(y_train, axis=-1)  # 增加一个类别维度


# 定义数据增强生成器
data_gen_args = dict(rotation_range=0.2,
                     width_shift_range=0.05,
                     height_shift_range=0.05,
                     shear_range=0.05,
                     zoom_range=0.05,
                     horizontal_flip=True,
                     rescale=1./255,
                     fill_mode='nearest')
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

# 将种子提供给随机数生成器
seed = 1
# 将同样的种子应用于图像和标签以确保其转换方式相同
image_datagen.fit(X_train, augment=True, seed=seed)
mask_datagen.fit(y_train, augment=True, seed=seed)

image_generator = image_datagen.flow(X_train, batch_size=8, seed=seed)
mask_generator = mask_datagen.flow(y_train, batch_size=8, seed=seed)

# 将生成器组合成一个生成器,产生图像和标签
train_generator = zip(image_generator, mask_generator)

X_test = np.array([imread(os.path.join(test_images_folder, img)) for img in test_images])
X_test = np.stack((X_test,)*3, axis=-1)  # 复制通道以创建三通道图像

解读:

其他没什么好说的,就是要注意:上述代码的数据需要人工的安排训练集和测试集。严格按照下面格式放置好各个文件,包括文件夹的命名也不要变动:

(b)U-Net建模

# 定义U-Net模型结构
def get_unet(input_shape):
    inputs = Input(input_shape)

    # 下采样部分
    c1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    c1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c1)
    p1 = MaxPooling2D(pool_size=(2, 2))(c1)

    c2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(p1)
    c2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c2)
    p2 = MaxPooling2D(pool_size=(2, 2))(c2)

    c3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(p2)
    c3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c3)
    p3 = MaxPooling2D(pool_size=(2, 2))(c3)

    c4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(p3)
    c4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c4)
    drop4 = Dropout(0.5)(c4)
    p4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    c5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(p4)
    c5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c5)
    drop5 = Dropout(0.5)(c5)

    # 上采样部分
    u6 = UpSampling2D(size=(2, 2))(drop5)
    u6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(u6)
    merge6 = concatenate([drop4, u6], axis=3)
    c6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
    c6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c6)

    u7 = UpSampling2D(size=(2, 2))(c6)
    u7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(u7)
    merge7 = concatenate([c3, u7], axis=3)
    c7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
    c7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c7)

    u8 = UpSampling2D(size=(2, 2))(c7)
    u8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(u8)
    merge8 = concatenate([c2, u8], axis=3)
    c8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
    c8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c8)

    u9 = UpSampling2D(size=(2, 2))(c8)
    u9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(u9)
    merge9 = concatenate([c1, u9], axis=3)
    c9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
    c9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(c9)

    c10 = Conv2D(1, 1, activation='sigmoid')(c9)

    model = Model(inputs=[inputs], outputs=[c10])

    return model


# 获取模型
model = get_unet(X_train.shape[1:])

# 编译模型
model.compile(optimizer=adam_v2.Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])

# 设置模型检查点以保存训练中的最佳模型
model_checkpoint = ModelCheckpoint('unet_membrane.hdf5', monitor='loss', verbose=1, save_best_only=True)

# 训练模型
history = model.fit(train_generator, steps_per_epoch=len(X_train) // 16, epochs=100, verbose=1, callbacks=[model_checkpoint])

让GPT解读:

可能是用了4090,1分钟不到:

(c)各种性能指标打印和可视化

###################################误差曲线#######################################

import matplotlib.pyplot as plt

# 设置matplotlib支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用SimHei字体
plt.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题

# 绘制训练损失和准确率
plt.figure(figsize=(12, 5))

# 绘制损失
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='训练损失')
plt.title('损失随迭代次数的变化')
plt.xlabel('迭代次数')
plt.ylabel('损失')
plt.legend()

# 绘制准确率
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.title('准确率随迭代次数的变化')
plt.xlabel('迭代次数')
plt.ylabel('准确率')
plt.legend()

plt.tight_layout()
plt.show()

##############################评价指标#######################################
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, accuracy_score, recall_score, precision_score, f1_score

# 预测训练集
train_pred = model.predict(X_train)

# 确保y_train中的值是0或1
y_train[y_train == 255] = 1

def calc_iou(y_true, y_pred):
    intersection = np.logical_and(y_true, y_pred)
    union = np.logical_or(y_true, y_pred)
    return np.sum(intersection) / np.sum(union)

# 计算ROC曲线
fpr_train, tpr_train, _ = roc_curve(y_train.ravel(), train_pred.ravel())

# 计算AUC
auc_train = auc(fpr_train, tpr_train)

# 计算其他评估指标
pixel_accuracy_train = accuracy_score(y_train.ravel(), train_pred.ravel() > 0.5)
iou_train = calc_iou(y_train, train_pred > 0.5)
accuracy_train = accuracy_score(y_train.ravel(), train_pred.ravel() > 0.5)
recall_train = recall_score(y_train.ravel(), train_pred.ravel() > 0.5)
precision_train = precision_score(y_train.ravel(), train_pred.ravel() > 0.5)
f1_train = f1_score(y_train.ravel(), train_pred.ravel() > 0.5)

# 绘制ROC曲线
plt.figure()
plt.plot(fpr_train, tpr_train, color='blue', lw=2, label='Train ROC curve (area = %0.2f)' % auc_train)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.show()

# 定义指标列表
metrics = [
    ("Pixel Accuracy", pixel_accuracy_train),
    ("IoU", iou_train),
    ("Accuracy", accuracy_train),
    ("Recall", recall_train),
    ("Precision", precision_train),
    ("F1 Score", f1_train)
]

# 打印表格的头部
print("+-----------------+------------+")
print("| Metric          | Value      |")
print("+-----------------+------------+")

# 打印每个指标的值
for metric_name, metric_value in metrics:
    print(f"| {metric_name:15} | {metric_value:.6f} |")
    print("+-----------------+------------+")

直接看结果:

误差和准确率曲线,看起来模型收敛的不错。

ROC曲线:这里存疑,感觉没啥意义,而且这个曲线看起来有问题,是一个三点折线。

一些性能指标,稍微解释,主要是前两个:

A)Pixel Accuracy:

定义:它是所有正确分类的像素总数与图像中所有像素的总数的比率。

计算:(正确预测的像素数量) / (所有像素数量)。

说明:这个指标评估了模型在每个像素级别上的准确性。但在某些场景中(尤其是当类别非常不平衡时),这个指标可能并不完全反映模型的表现。

B)IoU (Intersection over Union):

定义:对于每个类别,IoU 是该类别的预测结果(预测为该类别的像素)与真实标签之间的交集与并集的比率。

计算:(预测与真实标签的交集) / (预测与真实标签的并集)。

说明:它是一个很好的指标,因为它同时考虑了假阳性和假阴性,尤其在类别不平衡的情况下。

C)Accuracy:

定义:是所有正确分类的像素与所有像素的比例,通常与 Pixel Accuracy 相似。

计算:(正确预测的像素数量) / (所有像素数量)。

D)Recall (or Sensitivity or True Positive Rate):

定义:是真实正样本被正确预测的比例。

计算:(真阳性) / (真阳性 + 假阴性)。

说明:高召回率表示少数阳性样本不容易被漏掉。

E)Precision:

定义:是被预测为正的样本中实际为正的比例。

计算:(真阳性) / (真阳性 + 假阳性)。

说明:高精度表示假阳性的数量很少。

F)F1 Score:

定义:是精度和召回率的调和平均值。它考虑了假阳性和假阴性,并试图找到两者之间的平衡。

计算:2 × (精度 × 召回率) / (精度 + 召回率)。

说明:在不平衡类别的场景中,F1 Score 通常比单一的精度或召回率更有用。

(d)查看验证集具体分割情况

#看具体分割的效果
import matplotlib.pyplot as plt

# 选择一张测试图片
img_index = 3
test_img = X_test[img_index]

# 扩展维度以匹配模型输入,因为模型需要四个维度的输入,然后进行预测
test_img = np.expand_dims(test_img, axis=0)
pred = model.predict(test_img)

# 移除添加的维度,以便显示图像
pred_img = np.squeeze(pred)

# 使用matplotlib来展示原始图像和预测的分割图像
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(np.squeeze(test_img), cmap='gray')
plt.subplot(1, 2, 2)
plt.title("Predicted Segmentation")
plt.imshow(pred_img, cmap='gray')
plt.show()

随意从验证集挑一张图片,查看分割效果:

总体来看,勉强过关,收工!

四、写在后面

以上,只是U-Net的最简单的应用了,不过对于硬件要求还是挺高的,训练起来显卡可以煮鸡蛋的感觉。

后面会单独开个专栏,深入研究各种五花八门的数据应用。

五、数据

链接:https://pan.baidu.com/s/1Cb78MwfSBfLwlpIT0X3q9Q?pwd=u1q1

提取码:u1q1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/136095.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

goroutine调度模型 调度策略

文章目录 背景 协程线程与协程的对比线程(Thread)协程(Coroutine) 运作线程模型 goroutine调度模型与演进过程G-M模型G-P-M模型抢占式调度器其他优化 调度策略队列轮转系统调用工作量窃取抢占式调度GOMAXPROCS 对性能的影响 Go在语…

multilinear多项式承诺方案benchmark对比

1. 引言 前序博客有: Lasso、Jolt 以及 Lookup Singularity——Part 1Lasso、Jolt 以及 Lookup Singularity——Part 2深入了解LassoJolt Lasso lookup中,multilinear多项式承诺方案的高效性至关重要。 本文重点关注4种multilinear多项式承诺方案的实…

【Python基础】try-finally语句和with语句

📢:如果你也对机器人、人工智能感兴趣,看来我们志同道合✨ 📢:不妨浏览一下我的博客主页【https://blog.csdn.net/weixin_51244852】 📢:文章若有幸对你有帮助,可点赞 👍…

Springboot+vue的毕业生实习与就业管理系统(有报告)。Javaee项目,springboot vue前后端分离项目。

演示视频: Springbootvue的毕业生实习与就业管理系统(有报告)。Javaee项目,springboot vue前后端分离项目 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点…

logback异步日志打印阻塞工作线程

前言 最新做项目,发现一些历史遗留问题,典型的是日志打印的配置问题,其实都是些简单问题,但是往往简单问题引起严重的事故,比如日志打印阻塞工作线程,以logback和log4j2为例。logback实际上是springboot的…

Smart Link 和 Monitor Link应用

定义 Smart Link常用于双上行链路组网,提高接入的可靠性。 Monitor Link通过监视上行接口,使下行接口同步上行接口状态,起到传递故障信息的作用。 Smart Link,又叫做备份链路。一个Smart Link由两个接口组成,其中一个…

2016年408计网

这一年,计算机网络部分的全部考题都围绕该网络拓扑图进行。 第33题 在 OSI 参考模型中, R1、Switch、Hub 实现的最高功能层分别是() A. 2、2、1 B. 2、2、2 C. 3、2、1 D. 3、2、2 本题考察路由器、以太网交换机、集线器各自实现的最高功能层是什么题目给定R1是…

链表OJ题【环形链表】(3)

目录 环形问题的思考 ❓Q1 ❓Q2 🙂Q2 ❓Q3 ❓Q4 8.环形链表 9.环形链表Ⅱ 今天接着链表的经典问题环形问题。大家一定要自己动手多写写。🙂 快慢指针(保持相对距离/保持相对速度)野指针考虑为NULL的情况带环链表&#x…

Java14新增特性

前言 前面的文章,我们对Java9、Java10、Java11、Java12 、Java13的特性进行了介绍,对应的文章如下 Java9新增特性 Java10新增特性 Java11新增特性 Java12新增特性 Java13新增特性 今天我们来一起看一下Java14这个版本的一些重要信息 版本介绍 Java 14…

No180.精选前端面试题,享受每天的挑战和学习

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云课上架的前后端实战课程《Vue.js 和 Egg.js 开发企业级健康管理项目》、《带你从入…

【图像处理:OpenCV-Python基础操作】

【图像处理:OpenCV-Python基础操作】 1 读取图像2 显示图像3 保存图像4 图像二值化、灰度图、彩色图,像素替换5 通道处理(通道拆分、合并)6 调整尺寸大小7 提取感兴趣区域、掩膜8 乘法、逻辑运算9 HSV色彩空间,获取特定…

【算法每日一练]-单调队列,滑动窗口(保姆级教程 篇1) #滑动窗口 #求m区间的最小值 #理想的正方形 #切蛋糕

今天讲单调队列 目录 题目:滑动窗口 思路: 题目:求m区间的最小值​ 思路: 题目:理想的正方形 思路: 题目:切蛋糕 思路: 一共两种类型:一种是区间中的最值&…

游戏制作:猜数字(1~100),不会也可以先试着玩玩

目录 1 2代码如下:可以试着先玩玩 3运行结果:嘿嘿嘿 4程序分析:想学的看 5总结: 1 猜数范围为1~100,猜大输出猜大了,猜小输出猜小了,游戏可以无限玩。 首先先做一个简单的菜单界面&#xf…

RK3588平台 WIFI的基本概念

一.安卓WIFI框架 Android WIFI系统引入了wpa_supplicant,它的整个WIFI系统以wpa_supplicant为核心来定义上层接口和下层驱动接口。Android WIFI主要分为六大层,分别是WiFi Settings层,Wifi Framework层,Wifi JNI 层, W…

WorkPlus Meet:局域网内部使用的高效视频会议系统

随着全球化和远程办公的趋势,视频会议已成为现代企业和机构不可或缺的沟通工具。而现在,大多数政企单位或者涉密强的企业,都会使用局域网部署的音视频会议系统,提供更高的安全性和隐私保护。因为音视频会议中可能涉及到公司机密和…

Torch Hub 系列#2:VGG 和 ResNet

一、说明 在上一篇教程中,我们了解了 Torch Hub 背后的本质及其概念。然后,我们使用 Torch Hub 的复杂性发布了我们的模型,并通过相同的方式访问它。但是,当我们的工作要求我们利用 Torch Hub 上提供的众多全能模型之一时,会发生什么? 在本教程中,我们将学习如何利用称为…

自动泊车轨迹规划学习

1.基于6次多项式的自动泊车轨迹算法研究 针对常见的自动泊车系统无法躲避障碍物,以及轨迹的曲率不连续等问题进行了泊车轨迹算法的研究以及跟踪算法的设计。 针对低速自动泊车场景进行分析,建立符合对应场景下的车辆运动学模型以及能够泊车的最小车位大…

华为dns mapping配置案例

解决内网PC用公网的dns用域名方法访问公司内网的web服务器: 原理是用DNS mapping方式解决 配置过程:域名——出口公网IP地址——公网端口——协议类型 公司内网client用户填公网dns, 公网上的dns上面已注册有公司对外映射的web服务器的公网…

山西电力市场日前价格预测【2023-11-13】

日前价格预测 预测说明: 如上图所示,预测明日(2023-11-13)山西电力市场全天平均日前电价为428.16元/MWh。其中,最高日前电价为751.89元/MWh,预计出现在18: 30。最低日前电价为289.03元/MWh,预计…

【MySQL系列】 第一章 · MySQL概述

写在前面 Hello大家好, 我是【麟-小白】,一位软件工程专业的学生,喜好计算机知识。希望大家能够一起学习进步呀!本人是一名在读大学生,专业水平有限,如发现错误或不足之处,请多多指正&#xff0…