python零基础实现基于旋转特征的自监督学习(二)——在resnet18模型下应用自监督学习

系列文章目录

基于旋转特征的自监督学习(一)——算法思路解析以及数据集读取
基于旋转特征的自监督学习(二)——在resnet18模型下应用自监督学习


模型搭建与训练

  • 系列文章目录
  • 前言
  • resNet18
    • Residual
    • resnet_block
    • resNet18
    • select_model
  • 模型训练
    • 损失函数与精度
    • 批量训练
    • 训练参数
  • 效果对比
  • 总结


前言

在本系列的上一篇文章中,我们介绍了如何对数据加载器进行修改来构建适合预基于特征旋转的自监督学习使用的数据集,在本篇文章中,我们将构建一个简易的深度学习模型——resnet18作为测试模型作为案例,在resnet18上我们进行训练,以及效果的对比。

代码地址:https://github.com/AiXing-w/little-test-for-FeatureLearningRotNet

resNet18

为了能够尽快看到结果,我们将搭建resNet18网络模型进行测试,这里我们将resnet18的搭建分为三个部分
residualresnet_block、和resnet18

其中resnet18是我们需要构建的模型,在之前的文章中我们曾经介绍过resnet模型的结构
可参考:常用结构化CNN模型构建

由于resnet18的模型结构是结构化的,为了方便表达,我们将模型分成多个块,即resnet_block;resnet_block也可以分成多个残差块,即residual。

Residual

在这里插入图片描述

Residual的实际就是两个卷积网络再加一个残差边,经过残差边的特征与经过卷积的特征相加再进行输出。

class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)

        else:
            self.conv3 = None

        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)

        Y += X
        return F.relu(Y)

resnet_block

使用resnet_block实现残差块(residual)的拼接,他需要传入三个参数:输入通道、输出通道、总共层数。一个可选参数first_block默认为False,用于判断是否为第一个resnet块。

def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))

    return blk

resNet18

使用resnet_block构建resnet18网络。

def resNet18(in_channels, num_classes):
    b1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
                       nn.BatchNorm2d(64), nn.ReLU(),
                       nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                       )

    b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))

    b3 = nn.Sequential(*resnet_block(64, 128, 2))

    b4 = nn.Sequential(*resnet_block(128, 256, 2))

    b5 = nn.Sequential(*resnet_block(256, 512, 2))

    net = nn.Sequential(b1, b2, b3, b4, b5,
                        nn.AdaptiveAvgPool2d((1, 1)),
                        nn.Flatten(),
                        nn.Linear(512, num_classes)
                        )

    return net

select_model

为了后边可以拓展新的模型,所以我们构建一个select_model函数来用于选择模型

def select_model(model_name=str, kwargs=dict):
    # 选择模型
    if model_name == 'resNet18':
        net = resNet18(kwargs['in_channels'], kwargs['num_classes'])


    if os.path.exists("./model_weights/{}.pth".format(model_name)):
        net.load_state_dict(torch.load("./model_weights/{}.pth".format(model_name)))
        print("model wieghts loaded")

    return net

模型训练

损失函数与精度

Focal loss
如果仅使用交叉熵(cross entropy)的话有可能某些较为容易被学会的类别越学越好,然而其他类别越学越差。或者会受到数据不均衡的影响。所以我们将使用focal loss,可以抑制当时学习较好的样本,促进当时学习较差的样本。

import torch
from torch import nn

def Focal_loss(pred, target,  alpha=0.5, gamma=2):
    logpt = -nn.CrossEntropyLoss(reduction='none')(pred, target)
    pt = torch.exp(logpt)
    if alpha is not None:
        logpt *= alpha
    loss = -((1 - pt) ** gamma) * logpt
    loss = loss.mean()
    return loss

accuracy
我们使用accuracy函数获取精度,他需要传入两个参数,y_hat和y。其中y_hat是预测结果,y是标签。
我们通过argmax获取y_hat中最大概率的预测,与y做比较。这个函数本质上就是用来评价有多少预测正确。

