P6打卡—Pytorch实现人脸识别

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.检查GPU

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

2.查看数据

import os,PIL,random,pathlib
data_dir = pathlib.Path('data/48-data')
data_dir=pathlib.Path(data_dir)
data_path=list(data_dir.glob("*"))
ClassNames=[str(path).split('\\')[2] for path in data_path]
ClassNames

​​​​​

3.划分数据集

train_trainsforms=transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.486,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
]
)
total_data=datasets.ImageFolder("data/48-data",transform=train_trainsforms)
total_data

total_data.class_to_idx

train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_dataset,test_dataset=torch.utils.data.random_split(total_data,(train_size,test_size))
train_dataset,test_dataset

batch_size=32
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size,shuffle=True,num_workers=1)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size,shuffle=True,num_workers=1)

for X,y in train_dl:
    print(X.shape)
    print(y.shape)
    break

​​

4.调用官方模型

from torchvision.models import vgg16
print("Using {} device".format(device))
model=vgg16(pretrained=True).to(device)
for param in model.parameters():
    param.requires_grad=False
model.classifier._modules['6']=nn.Linear(4096,len(ClassNames))
model.to(device)
model


​​​

5.动态调整学习率函数

#调用官方动态学习率接口
learning_rate = 1e-4
lambda1=lambda epoch:0.92**(epoch//4)
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)
scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1)

6.编译及训练模型

def train(dataloader,model,loss_fn,optimizer):
    size=len(dataloader.dataset)
    num_batches=len(dataloader)
    train_loss,train_acc=0,0
    for X,y in dataloader:
        X,y =X.to(device),y.to(device)
        pred=model(X)
        loss=loss_fn(pred,y)
        #反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss+=loss.item()
        train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
    train_acc/=size
    train_loss/=num_batches
    return train_acc,train_loss

def test(dataloader,model,loss_fn):
    size=len(dataloader.dataset)
    num_batches=len(dataloader)
    test_loss,test_acc=0,0
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target=imgs.to(device),target.to(device)
            target_pred=model(imgs)
            loss=loss_fn(target_pred,target)
            test_loss+=loss.item()
            test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
    test_acc/=size
    test_loss/=num_batches
    return test_acc,test_loss

import copy
loss_fn=nn.CrossEntropyLoss()
epochs=40
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0
for epoch in range(epochs):
    model.train()
    epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
    #更新学习率
    scheduler.step()
    model.eval()
    epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)

    if epoch_test_acc>=best_acc:
        best_acc=epoch_test_acc
        best_model=copy.deepcopy(model)
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    lr=optimizer.state_dict()['param_groups'][0]['lr']
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
PATH='./best_model.pth'
torch.save(best_model.state_dict(),PATH)
print('Finished Training')

​​​​​​

7.结果可视化

import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False
plt.rcParams['figure.dpi']=100

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

​​

8.预测本地图片

from PIL import Image
classes=list(total_data.class_to_idx)
def predict_one_image(image_path,model,transform,classes):
    test_img=Image.open(image_path).convert('RGB')
    plt.imshow(test_img)
    test_img=transform(test_img)
    img=test_img.to(device).unsqueeze(0)
    model=model.eval()
    output=model(img)
    _,pred=torch.max(output,1)
    pred_class=classes[pred]
    print('预测结果是:{pred_class}')

predict_one_image(image_path='data/48-data/Angelina Jolie/005_582c121a.jpg',
                  model=model,
                  transform=train_trainsforms,
                  classes=classes)
#查看最优损失及准确率
best_model.eval()
epoch_test_Acc,epoch_test_loss=test(test_dl,best_model,loss_fn)
epoch_test_Acc,epoch_test_loss

​​

​​总结:

1.VGG-16

VGG-16(Visual Geometry Group-16)是由牛津大学视觉几何组(Visual Geometry Group)提出的一种深度卷积神经网络架构,用于图像分类和对象识别任务。VGG-16在2014年被提出,是VGG系列中的一种。VGG-16之所以备受关注,是因为它在ImageNet图像识别竞赛中取得了很好的成绩,展示了其在大规模图像识别任务中的有效性。

