2024Datawhale AI夏令营---Inclusion・The Global Multimedia Deepfake Detection--学习笔记

赛题背景:

        其实总结起来就是一句话,这个项目是基于目前的深度伪装技术,就是通过大量人脸的原数据集进行模型训练之后,能够生成伪造的人脸视频。这项目就是教我们如何去实现这个DeepFake技术。

Task1:了解Deepfake和跑通baseline

代码架构如下:

  1. 模型定义:使用timm库创建一个预训练的resnet18模型。

  2. 训练/验证数据加载:使用torch.utils.data.DataLoader来加载训练集和验证集数据,并通过定义的transforms进行数据增强。

  3. 训练与验证过程

    1. 定义了train函数来执行模型在一个epoch上的训练过程,包括前向传播、损失计算、反向传播和参数更新。

    2. 定义了validate函数来评估模型在验证集上的性能,计算准确率。

  4. 性能评估:使用准确率(Accuracy)作为性能评估的主要指标,并在每个epoch后输出验证集上的准确率。

  5. 提交:最后,将预测结果保存到CSV文件中,准备提交到Kaggle比赛。

代码解释如下:

详见代码注释吧

这份代码后续还是要好好精读理解一下的吧,好好分析一下,顺便提升一下代码能力。--7.15

from PIL import Image
Image.open('/kaggle/input/deepfake/phase1/trainset/63fee8a89581307c0b4fd05a48e0ff79.jpg')

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import timm
import time

import pandas as pd
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm_notebook

train_label = pd.read_csv('/kaggle/input/deepfake/phase1/trainset_label.txt')
val_label = pd.read_csv('/kaggle/input/deepfake/phase1/valset_label.txt')

train_label['path'] = '/kaggle/input/deepfake/phase1/trainset/' + train_label['img_name']
val_label['path'] = '/kaggle/input/deepfake/phase1/valset/' + val_label['img_name']

train_label['target'].value_counts()

val_label['target'].value_counts()

train_label.head(10)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, *meters):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = ""


    def pr2int(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def validate(val_loader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(val_loader), batch_time, losses, top1)

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
            losses.update(loss.item(), input.size(0))
            top1.update(acc, input.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f}'
              .format(top1=top1))
        return top1

def predict(test_loader, model, tta=10):
    # switch to evaluate mode
    model.eval()
    
    test_pred_tta = None
    for _ in range(tta):
        test_pred = []
        with torch.no_grad():
            end = time.time()
            for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):
                input = input.cuda()
                target = target.cuda()

                # compute output
                output = model(input)
                output = F.softmax(output, dim=1)
                output = output.data.cpu().numpy()

                test_pred.append(output)
        test_pred = np.vstack(test_pred)
    
        if test_pred_tta is None:
            test_pred_tta = test_pred
        else:
            test_pred_tta += test_pred
    
    return test_pred_tta