def accuracy(y_hat, y):
    # 预测精度
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = torch.argmax(y_hat, axis=1)

    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum()) / len(y)

批量训练

这个函数参数比较复杂

参数含义
net模型
train_iter训练数据
test_iter测试数据
start起始轮次(开始训练的轮次)
num_epochs总轮次
lr学习率
device选择在cuda或cpu下训练
threshold阈值(满足阈值设置时提前结束训练)
save_checkpoint选择是否按照检查点保存模型权值
save_steps每隔几个轮次保存一次模型权值
model_name保存权值时模型的名称

先使用Xavier初始化权值然后训练和测试模型即可,这里使用history来保存训练损失、训练精度、测试损失、测试精度

def train(net, train_iter, test_iter, start, num_epochs, lr, device, threshold, save_checkpoint=False, save_steps=50, model_name="rotation"):
    # 训练模型
    def init_weights(m):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights)
    print("device in : ", device)
    net = net.to(device)

    loss = nn.CrossEntropyLoss()
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    for epoch in range(start, num_epochs):
        net.train()
        train_loss = 0.0
        train_acc = 0.0
        data_num = 0

        with tqdm(range(len(train_iter)), ncols=100, colour='red',
                  desc="{} train epoch {}/{}".format(model_in_use, epoch + 1, num_epochs)) as pbar:
            for i, (X, y) in enumerate(train_iter):
                optimizer.zero_grad()
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                l.backward()
                optimizer.step()
                train_loss += l.detach()
                train_acc += accuracy(y_hat.detach(), y.detach())
                data_num += 1
                pbar.set_postfix({'loss': "{:.4f}".format(train_loss / data_num), 'acc': "{:.4f}".format(train_acc / data_num)})
                pbar.update(1)

        history['train_loss'].append(float(train_loss / data_num))
        history['train_acc'].append(float(train_acc / data_num))

        net.eval()
        test_loss = 0.0
        test_acc = 0.0
        data_num = 0
        with tqdm(range(len(test_iter)), ncols=100, colour='blue',
                  desc="{} test epoch {}/{}".format(model_in_use, epoch + 1, num_epochs)) as pbar:
            for X, y in test_iter:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                with torch.no_grad():
                    l = loss(y_hat, y)
                    test_loss += l.detach()
                    test_acc += accuracy(y_hat.detach(), y.detach())

                    data_num += 1
                    pbar.set_postfix({'loss': "{:.4f}".format(test_loss / data_num), 'acc': "{:.4f}".format(test_acc / data_num)})
                    pbar.update(1)

        history['test_loss'].append(float(test_loss / data_num))
        history['test_acc'].append(float(test_acc / data_num))
        if history['test_acc'][-1] > threshold:
            print("early stop")
            break
        if save_checkpoint and (epoch+1) % save_steps == 0:
            torch.save(net.state_dict(), "./model_weights/{}-ep{}-{}-acc-{:.4f}-loss-{:.4f}.pth".format(model_name, epoch+1, model_in_use, history['test_acc'][-1], history['test_loss'][-1]))

    torch.save(net.state_dict(), "./model_weights/{}-{}.pth".format(model_in_use, model_name))
    return history

训练参数

基于旋转特征的自监督学习的参数

参数含义
batch_size批量大小
in_channels输入通道数
num_classes预测类别
num_rotation_epochs自监督轮次数
num_supervise_epochs迁移轮次数
lr学习率
threshold提前停止的阈值,即测试精度超过这个阈值就停止训练
model_in_use使用的模型
model_kargs模型参数,即输入通道数和总类别
device测试cuda是否可用
img_size改变图像大小
is_freeze选择是否冻结自监督学习到的卷积层

这里可以看到,整个的训练过程需要加载两次数据并且训练两次,第一次是加载基于旋转特征的自监督学习的数据集,即判断是否经过旋转的数据集。另一次是加载cifar-10分类数据集(判断是否经过旋转的数据集也是cifar-10数据集,只不过我们在上一篇文章中在数据加载阶段对标签进行了修改)

