知识蒸馏实战代码教学二(代码实战部分)

一、上章原理回顾

具体过程:

        (1)首先我们要先训练出较大模型既teacher模型。(在图中没有出现)

        (2)再对teacher模型进行蒸馏,此时我们已经有一个训练好的teacher模型,所以我们能很容易知道teacher模型输入特征x之后,预测出来的结果teacher_preds标签。

        (3)此时,求到老师预测结果之后,我们需要求解学生在训练过程中的每一次结果student_preds标签。

        (4)先求hard_loss,也就是学生模型的预测student_preds与真实标签targets之间的损失。

        (5)再求soft_loss,也就是学生模型的预测student_preds与教师模型teacher_preds的预测之间的损失。

        (6)求出hard_loss与soft_loss之后,求和总loss=a*hard_loss + (1-a)soft_loss,a是一个自己设置的权重参数,我在代码中设置为a=0.3。

        (7)最后反向传播继续迭代。

二、代码实现

1、数据集

        数据集采用的是手写数字的数据集mnist数据集,如果没有下载,代码部分中会进行下载,只需要把download改成True,然后就会保存在当前目录中。该数据集将其分成80%的训练集,20%的测试集,最后返回train_dataset和test_datatset。

class MyDataset(Dataset):
    def __init__(self,opt):
        self.opt = opt

    def MyData(self):
        ## mnist数据集下载0
        mnist = datasets.MNIST(
            root='../datasets/', train=True, download=False, transform=transforms.Compose(
                [transforms.Resize(self.opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        )

        dataset_size = len(mnist)
        train_size = int(0.8 * dataset_size)
        test_size = dataset_size - train_size

        train_dataset, test_dataset = random_split(mnist, [train_size, test_size])

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.opt.batch_size,
            shuffle=True,
        )

        test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.opt.batch_size,
            shuffle=False,  # 在测试集上不需要打乱顺序
        )
        return train_dataloader,test_dataloader

2、teacher模型和训练实现

       (1) 首先是teacher模型构造,经过三次线性层。

import torch.nn as nn
import torch

img_area = 784

class TeacherModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(TeacherModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

        (2)训练teacher模型

        老师模型训练完成后其权重参数会保存在teacher.pth当中,为以后调用。

import torch.nn as nn
import torch


## 创建文件夹
from tqdm import tqdm

from dist.TeacherModel import TeacherModel

weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/teacher.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速


class TeacherTrainer():
    def __init__(self,opt,train_dataloader,test_dataloader):
        self.opt = opt
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader

    def trainer(self):
        # 老师模型
        opt = self.opt
        train_dataloader = self.train_dataloader
        test_dataloader = self.test_dataloader

        teacher_model = TeacherModel()
        teacher_model = teacher_model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

        for epoch in range(opt.n_epochs):  ## epoch:50
            teacher_model.train()

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                preds = teacher_model(data)
                loss = criterion(preds, targets)

                optimizer_teacher.zero_grad()
                loss = criterion(preds, targets)
                loss.backward()
                optimizer_teacher.step()

            teacher_model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = teacher_model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

            torch.save(teacher_model.state_dict(), weight_path)

        teacher_model.train()
        print('teacher: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

        (3)训练teacher模型

        模型参数都在paras()当中设置好了,直接调用teacher_model就行,然后将其权重参数会保存在teacher.pth当中。

import argparse

import torch

from dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def paras():
    ## 超参数配置
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
    opt = parser.parse_args()
    ## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行
    print(opt)
    return opt


if __name__ == '__main__':
    opt = paras()
    data = MyDataset(opt)
    train_dataloader, test_dataloader = data.MyData()

    # 训练Teacher模型
    teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)
    teacher_trainer.trainer()







 3、学生模型的构建

        学生模型也是经过了三次线性层,但是神经元没有teacher当中多。所以student模型会比teacher模型小很多。

import torch.nn as nn
import torch

img_area = 784

class StudentModel(nn.Module):
    def __init__(self,in_channel=1,num_classes=10):
        super(StudentModel,self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(img_area,20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)

    def forward(self, x):
        x = x.view(-1, img_area)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

4、知识蒸馏训练

(1)首先读取teacher模型。

        将teacher模型中的权重参数teacher.pth放入模型当中。

 #拿取训练好的模型
        teacher_model = TeacherModel()
        if os.path.exists(weights):
            teacher_model.load_state_dict(torch.load(weights))
            print('successfully')
        else:
            print('not loading')
        teacher_model = teacher_model.to(device)

(2)设置损失求解的函数

        hard_loss用的就是普通的交叉熵损失函数,而soft_loss就是用的KL散度。

        # hard_loss
        hard_loss = nn.CrossEntropyLoss()
        # hard_loss权重
        alpha = 0.3

        # soft_loss
        soft_loss = nn.KLDivLoss(reduction="batchmean")

(3)之后再进行蒸馏训练,温度为7

  •         先求得hard_loss就是用学生模型预测的标签和真实标签进行求得损失。
  •         再求soft_loss就是用学生模型预测的标签和老师模型预测的标签进行求得损失。使用softmax时候还需要进行除以温度temp。
  •         最后反向传播,求解模型
       for epoch in range(opt.n_epochs):  ## epoch:5

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                # 老师模型预测
                with torch.no_grad():
                    teacher_preds = teacher_model(data)

                # 学生模型预测
                student_preds = model(data)
                # 计算hard_loss
                student_loss = hard_loss(student_preds, targets)

                # 计算蒸馏后的预测损失
                ditillation_loss = soft_loss(
                    F.softmax(student_preds / temp, dim=1),
                    F.softmax(teacher_preds / temp, dim=1)
                )

                loss = alpha * student_loss + (1 - alpha) * ditillation_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

        model.train()
        print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(4)整个蒸馏训练代码

import torch.nn as nn
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm

from dist.StudentModel import StudentModel
from dist.TeacherModel import TeacherModel

weights = 'C:/Users/26394/PycharmProjects/untitled1//dist/params/teacher.pth'

# D_weight_path = 'C:/Users/26394/PycharmProjects/untitled1/dist/params/distillation.pth'
## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True #使用卷积cuDNN加速


class DistillationTrainer():
    def __init__(self,opt,train_dataloader,test_dataloader):
        self.opt = opt
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader


    def trainer(self):
        opt = self.opt
        train_dataloader = self.train_dataloader
        test_dataloader = self.test_dataloader

        #拿取训练好的模型
        teacher_model = TeacherModel()
        if os.path.exists(weights):
            teacher_model.load_state_dict(torch.load(weights))
            print('successfully')
        else:
            print('not loading')
        teacher_model = teacher_model.to(device)
        teacher_model.eval()

        model = StudentModel()
        model = model.to(device)

        temp = 7

        # hard_loss
        hard_loss = nn.CrossEntropyLoss()
        # hard_loss权重
        alpha = 0.3

        # soft_loss
        soft_loss = nn.KLDivLoss(reduction="batchmean")

        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

        for epoch in range(opt.n_epochs):  ## epoch:5

            for data, targets in tqdm(train_dataloader):
                data = data.to(device)
                targets = targets.to(device)

                # 老师模型预测
                with torch.no_grad():
                    teacher_preds = teacher_model(data)

                # 学生模型预测
                student_preds = model(data)
                # 计算hard_loss
                student_loss = hard_loss(student_preds, targets)

                # 计算蒸馏后的预测损失
                ditillation_loss = soft_loss(
                    F.softmax(student_preds / temp, dim=1),
                    F.softmax(teacher_preds / temp, dim=1)
                )

                loss = alpha * student_loss + (1 - alpha) * ditillation_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()
            num_correct = 0
            num_samples = 0
            with torch.no_grad():
                for x, y in test_dataloader:
                    x = x.to(device)
                    y = y.to(device)

                    preds = model(x)

                    predictions = preds.max(1).indices
                    num_correct += (predictions == y).sum()
                    num_samples += predictions.size(0)
                acc = (num_correct / num_samples).item()

        model.train()
        print('distillation: Epoch:{}\t Accuracy:{:.4f}'.format(epoch + 1, acc))

(5)蒸馏训练的主函数

        该部分大致与teacher模型训练类似,只是调用不同。

import argparse

import torch

from dist.DistillationTrainer import DistillationTrainer
from dist.MyDateLoader import MyDataset
from dist.TeacherTrainer import TeacherTrainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def paras():
    ## 超参数配置
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=5, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
    opt = parser.parse_args()
    ## opt = parser.parse_args(args=[])                 ## 在colab中运行时,换为此行
    print(opt)
    return opt


if __name__ == '__main__':
    opt = paras()
    data = MyDataset(opt)
    train_dataloader, test_dataloader = data.MyData()

    # 训练Teacher模型
    # teacher_trainer = TeacherTrainer(opt,train_dataloader,test_dataloader)
    # teacher_trainer.trainer()

    distillation_trainer = DistillationTrainer(opt,train_dataloader,test_dataloader)
    distillation_trainer.trainer()





三、总结

        总的来说,知识蒸馏是一种有效的模型压缩技术,可以通过在模型训练过程中引入额外的监督信号来训练简化的模型,从而获得与大型复杂模型相近的性能,但具有更小的模型尺寸和计算开销。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/396801.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

每日一题——LeetCode1464.数组中两元素的最大乘积

这题就是找数组里的最大值和次大值 方法一 排序 var maxProduct function(nums) {nums.sort((a,b)>b-a)return (nums[0] - 1) * (nums[1] - 1); }; 消耗时间和内存情况: 方法二 一次遍历: var maxProduct function(nums) {let first-1,second-…

IO进程线程第三天

1.使用fread,fwrite完成两个文件之间的拷贝 #include <myhead.h> //使用fread,fwrite完成两个文件之间的拷贝 int main(int argc, const char *argv[]) {if(argc!3){printf("input file error\n");printf("usage:./a.out srcfile destfile\n");retu…

Eclipse - Makefile generation

Eclipse - Makefile generation References right mouse click on the project -> Properties -> C/C Build -> Generate Makefiles automatically 默认会在 Debug 目录下创建 Makefile 文件。 References [1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/

智慧城市驿站:智慧公厕升级版,打造现代化城市生活的便捷配套

随着城市化进程的加速&#xff0c;人们对城市生活质量的要求也越来越高。作为智慧城市建设的一项重要组成部分&#xff0c;多功能城市智慧驿站应运而生。它集合了信息技术、设计美学、结构工艺、系统集成、环保节能等多个亮点&#xff0c;将现代科技与城市生活相融合&#xff0…

文献速递:GAN医学影像合成--联邦生成对抗网络基础医学图像合成中的后门攻击与防御

文献速递&#xff1a;GAN医学影像合成–联邦生成对抗网络基础医学图像合成中的后门攻击与防御 01 文献速递介绍 虽然深度学习在医疗保健研究中产生了显著影响&#xff0c;但其在医疗保健领域的影响无疑比在其他应用领域更慢、更有限。造成这种情况的一个重要原因是&#xff…

uniapp不同平台获取文件内容以及base64编码特征

前言 文件图片上传&#xff0c;客户端预览是很正常的需求&#xff0c;获取文件的md5特征码也是很正常的&#xff0c;那么&#xff0c;在uniapp中三种环境&#xff0c;h5, 小程序以及 app环境下&#xff0c;如何实现的&#xff1f; 参考&#xff1a; 如何在uniapp中读取文件Arr…

自动驾驶TPM技术杂谈 ———— RFID(Radio Frequency IDentification)技术

文章目录 介绍问题挑战新机遇基于RFID 的绑定式感知基于标签信号物理模型基于标签能量耦合变化基于信号变化模式匹配 基于RFID 的非绑定式感知基于标签电感耦合基于反射信号模型基于信号模式匹配 基于RFID 的混合式感知 介绍 马克维瑟&#xff08;Mark Weiser&#xff09;在199…

vue+element (el-progress)标签 隐藏百分比(%) ,反向显示 ,自定义颜色, demo 复制粘贴拿去用

1 效果: 2 页面代码: <el-row :gutter"10" ><el-col :span"12"><el-card ><div class"fourqu"><div><span slot"title">{{推送任务TOP5}}</span></div></div><div class&…

Maven(基础)、MyBatis

简介 Apache Maven是一个项目管理和构建工具&#xff0c;它基于项目对象模型 (POM)的概念&#xff0c;通过一小段描述信息来管理项目的构建、报告和文档 官网: http://maven.apache.org/ Maven作用 Maven是专门用于管理和构建Java项目的工具&#xff0c;它的主要功能有&#x…

Vue 实现当前页的刷新

Vue 在缓存的基础上实现当前页的刷新 前进刷新&#xff0c;后退不刷新 一、Bus的实现 Bus.js 二、利用Bus实现不同页面的事件传播 1.引入Bus.js&#xff08;传递&#xff09;例如&#xff1a;A页面 2.引入Bus.js&#xff08;接收&#xff09;例如&#xff1a;B页面 3.路由组件设…

rocketMQ-Dashboard安装与部署

1、下载最新版本rocketMQ-Dashboard 下载地址&#xff1a;https://github.com/apache/rocketmq-dashboard 2、下载后解压&#xff0c;并用idea打开 3、修改配置 ①、修改端口及rocketmq服务的ip&#xff1a;port ②、修改访问账号、密码 3、然后启动访问&#xff1a; 4、mav…

WEB基础及HTTP协议概念

目录 一、http概述 1、http的相关概念名词 2、访问浏览器的过程 3、http 协议通信过程 4、扩展网络通信 二、http相关技术 1、web开发语言 2、URL 和 URN 3、URL的组成 4、MIME ​5、网站访问量 6、http协议版本及区别 7、http请求访问的七大过程 8、HTTP工作机制…

【数据结构题目讲解】洛谷P4219 大融合

P4219 大融合 D e s c r i p t i o n \mathrm{Description} Description 给定 1 1 1 棵 n n n 个节点的树&#xff0c;树的边是在操作中加入的&#xff0c;接下来有 m m m 次操作&#xff1a; 将 x x x 与 y y y 之间连一条边查询 x x x 与 y y y 之间这条边有多少条经…

Redis 缓存(Cache)

什么是缓存 缓存(cache)是计算机中的一个经典的概念在很多场景中都会涉及到。 核心思路就是把一些常用的数据放到触手可及(访问速度更快)的地方&#xff0c;方便随时读取。 这里所说的“触手可及”是个相对的概念 我们知道&#xff0c;对于硬件的访问速度来说&#xff0c;通常…

C++-带你初步走进继承(1)

1.继承的概念及定义 1.1继承的概念 继承 (inheritance) 机制是面向对象程序设计 使代码可以复用 的最重要的手段&#xff0c;它允许程序员在 保 持原有类特性的基础上进行扩展 &#xff0c;增加功能&#xff0c;这样产生新的类&#xff0c;称派生类。继承 呈现了面向对象 …

前端新手Vue3+Vite+Ts+Pinia+Sass项目指北系列文章 —— 第十一章 基础界面开发 (组件封装和使用)

前言 Vue 是前端开发中非常常见的一种框架&#xff0c;它的易用性和灵活性使得它成为了很多开发者的首选。而在 Vue2 版本中&#xff0c;组件的开发也变得非常简单&#xff0c;但随着 Vue3 版本的发布&#xff0c;组件开发有了更多的特性和优化&#xff0c;为我们的业务开发带…

如何搭建一款论坛系统?简单介绍多功能论坛系统。

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 前言 论坛系统简单介绍就是&#xff1a;跟微博类似的app系统&#xff0c;粉丝用户可以很好…

Web服务器基础

Web服务器基础 【一】前端概述 【1】HTML HTML&#xff08;超文本标记语言&#xff09;是用于创建网页结构的标记语言。它定义了网页的骨架&#xff0c;包括标题、段落、列表、链接等元素&#xff0c;但没有样式。可以将HTML视为网页的结构和内容的描述。 【2】CSS css&…

ARM体系在linux中的中断抢占

上一篇说到系统调用等异常通过向量el1_sync做处理&#xff0c;中断通过向量el1_irq做处理&#xff0c;然后gic的工作都是为中断处理服务&#xff0c;在rtos中&#xff0c;我们一般都会有中断嵌套和优先级反转的概念&#xff0c;但是在linux中&#xff0c;中断是否会被其他中断抢…

MybatisPlus创建时间不想用默认值

我们知道&#xff0c;MybatisPlus可以给一些字段设置默认值&#xff0c;比如创建时间&#xff0c;更新时间&#xff0c;分为插入时设置&#xff0c;和更新时设置。 常见的例子&#xff1a; /*** 创建时间*/ JsonFormat(shape JsonFormat.Shape.STRING, pattern"yyyy-MM…