def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(len(train_loader), batch_time, losses, top1)

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        losses.update(loss.item(), input.size(0))

        acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
        top1.update(acc, input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 100 == 0:
            progress.pr2int(i)

class FFDIDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, torch.from_numpy(np.array(self.img_label[index]))
    
    def __len__(self):
        return len(self.img_path)

import timm
model = timm.create_model('resnet18', pretrained=True, num_classes=2)
model = model.cuda()

train_loader = torch.utils.data.DataLoader(
    FFDIDataset(train_label['path'].head(1000), train_label['target'].head(1000), 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomVerticalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    FFDIDataset(val_label['path'].head(1000), val_label['target'].head(1000), 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)

criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(model.parameters(), 0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
best_acc = 0.0
for epoch in range(2):
    scheduler.step()
    print('Epoch: ', epoch)

    train(train_loader, model, criterion, optimizer, epoch)
    val_acc = validate(val_loader, model, criterion)
    
    if val_acc.avg.item() > best_acc:
        best_acc = round(val_acc.avg.item(), 2)
        torch.save(model.state_dict(), f'./model_{best_acc}.pt')

test_loader = torch.utils.data.DataLoader(
    FFDIDataset(val_label['path'], val_label['target'], 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)

val_label['y_pred'] = predict(test_loader, model, 1)[:, 1]
val_label[['img_name', 'y_pred']].to_csv('submit.csv', index=None)

        本来是想直接在本地运行的,但是这个数据集实在是太大了,受限于操作和设备,只能在kaggle云运行这个代码咯,结果如下:

        提交上kaggle进行评分:

Inclusion・The Global Multimedia Deepfake Detection | Kaggle

        结果挺差的,毕竟这就是个普通的原始代码,啥都没优化的,参数也没调,只能先这样咯,后续task再来优化调整咔咔上分吧。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/802129.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python项目部署到Linux生产环境(uwsgi+python+flask+nginx服务器)

1.安装python 我这里是3.9.5版本 安装依赖: yum install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gcc make -y 根据自己的需要下载对应的python版本: cd local wget https://www.python.org/ftp…

开发实战经验分享:互联网医院系统源码与在线问诊APP搭建

作为一名软件开发者,笔者有幸参与了多个互联网医院系统的开发项目,并在此过程中积累了丰富的实战经验。本文将结合我的开发经验,分享互联网医院系统源码的设计与在线问诊APP的搭建过程。 一、需求分析 在开发任何系统之前,首先要…

成像光谱遥感技术中的AI革命:ChatGPT

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境,是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型,在理解和生成人类语言方面表现出了非凡的能力,ChatGPT在遥感中的应用,人工智能在…

【STM32】RTT-Studio中HAL库开发教程三:IIC通信--AHT20

文章目录 一、I2C总线通信协议二、AHT20传感器介绍三、STM32CubeMX配置硬件IIC四、RTT中初始化配置五、具体实现代码六、实验现象 一、I2C总线通信协议 使用奥松的AHT20温湿度传感器,对环境温湿度进行采集。AHT20采用的是IIC进行通信,可以使用硬件IIC或…

2. KNN分类算法与鸢尾花分类任务

鸢尾花分类任务 1. 鸢尾花分类步骤1.1 分析问题,搞定输入和输出1.2 每个类别各采集50朵花1.3 选择一种算法,完成输入到输出的映射1.4 第四步:部署,集成 2. KNN算法原理2.1 基本概念2.2 核心理念2.3 训练2.4 推理流程 3. 使用 skle…

Word参考文献交叉引用

前言 Word自带交叉引用功能,可在正文位置引用文档内自动编号的段落,同时创建超链接,适用于参考文献的引用。使用此方法对参考文献进行引用后,当参考文献的编号发生变化时,只需要更新域即可与正文中的引用相对应。下文…

vue3+TS从0到1手撸后台管理系统

1.路由配置 1.1路由组件的雏形 src\views\home\index.vue(以home组件为例) 1.2路由配置 1.2.1路由index文件 src\router\index.ts //通过vue-router插件实现模板路由配置 import { createRouter, createWebHashHistory } from vue-router import …

【15】Android基础知识之Window(一)

概述 这篇文章纠结了很久,在想需要怎么写?因为window有关的篇幅,如果需要讲起来那可太多了。从层级,或是从关联,总之不是很好开口。这次也下定决心,决定从浅入深的讲讲window这个东西。 Window Window是…

鸿蒙特色物联网实训室

一、 引言 在当今这个万物皆可连网的时代,物联网(IoT)正以前所未有的速度改变着我们的生活和工作方式。它如同一座桥梁,将实体世界与虚拟空间紧密相连,让数据成为驱动决策和创新的关键力量。随着物联网技术的不断成熟…

Qt Creator的好用的功能

(1)ctrlf: 在当前文档进行查询操作 (2)f3: 找到后,按f3,查找下一个 (3)shiftf3: 查找上一个 右键菜单: (4)f4:在…

solidity实战练习3——荷兰拍卖

//SPDX-License-Identifier:MIT pragma solidity ^0.8.24; interface IERC721{function transFrom(address _from,address _to,uint nftid) external ; }contract DutchAuction { address payable immutable seller;//卖方uint immutable startTime;//拍卖开始时间uint immut…

钡铼Modbus TCP耦合器BL200实现现场设备与SCADA无缝对接

前言 深圳钡铼技术推出的Modbus TCP耦合器为SCADA系统与现场设备之间的连接提供了强大而灵活的解决方案,它不仅简化了设备接入的过程,还提升了数据传输的效率和可靠性,是工业自动化项目中不可或缺的关键设备。本文将从Modbus TC、SCADA的简要…

基于Ubuntu2204搭建openstack-Y版-手动搭建

openstack手搭Y版 基础环境配置离线环境时间同步(双节点)安装openstack客户端数据库服务消息队列服务缓存服务 keystone服务部署glance服务部署placement服务部署nova服务部署controllercompute neutron服务部署controller节点配置neutron.conf文件配置m…

leetcode-349.两个数组的交集

题源 349.两个数组的交集 题目描述 给定两个数组 nums1 和 nums2 ,返回 它们的 交集 。输出结果中的每个元素一定是 唯一 的。我们可以 不考虑输出结果的顺序 。 示例 1: 输入:nums1 [1,2,2,1], nums2 [2,2] 输出:[2] 示例…

权威认可 | 海云安开发者安全助手系统通过信通院支撑产品功能认证并荣获信通院2024年数据安全体系建设优秀案例

近日,2024全球数字经济大会——数字安全生态建设专题论坛(以下简称“论坛”)在京成功举办。由全球数字经济大会组委会主办,中国信息通信研究院及公安部第三研究所共同承办,论坛邀请多位专家和企业共同参与。 会上颁发…

android预置apk

在framework开发中,有一些需求是需要预装应用的,有些是预置应用源码,有些是预置apk。今天我们就分享下怎样预置apk 一般系统有自定义的目录,比如我的项目中根目录下有一个文件夹vendor,这里没都是自定义的一些功能。预…

Redis系列命令更新--Redis列表命令

Redis列表 1、Redis Blpop命令: (1)说明:Redis Blpop命令移出并获取列表的第一个元素;如果列表没有元素会阻塞列表直到等到超时或发现可弹出元素为止 (2)语法:redis 127.0.0.1:63…

leetcode-三数之和

视频:https://www.bilibili.com/video/BV1bP411c7oJ/?spm_id_from333.788&vd_sourcedd84879fcf1be72f360461b01ecab0d6 从两数之和开始,排序后的两数之和,利用好升序的性质,可以将时间复杂度从on2降到on; class Solution …

「AI得贤招聘官」通过首批“AI产业创新场景应用案例”评估

近日,上海近屿智能科技有限公司的「AI得贤招聘官」,经过工业和信息化部工业文化发展中心数字科技中心的严格评估,荣获首批“AI产业创新场景应用案例”。 据官方介绍,为积极推进通用人工智能产业高质量发展,围绕人工智能…

自适应简约大气科技数码产品公司网站源码系统 模版一键搭建 可自定义带源代码包以及搭建部署教程

系统概述 在当今这个数字化、信息化的时代,科技数码产品行业正处于高速发展的黄金时期。为了在这个竞争激烈的市场中脱颖而出,科技数码产品公司不仅需要拥有卓越的产品和技术,还需要一个能够完美展现其品牌形象和产品特色的网站。为此&#…