【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别

文章目录

      • 0. 前言
      • 1. 级联神经网络介绍
      • 2. MTCNN介绍
        • 2.1 MTCNN提出背景
        • 2.2 MTCNN结构
      • 3. MTCNN PyTorch实战
        • 3.1 facenet_pytorch库中的MTCNN
        • 3.2 识别图像数据
        • 3.3 人脸识别
        • 3.4 关键点定位

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文详细介绍MTCNN——多任务级联卷积神经网络的结构,并通过PyTorch实例说明MTCNN在人脸识别上的应用。

MTCNN的全称是Multi-Task Cascaded Convolutional Networks,它的缩写确实是MTCNN不是MTCCN.

1. 级联神经网络介绍

级联(cascaded)神经网络是一种人工神经网络的架构设计,它指的是多个神经网络层按照特定的方式连接起来,形成一个逐层处理信息的多层结构。在级联神经网络中,前一层次网络的输出作为后一层次网络的输入,这种结构允许在网络在深度方向上对复杂性和抽象层数进行增加

级联网络的重要特点是其动态构建特性,即可以从一个小规模的基本网络开始,并随着训练过程自动添加更多的隐藏单元或子网络,逐渐扩展成一个更深层次的结构,然后通过只针对新增部分数据进行训练来更新权重,即增量式学习(Incremental Learning)。这与传统的——构建完整模型后统一进行训练更新权重的思路非常不同。

使用传统的思路,如果发现我们的模型并不适用于待解决的任务,导致要调整模型结构时,通常会意味着之前的训练模型的工作全部白费了。

总结起来级联神经网络具有以下优点:

  1. 自适应结构:级联网络设计允许根据训练数据或学习过程动态调整网络结构,比如自动增加新的层或神经元,以适应更复杂的模式识别任务。

  2. 学习效率提升:可以通过增量学习或局部训练来加快学习速度,只针对新增加的部分进行训练优化。在某些情况下,级联网络可以采用非传统的权重更新机制,不需要在整个网络上执行全局误差反向传播算法。

  3. 鲁棒性和容错性:分层结构有助于提高系统的鲁棒性,单个层次的错误可能在后续层次中得到修正。

2. MTCNN介绍

2.1 MTCNN提出背景

MTCNN是Kaipeng Zhang等人在论文——Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Networks中提出的,其宗旨是通过多任务级联CNN解决两个问题:人脸检测(找出图像中人脸的位置和边界框)和人脸对齐(精确定位面部特征点)。

2.2 MTCNN结构

MTCNN的构建思路可以简单分为下面几个步骤:
在这里插入图片描述

  • 准备步骤:对图像进行缩放,建立图像金字塔;
  • 第一步Proposal-Net:快速选出若干候选框,为下一步准备;
  • 第二步Refine-Net:对第一步的众多候选框进行精选,留下置信度大的候选框;
  • 第三步Output-Net:输出最终bounding box、人脸关键特征定位和置信度。

详细来说P-Net、R-Net和O-Net的结构如下:在这里插入图片描述
通过Netron可以看到facenet_pytorch库中的MTCNN的结构及详细参数如下:
请添加图片描述

  1. P-Net (Proposal Network):

    • 输入是原始图像。
    • 首先通过一个卷积层(Conv2d)将3通道的输入图像转换为10通道特征图,使用3x3的卷积核(kernel_size=(3, 3))。
    • 紧接着使用PReLU激活函数(prelu1)进行非线性变换。
    • 使用最大池化层(MaxPool2d)下采样特征图(pool1),步长为2。
    • 再经过两个卷积层(Conv2d)提取更深层次的特征,并分别用PReLU激活函数(prelu2和prelu3)进行非线性处理。
    • 最后通过两个1x1卷积层(Conv2d)生成两个输出:一个是softmax4_1用于预测每个像素是否为人脸的概率分布,另一个是conv4_2用于回归bounding box的位置信息。
  2. R-Net (Refine Network):

    • 输入是P-Net的候选区域。
    • 类似于P-Net,R-Net也包含多个卷积层与激活函数,以及池化层进行特征提取和下采样。
    • 在最后,通过两个全连接层(Dense或Linear)生成两个输出:softmax5_1用于判断候选框内是否为人脸并给出置信度,dense5_2用于进一步细化人脸框的位置。
  3. O-Net (Output Network):

    • 输入同样是前一级网络(R-Net)筛选后的候选区域。
    • O-Net具有更多的卷积层以获取更精细的特征表达,同样在最后阶段通过三个全连接层生成三个输出:softmax6_1用于人脸分类,dense6_2用于人脸框回归精修,dense6_3用于估计关键点(如眼睛、嘴巴等)的位置。