if __name__ == '__main__':
    batch_size = 4086  # 批量大小
    in_channels = 3  # 输入通道数
    num_classes = 10  # 预测类别
    num_rotation_epochs = 100  # 自监督轮次
    num_supervise_epochs = 100  # 迁移
    lr = 2e-1
    threshold = 0.95  # 提前停止的阈值,即测试精度超过这个阈值就停止训练
    model_in_use = 'resNet18'  
    model_kargs = {'in_channels': in_channels, "num_classes": 4}  # 模型参数,即输入通道数和总类别
    device = 'cuda' if torch.cuda.is_available() else 'cpu'  # 测试cuda并使用
    img_size = None
    is_freeze = False
    trans = [transforms.ToTensor(), transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                                         std=[x/255.0 for x in [63.0, 62.1, 66.7]])]

    trans = transforms.Compose(trans)

    train_iter, test_iter = LoadRotationDataset(batch_size=batch_size, trans=trans)  # 加载数据集
    net = select_model(model_in_use, model_kargs)
    history = train(net, train_iter, test_iter, 0, num_rotation_epochs, lr, device, threshold, save_checkpoint=True)  # 训练
	
	if is_freeze:
	    for param in net.named_parameters():
	         param[1].requires_grad = False

    if model_in_use == 'resNet18':
        net = net[:-3]
        net.add_module("new Adapt", nn.AdaptiveAvgPool2d((1, 1)))
        net.add_module("new Flatten", nn.Flatten())
        net.add_module("new linear", nn.Linear(512, num_classes))

    lr=2e-1
    train_iter, test_iter = LoadSuperviseDataset(batch_size=batch_size, trans=trans)  # 加载数据集
    history_1 = train(net, train_iter, test_iter, num_rotation_epochs, num_rotation_epochs+num_supervise_epochs, lr, device, threshold, save_checkpoint=True)  # 训练

    for key in list(history.keys()):
        history[key] = history[key] + history_1[key]
    plot_history(model_in_use + "_rotation", history)

监督学习的参数
如果不使用基于特征旋转的自监督学习,那么我们只需要去掉下面几行代码即可,同时要注意将history修改到与plot_history中使用的名称保持一致。

train_iter, test_iter = LoadRotationDataset(batch_size=batch_size, trans=trans)  # 加载数据集
net = select_model(model_in_use, model_kargs)
history = train(net, train_iter, test_iter, 0, num_rotation_epochs, lr, device, threshold,save_checkpoint=True)  # 训练

效果对比

完成训练后,在项目所在的文件夹中会生成txt文件来存储训练过程中产生的训练精度、训练损失、测试精度、测试损失。我们写一个函数来对比这些数据

import matplotlib.pyplot as plt


def compare(rotationName, superviseName, titleName):
    rotation_acc = []
    with open(rotationName) as f:
        for acc in f.readlines():
            rotation_acc.append(float(acc.strip()))

    supervise_acc = []
    with open(superviseName) as f:
        for acc in f.readlines():
            supervise_acc.append(float(acc.strip()))
    plt.plot(range(len(rotation_acc)), rotation_acc, label='rotation')
    plt.plot(range(len(supervise_acc)), supervise_acc, label='supervise')
    plt.xlabel('epochs')
    plt.ylabel('scores')
    plt.title(titleName)
    plt.legend()
    plt.show()

if __name__ == "__main__":
    compare("resNet18_rotation_test_acc.txt", "resNet18_supervise_test_acc.txt", "test acc")
    compare("resNet18_rotation_test_loss.txt", "resNet18_supervise_test_loss.txt", "test loss")

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
可以看到虽然在其他指标中两者差距并不明显,但是在测试精度(test acc)中使用了基于旋转特征的自监督学习的精度上升十分明显。

总结

基于旋转特征的自监督学习实质上就是将原始图像进行旋转,旋转过后将他的标签设置成旋转的角度。然后传入模型进行训练,训练好的权值作为分类模型的预训练模型进行模型迁移。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/9542.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

apache配置与应用