以下是VGG-16的主要特点:

  1. 深度:VGG-16由16个卷积层和3个全连接层组成,因此具有相对较深的网络结构。这种深度有助于网络学习到更加抽象和复杂的特征。
  2. 卷积层的设计:VGG-16的卷积层全部采用3x3的卷积核和步长为1的卷积操作,同时在卷积层之后都接有ReLU激活函数。这种设计的好处在于,通过堆叠多个较小的卷积核,可以提高网络的非线性建模能力,同时减少了参数数量,从而降低了过拟合的风险。
  3. 池化层:在卷积层之后,VGG-16使用最大池化层来减少特征图的空间尺寸,帮助提取更加显著的特征并减少计算量。
  4. 全连接层:VGG-16在卷积层之后接有3个全连接层,最后一个全连接层输出与类别数相对应的向量,用于进行分类。

VGG-16结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示;
  • 3个全连接层(Fully connected Layer),用classifier表示;
  • 5个池化层(Pool layer)。

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

2.设置动态学习率

#非官方设置动态学习率
def adjust_learning_rate(optimizer, epoch, start_lr):
     # 每 2 个epoch衰减到原来的 0.98
     lr = start_lr * (0.92 ** (epoch // 2))
     for param_group in optimizer.param_groups:
         param_group['lr'] = lr

learn_rate = 1e-4 # 初始学习率
optimizer  = torch.optim.SGD(model.parameters(), lr=learn_rate)

#官方设置动态学习率
# 调用官方动态学习率接口时使用
lambda1 = lambda epoch: 0.92 ** (epoch // 4)
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) #选定调整方法

model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/940350.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Electronjs+Vue如何开发PC桌面客户端(Windows,Mac,Linux)

electronjs官网 https://www.electronjs.org/zh/ Electron开发PC桌面客户端的技术选型非常适合已经有web前端开发人员的团队。能够很丝滑的过渡。 Electron是什么? Electron是一个使用 JavaScript、HTML 和 CSS 构建桌面应用程序的框架。 嵌入 Chromium 和 Node.…

内旋风铣削知识再学习

最近被有不少小伙伴们问到蜗杆加工的一种方式——内旋风铣削加工。关于旋风铣之前出过一篇《什么是旋风铣?》,简要介绍了旋风铣(Whilring)的一些基本内容。本期再重新仔细聊一聊内旋风这种加工方式,可加工的零件种类&a…

centos7下docker 容器实现redis主从同步

1.下载redis 镜像 docker pull bitnami/redis2. 文件夹授权 此文件夹是 你自己映射到宿主机上的挂载目录 chmod 777 /app/rd13.创建docker网络 docker network create mynet4.运行docker 镜像 安装redis的master -e 是设置环境变量值 docker run -d -p 6379:6379 \ -v /a…

基于Spring Boot的远程教育网站

一、系统背景与意义 随着互联网技术的飞速发展和普及,远程教育已成为现代教育体系中的重要组成部分。它打破了时间和空间的限制,让学习者可以随时随地进行学习。基于Spring Boot的远程教育网站正是为了满足这一需求而设计的,它利用互联网技术…

cf补题日记4

进度:6/40 我觉得我的思维还是太差了,多练思维题吧!!!!(燃起来 简直是思维题b题专题了,现在连b都做不出了吗(悲 原题1: Cats are attracted to pspspsps, …

WPF Binding 绑定

绑定是 wpf 开发中的精髓,有绑定才有所谓的数据驱动。 1 . 背景 目前 wpf 界面可视化的控件,继承关系如下, 控件的数据绑定,基本上都要借助于 FrameworkElement 的 DataContext 属性。 只有先设置了控件的 DataContext 属性&…

datasets笔记:两种数据集对象

Datasets 提供两种数据集对象:Dataset 和 ✨ IterableDataset ✨。 Dataset 提供快速随机访问数据集中的行,并支持内存映射,因此即使加载大型数据集也只需较少的内存。IterableDataset 适用于超大数据集,甚至无法完全下载到磁盘或…

【Python】【数据分析】深入探索 Python 数据可视化:Seaborn 可视化库详解

目录 引言一、Seaborn 简介二、安装 Seaborn三、Seaborn 的基本图形3.1 散点图(Scatter Plot)3.2 线图(Line Plot)3.3 条形图(Bar Plot)3.4 箱型图(Box Plot)3.5 小提琴图&#xff0…

如何构建机器学习数据集

1. 常见数据集网站 论文开源代码/数据集:Paperswithcodes 竞赛数据集:Kaggle Dataset 数据集搜索工具:Google Dataset Search HuggingFace:Hugging Face 魔塔:Model Scope 开源工具包自带:Pytorch, tensor…

EMQX V5 使用API 密钥将客户端踢下线

在我们选用开源的EMQX作为mqtt broker,我们可能会考虑先让客户端连接mqtt broker成功,再去校验客户端的有效性,当该客户端认证失败,再将其踢下线。例如:物联网设备连接云平台时,我们会将PK、PS提前烧录到设…

Ubuntu搭建ES8集群+加密通讯+https访问

目录 写在前面 一、前期准备 1. 创建用户和用户组 2. 修改limits.conf文件 3. 关闭操作系统swap功能 4. 调整mmap上限 二、安装ES 1.下载ES 2.配置集群间安全访问证书密钥 3.配置elasticsearch.yml 4.修改jvm.options 5.启动ES服务 6.修改密码 7.启用外部ht…

电子电气架构---基于PREEvision的线束设计工作流程优化

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 所谓鸡汤,要么蛊惑你认命,要么怂恿你拼命,但都是回避问题的根源,以现象替代逻辑,以情绪代替思考,把消极接受现实的懦弱,伪装成乐观面对不幸的…

【活动邀请·深圳】深圳COC社区 深圳 AWS UG 2024 re:Invent re:Cap

re:Invent 是全球云计算领域的顶级盛会,每年都会吸引来自世界各地的技术领袖、创新者和实践者汇聚一堂,分享最新的技术成果和创新实践,深圳 UG 作为亚马逊云科技技术社区的重要组成部分,将借助 re:Invent 的东风,举办此…

从零搭建纯前端飞机大战游戏(附源码)

目录 前言 一、游戏概览与技术选型 二、HTML 结构搭建和CSS样式美化 三、JavaScript 核心逻辑 1.变量声明与初始化 2.玩家飞机控制函数 3.射击与子弹管理函数 4.敌机生成与管理函数 5.碰撞检测与得分更新函数 6.游戏主循环与启动函数 四、完整代码 前言 在前端开发的…

【MAC】深入浅出 Homebrew 下 Nginx 的安装与配置指南

硬件:Apple M4 Pro 16寸 系统: macos Sonoma 15.1.1 Nginx 是一款高性能的 Web 服务器和反向代理服务器,广泛应用于全球各地的网站和企业应用中。本文将详细介绍如何在 macOS 环境下使用 Homebrew 安装、启动、管理以及优化配置 Nginx&#x…

简单了解图注意力机制

简单了解图注意力机制 如果对传统的图匹配的聚合方式进行创新的话,也就是对h这一个节点的聚合方式进行创新。 h i ( l 1 ) Norm ⁡ ( σ ( h i ( l ) α ∥ h i ( l ) ∥ m i ( l ) ∥ m i ( l ) ∥ ) ) , \mathbf{h}_{i}^{(l1)}\operatorname{Norm}\left(\sigm…

aosp15 - App冷启动

纸上得来终觉浅,绝知此事要躬行。 —— [宋]陆游 基于aosp_cf_x86_64_phone-trunk_staging-eng , 下面是具体断点位置。 第一部分,桌面launcher进程 com.android.launcher3.touch.ItemClickHandler onClickonClickAppShortcutstartAppShor…

arcgisPro相接多个面要素转出为完整独立线要素

1、使用【面转线】工具,并取消勾选“识别和存储面邻域信息”,如下: 2、得到的线要素,如下:

树莓派4B 搭建openwrt内置超多插件docker,nas等等使用教程

刷入固件 (想要固件的加我vx wyy7293) bleachwrt-plus-20241112-bcm27xx-bcm2711-rpi-4-squashfs-factory.img上电,并且把网线两头分别插在pi网口上和电脑的网口上(电脑必须断网) 等待网口灯亮,进入192.168.1.1 默认账密 root password 进入系统后更改openwrt的网关地址相关…

Java开发经验——数据库开发经验

摘要 本文主要介绍了Java开发中的数据库操作规范,包括数据库建表规范、索引规约、SQL规范和ORM规约。强调了在数据库设计和操作中应遵循的最佳实践,如字段命名、数据类型选择、索引创建、SQL语句编写和ORM映射,旨在提高数据库操作的性能和安…