整个MTCNN模型通过逐步筛选和优化候选区域,在不同尺度上定位和识别图像中的人脸,从而实现高效准确的人脸检测。

3. MTCNN PyTorch实战

3.1 facenet_pytorch库中的MTCNN

facenet_pytorch库中的MTCNN类是一个用于人脸检测的多任务级联卷积神经网络模型实现。直接使用MTCNN类的最大好处就是该模型已经训练好,可以拿来即用,其初始化时接受多个参数,以下是对这些参数的详细解释:

  1. image_size(默认值:160):输出图像的大小(像素),图像会调整为正方形。

  2. margin(默认值:0):在最终图像上添加到边界框的边距(以像素为单位)。需要注意的是,与davidsandberg/facenet库中的应用方式稍有不同,该库在调整原始图像大小之前就对原始图像应用了边距,导致边距与原始图像大小相关(这是davidsandberg/facenet的一个bug)。

  3. min_face_size(默认值:20):要搜索的人脸的最小尺寸。

  4. thresholds(默认值:[0.6, 0.7, 0.7]):MTCNN人脸检测阈值列表,分别对应P-Net、R-Net和O-Net三个阶段的阈值。

  5. factor(默认值:0.709):用于创建人脸大小缩放金字塔的比例因子。

  6. post_process(默认值:True):是否在返回前对图像张量进行后处理。

  7. select_largest(默认值:True):如果检测到多个人脸,是否选择面积最大的一个返回。若设为False,则选择概率最高的人脸返回。

  8. selection_method(默认值:None):指定使用哪种启发式方法进行选择,如果设置此参数将覆盖select_largest

    • "probability":选择概率最高的。
    • "largest":选择面积最大的框。
    • "largest_over_threshold":选择超过一定概率的最大框。
    • "center_weighted_size":基于框大小减去离图像中心加权距离平方后的结果进行选择。
  9. keep_all(默认值:False):如果设为True,则返回所有检测到的人脸,并按照select_largest参数设定的顺序排列。如果指定了保存路径,第一张人脸将被保存至该路径,其余人脸将依次保存为<save_path>1, <save_path>2等。

  10. device(默认值:None):运行神经网络前向传递时所使用的设备。图像张量和模型会在前向传递前复制到这个设备上。

3.2 识别图像数据

这块没有特殊要求,随便去网上下载,以下是我自己的识别对象数据:
在这里插入图片描述

3.3 人脸识别
  • 代码
from facenet_pytorch import MTCNN
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import os

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('the device is:{}'.format(device))

model = MTCNN(image_size=160, margin=0, min_face_size=10,
              thresholds=[0.7,0.7,0.7],factor=0.7, post_process=True, device=device)
path = os.path.abspath('face_img')  #在face_img文件夹下面还要再加一个class_folder文件夹
dataset = datasets.ImageFolder(path)
imgs_list = list(sorted(os.listdir(os.path.join(path,'class_folder'))))

def collate_fn(x):
    return x[0]

loader = DataLoader(dataset, collate_fn = collate_fn, num_workers=0)
index = 0
#detected_faces = []

for pic,_ in loader:
    aligned, confidence = model(pic , return_prob=True)
    if confidence is not None:
        print('Confidence of {} containing human face is {:.8f}'.format(imgs_list[index], confidence))
        detected_faces.append(aligned)
    else:
        print('No human face detected in {}'.format(imgs_list[index]))
    index += 1
    
# 以下是人脸对齐的还原实现
#face_numpy = (detected_faces[0] + 1) * 127.5  # 由于是 [-1, 1] 范围,将其映射到 [0, 255]
#face_numpy = face_numpy.numpy().astype(np.uint8)
#face_image = Image.fromarray(face_numpy.transpose(1, 2, 0)) # 将 Numpy 数组转为 PIL 图像格式,并注意调整通道顺序为 (H, W, C)

#plt.imshow(face_image)
#plt.show()
  • 输出
Confidence of art.png containing human face is 0.99512947
No human face detected in ironman.png
Confidence of man.png containing human face is 0.99643928
No human face detected in ogre.png
Confidence of thanos.png containing human face is 0.96726525
Confidence of woman.png containing human face is 0.99991846

可见MTCNN不认为钢铁侠和食人魔魔法师算“人脸”。MTCNN的输出有2部分:

  1. 对齐后的人脸张量:其范围是[-1, 1],可以将其线性还原到[0, 255]并输出对其后的人脸,例如下图:在这里插入图片描述
  2. 包含人脸的置信度:即上面的0.99512947等置信度数值。
3.4 关键点定位

也可以使用mtcnn.detect()得到人脸得关键点(眼睛、鼻子、嘴角)定位,代码如下:

from facenet_pytorch import MTCNN
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import os
import numpy
import matplotlib
import matplotlib.pyplot as plt

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('the device is:{}'.format(device))

model = MTCNN(image_size=160, margin=0, min_face_size=10,
              thresholds=[0.7,0.7,0.7],factor=0.7, post_process=True, device=device)

path = os.path.abspath('face_img')  #在face_img文件夹下面还要再加一个class_folder文件夹
dataset = datasets.ImageFolder(path)
imgs_list = list(sorted(os.listdir(os.path.join(path,'class_folder'))))

def collate_fn(x):
    return x[0]

loader = DataLoader(dataset, collate_fn = collate_fn, num_workers=0)
index = 0
detected_faces = []

for pic,_ in loader:
    aligned, confidence = model(pic , return_prob=True)
    if confidence is not None:
        print('Confidence of {} containing human face is {:.8f}'.format(imgs_list[index], confidence))
        detected_faces.append(aligned)
        boxes, probs, points = model.detect(pic, landmarks=True)
        points = points.squeeze(0)
        for x,y in points:
            plt.scatter(x,y,s=10,c='r')
        plt.imshow(pic)
        plt.savefig('{}_aligned.jpg'.format(imgs_list[index]))
        plt.close()
    else:
        print('No human face detected in {}'.format(imgs_list[index]))
    index += 1

最终保存的图像为:
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

可以看出MTCNN的关键点定位也是很准确的。上面代码的boxs即为人脸边界框,这里不再画出效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/447565.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Qt学习笔记】(二)--第一个程序“Hello World”(学习Qt中程序的运行、发布、编译过程)

声明&#xff1a;本人水平有限&#xff0c;博客可能存在部分错误的地方&#xff0c;请广大读者谅解并向本人反馈错误。    因为我个人对Qt也是有一些需求&#xff0c;所以开设本专栏进行学习&#xff0c;希望大家可以一起学习&#xff0c;共同进步。   这篇博客将从一个 He…

算法刷题Day1 | 704.二分查找、27.移除元素

目录 0 引言1 二分查找1.1 我的解题1.2 修改后1.3 总结 2 移除元素2.1 暴力求解2.2 双指针法&#xff08;快慢指针&#xff09; &#x1f64b;‍♂️ 作者&#xff1a;海码007&#x1f4dc; 专栏&#xff1a;算法专栏&#x1f4a5; 标题&#xff1a;代码随想录算法训练营第一天…

Vue.js+SpringBoot开发大病保险管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统配置维护2.2 系统参保管理2.3 大病保险管理2.4 大病登记管理2.5 保险审核管理 三、系统详细设计3.1 系统整体配置功能设计3.2 大病人员模块设计3.3 大病保险模块设计3.4 大病登记模块设计3.5 保险审核模块设计 四、…

MySQL三种日志

一、undo log&#xff08;回滚日志&#xff09; 1.作用&#xff1a; &#xff08;1&#xff09;保证了事物的原子性 &#xff08;2&#xff09;通过read view和undo log实现mvcc多版本并发控制 2.在事务提交前&#xff0c;记录更新前的数据到undo log里&#xff0c;回滚的时候读…

数据可视化助力林业智能管理

数据可视化是当下科技发展中的一项重要工具&#xff0c;它在各行各业都展现了强大的应用价值。在智慧林业领域&#xff0c;数据可视化更是发挥了独特的作用&#xff0c;为林业管理和生态保护提供了有效的支持和解决方案。下面我就以可视化从业者的角度&#xff0c;来简单聊聊这…

四节点/八节点四边形单元悬臂梁Matlab有限元编程 | 平面单元 | Matlab源码 | 理论文本

专栏导读 作者简介&#xff1a;工学博士&#xff0c;高级工程师&#xff0c;专注于工业软件算法研究本文已收录于专栏&#xff1a;《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现&#xff0c;并提供所有案例完整源码&#xff1b;2.单元…

关于yolov8的DFL模块(pytorch以及tensorrt)