构建虚拟 Web 主机 apache虚拟web主机指在同一台服务器运行多个web站点,其中每一个站点实际上并不独立占用整个服务器,因此被称为web主机,可以充分利用服务器的硬件资源,大大降低网站的构建成本 http服务支持三种虚拟主机类型 …

Spring Cloud之Consul服务治理实战

目录 1、Consul是什么 1.1概念 1.2 Consul下载 1.3 Consul启动 2、Consul使用场景 3、Consul优势 4、Consul架构及原理 4.1 整体架构图 4.2 通讯机制 4.3 健康检测 4.4如何保证数据一致性 5、搭建Consul环境 5.1 本地Consul搭建 5.2 集群Consul搭建 5.2.1 安装C…

VUE使用el-ui的form表单输入框批量搜索<VUE专栏三>

针对form表单的输入框单号批量查询&#xff0c;这里用换行符进行分割&#xff0c;注意v-model不要使用.trim 前端代码&#xff1a; <el-form-item label"SKU编码:" prop"prodNumbers"><el-input type"textarea" :rows"4" pla…

Spring 之依赖注入底层原理

Spring 框架作为 Java 开发中最流行的框架之一&#xff0c;其核心特性之一就是依赖注入&#xff08;Dependency Injection&#xff0c;DI&#xff09;。在Spring中&#xff0c;依赖注入是通过 IOC 容器&#xff08;Inversion of Control&#xff0c;控制反转&#xff09;来实现…

电动牙刷语音芯片,音乐播放ic选型?

牙医认为每天刷牙两次对于口腔健康至关重要。但是有些人会因为不良的刷牙习惯而受害&#xff0c;比如刷牙时间不够长。而刷牙不当会导致牙齿出现问题&#xff0c;例如蛀牙、污渍和口臭。 为了保护牙齿健康&#xff0c;很多家庭开始尝试使用电动牙刷。电动牙刷通过嗡嗡地振动消…

springboot2.7及springboot3中自动配置的变化

大家在面试中问到springboot如何实现的自动配置 诶&#xff0c;不要傻背错八股了&#xff0c;以往 入口类上的SpringBootApplication注解引入EnableAutoConfiguration&#xff0c;再 Import({AutoConfigurationImportSelector.class}) &#xff0c;其中代码实现根据META-INF/sp…

实战案例|聚焦攻击面管理,腾讯安全威胁情报守护头部券商资产安全

金融“活水”润泽千行百业&#xff0c;对金融客户来说&#xff0c;由于业务场景存在特殊性和复杂性&#xff0c;网络安全必然是一场“持久战”。如何在事前做好安全部署&#xff0c;构建威胁情报分析的防护体系至为重要&#xff0c;实现更为精准、高效的动态防御。 客户名片 …

Java实现打印杨辉三角形,向左、右偏的平行四边形这三个图形代码程序

目录 前言 一、打印杨辉三角形 1.1运行流程&#xff08;思想&#xff09; 1.2代码段 1.3运行截图 二、向左偏的平行四边形 1.1运行流程&#xff08;思想&#xff09; 1.2代码段 1.3运行截图 三、向右偏的平行四边形 1.1运行流程&#xff08;思想&#xff09; 1.2代…

一个评测模型+10个问题,摸清盘古、通义千问、文心一言、ChatGPT的“家底”!...

‍数据智能产业创新服务媒体——聚焦数智 改变商业毫无疑问&#xff0c;全球已经在进行大模型的军备竞赛了&#xff0c;“有头有脸”的科技巨头都不会缺席。昨天阿里巴巴内测了通义千问&#xff0c;今天华为公布了盘古大模型的最新进展。不久前百度公布了文心一言、360也公布了…

网络系统集成综合实验(六)| 访问控制列表ACL配置

目录 一、前言 二、实验目的 三、实验需求 四、实验步骤与现象 &#xff08;一&#xff09;基本ACL实验 Step1&#xff1a;构建拓扑图如下&#xff1a; Step2&#xff1a;PC的IP地址分别配置如下&#xff1a; Step3&#xff1a;路由器的IP地址配置如下 Step4&#xff…

十、CNN卷积神经网络实战

一、确定输入样本特征和输出特征 输入样本通道数4、期待输出样本通道数2、卷积核大小33 具体卷积层的构建可参考博文&#xff1a;八、卷积层 设定卷积层 torch.nn.Conv2d(in_channelsin_channel,out_channelsout_channel,kernel_sizekernel_size,padding1,stride1) 必要参数&a…

大数据五次作业回顾

文章目录1. 大数据作业11.本地运行模式部分2. 使用scp安全拷贝部分2. 大数据作业21、Rrsync远程同步工具部分2、xsync集群分发脚本部分3、集群部署部分3. 大数据作业31. 配置历史服务器及日志2. 日志部分3. 其他4. 大数据作业4编写本地wordcount案例一、源代码二、信息截图5. 大…

matlab流场可视化后处理

1流体中标量的可视化 流体力学中常见的标量为位置、速度绝对值、压强等。 1.1 云图 常用的云图绘制有pcolor、image、imagesc、imshow、contourf等函数。 这里利用matlab自带的wind数据作为演示案例&#xff0c;显示二维云图的速度场。 close all load wind x2x(:,:,5);y2y…

介绍MSYS2 在windows下与使用

系列文章目录 文章目录系列文章目录前言一、MSYS下载二、安装三、使用MSYS2安装CMake工具前言 MSYS的独立改写版本 MSYS2 &#xff08;Minimal SYStem 2&#xff09; 是一个MSYS的独立改写版本&#xff0c;主要用于 shell 命令行开发环境。同时它也是一个在Cygwin &#xff08…

闭关修炼(0.0 pytorch基础学习)1

基于官网pytorch.org pytorch 动态 比较优秀 py3.7支持是最多的啦 原来anaconda 是蟒蛇的意思 细思极恐 python 是蛇 yi Introduction to PyTorch Tensors — PyTorch Tutorials 2.0.0cu117 documentation omygaga 英语极差 哈哈哈 tensor 多维数组 矩阵二维数组 Tensor…

G761-3005B伺服阀放大器

G761-3005B伺服阀放大器&#xff0c;两级设计能够实现高水平设备性能、更快的周期时间和更高的准确性&#xff0c;最终为客户带来更高的生产效率 双线圈力矩马达高可靠性冗余设计 力矩马达配置双精度喷嘴精确流量控制和可预测性 干式力矩马达设计消除力矩马达气隙中可能导致…

数据结构-二叉树(前中后层序遍历-代码实现)

一、概要 二叉树的遍历方式包括前序遍历、中序遍历、后序遍历和层序遍历&#xff0c;具体定义如下&#xff1a; 前序遍历&#xff1a;先访问根节点&#xff0c;然后按照前序遍历的方式递归访问左子树和右子树。 中序遍历&#xff1a;先按照中序遍历的方式递归访问左子树&#…

Spring————java的反射机制,Spring的IOC和DI

一、认识Spring 1.1、Spring家族 SpringFramework&#xff1a; Spring框架&#xff1a;是Spring中最早核心的技术&#xff0c;也是所有其他技术及的基础。 SpringBoot:Spring是用来简化开发。而SpringBoot是来帮助Spring在简化的基础上能更快速进行开发。 SpringCloud&#xf…

v851s gpio 应用程序编写

1. 查看硬件电路图SCH_Schematic1_2022-11-23 &#xff0c;查找合适的gpio 作为使用pin 在这里我们选取 GPIOH14&#xff08;注意目前开发使用这个pin 作为触摸屏的pin脚&#xff0c;需要将触摸屏connect断开&#xff09; &#xff0c;因为 可以通过排插使用杜邦线将其引出&am…

Maven高级-属性多环境配置与应用

Maven高级-属性&多环境配置与应用4&#xff0c;属性4.1 属性4.1.1 问题分析4.1.2 解决步骤步骤1:父工程中定义属性步骤2:修改依赖的version4.2 配置文件加载属性步骤1:父工程定义属性步骤2:jdbc.properties文件中引用属性步骤3:设置maven过滤文件范围步骤4:测试是否生效4.3…