先看代码 class DFL(nn.Module):"""Integral module of Distribution Focal Loss (DFL).Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391"""def __init__(self, c116):"""Initialize a convo…

嵌入式C语言(六)

对齐这个事情在内核中可不是个什么小事&#xff0c;内核中涉及到内存方面的都需要非常的谨慎。 上一篇我们知道了可以通过__attribute__来声明属性&#xff0c;也知道了section这个属性&#xff0c;这篇我们来看看关于内存对齐使用的两个属性–>aligned和packed 地址对齐&…

Altium Designer如何对走线模式进行切换

AD软件提供了比较智能的走线模式切换功能&#xff0c;可以根据个人习惯进行切换&#xff0c;能有效的提高了PCB设计效率。 点击界面右上角系统参数的图标 或者在pcb界面中使用快捷键OP进入到优选项界面&#xff0c;然后选中 PCB Editor-Interactive Routing&#xff0c;在布线…

ubuntu設定QGC獲取pixhawk Mini4(PX4 Mini 4) 的imu信息

ubuntu20.04 QGC使用v4.3.0的版本 飛控pixhawk Mini4 飛控上只使用一條micro USB連接電腦&#xff0c;沒有其他線 安裝命令 sudo apt-get remove modemmanager -y sudo apt install gstreamer1.0-plugins-bad gstreamer1.0-libav gstreamer1.0-gl -y sudo apt install libf…

Vue:纯前端实现文件拖拽上传

先看一下拖拽相关的事件&#xff1a;dragover、dragenter drop和dragleave 。 dragover事件&#xff1a;当被拖动的元素在一个可放置目标上方时&#xff0c;该事件会被触发。 通常&#xff0c;我们会使用event.preventDefault()方法来取消浏览器默认的拖放行为&#xff0c;以便…

amv是什么文件格式?如何播放amv视频?

AMV文件格式源自于中国公司Actions Semiconductor&#xff0c;最初作为其MP4播放器中使用的专有视频格式。产生于数码媒体发展的需求下&#xff0c;AMV格式为小屏幕便携设备提供了一种高度压缩的视频存储方案。 AMV文件格式的主要特性与使用场景 AMV格式以其独特的特性在小尺寸…

复合式统计图绘制方法(7)

复合式统计图绘制方法&#xff08;7&#xff09; 常用的统计图有条形图、柱形图、折线图、曲线图、饼图、环形图、扇形图。 前几类图比较容易绘制&#xff0c;饼图环形图绘制较难。 在统计图的应用方面&#xff0c;有时候有两个关联的统计学的样本值要用统计图来表达&#xff0…

运动想象 (MI) 迁移学习系列 (5) : SSMT

运动想象迁移学习系列:SSMT 0. 引言1. 主要贡献2. 网络结构3. 算法4. 补充4.1 为什么设置一种新的适配器&#xff1f;4.2 动态加权融合机制究竟是干啥的&#xff1f; 5. 实验结果6. 总结欢迎来稿 论文地址&#xff1a;https://link.springer.com/article/10.1007/s11517-024-0…

天府锋巢直播产业基地:直播带岗,成都直播基地奔向产业化

天府锋巢直播产业基地位于成都市天府新区科学城板块&#xff0c;是一座集直播带岗、电商孵化、产业培训、供应链整合等多功能于一体的现代化全域直播产业基地。近年来&#xff0c;随着成都直播产业的蓬勃发展&#xff0c;成都积极响应市场需求&#xff0c;致力于打造出西部地区…

linux进程间通信-共享内存

一、共享内存是什么 在Linux系统中&#xff0c;共享内存是一种IPC&#xff08;进程间通信&#xff09;方式&#xff0c;它可以让多个进程在物理内存中共享一段内存区域。 这种共享内存区域被映射到多个进程的虚拟地址空间中&#xff0c;使得多个进程可以直接访问同一段物理内存…

【Python可视化系列】一文教你绘制雷达图(源码)

这是我的第234篇原创文章。 一、引言 雷达图是以从同一点开始的轴上表示的三个或更多个定量变量的二维图表的形式显示多变量数据的图形方法&#xff0c;也称为蜘蛛图或星形图。雷达图通常用于综合分析多个指标&#xff0c;具有完整&#xff0c;清晰和直观的优点。通常由多个等…

Constrained Iterative LQR 自动驾驶中使用的经典控制算法

Motion planning 运动规划在自动驾驶领域是一个比较有挑战的部分。它既要接受来自上层的行为理解和决策的输出,也要考虑一个包含道路结构和感知所检测到的所有障碍物状态的动态世界模型。最终生成一个满足安全性和可行性约束并且具有理想驾驶体验的轨迹。 通常,motion plann…

遥感影像植被波谱特征总结

植被跟太阳辐射的相互关系有别于其他物质&#xff0c;如裸土、水体等&#xff0c;比如植被的“红边”现象&#xff0c;即在<700nm附近强吸收&#xff0c;>700nm高反射。很多因素影响植被对太阳辐射的吸收和反射&#xff0c;包括波长、水分含量、色素、养分、碳等。 研究…

Kubernetes--ingress实现七层负载

目录 一、传统方式&#xff1a;不借助ingress实现七层代理 二、nginx-ingress 三、使用ingress实现七层代理 四、部署ingrss-nginx及功能 五、样例 1.Ingress-nginx HTTP代理访问 2.Ingress HTTPS代理访问&#xff08;会话卸载层&#xff09; 3.Nginx进行BasicAuth&…