目录
1. ResNet介绍
2. SelfAttention 层
3. ResNet34 + SelfAttention
4. 遥感卫星下的土地使用情况分类
4.1 土地使用情况数据集
4.2 训练
4.3 训练结果
4.4 推理
1. ResNet介绍
ResNet(残差网络)是一种深度卷积神经网络模型,由Kaiming He等人于2015年提出。它的提出解决了深度神经网络的梯度消失和梯度爆炸问题,使得深层网络的训练变得更加容易和有效。
在深度神经网络中,随着网络层数的增加,梯度在反向传播过程中逐渐变小,导致网络的训练变得困难。这是因为在传统的网络结构中,每个网络层都是通过直接逐层堆叠来进行信息的传递。当网络层数增加时,信息的传递路径变得更长,导致梯度逐渐消失。为了解决这个问题,ResNet提出了“残差学习”的概念。
ResNet引入了“残差块”(residual block)的概念,其中每个残差块包含一个跳跃连接(skip connection),将输入直接添加到输出中。这个跳跃连接允许梯度直接通过残差块传递,避免了梯度的消失问题。通过残差块的堆叠,ResNet可以构建非常深的网络,如ResNet-50、ResNet-101等。
ResNet的提出极大地促进了深度神经网络的发展。它在多个视觉任务上取得了非常好的性能,成为了目标检测、图像分类、图像分割等领域的重要基准模型。同时,ResNet的思想也影响了后续的深度神经网络架构设计,被广泛应用于各种深度学习任务中。
2. SelfAttention 层
自注意机制基于Vaswani等人在2017年提出的变压器架构。它计算所有输入单词的嵌入加权和,其中权重由每个单词与序列中其他单词的相关性决定。这些权重是通过嵌入之间的一系列点积运算计算的,然后是一个softmax函数来归一化权重。
与传统的序列模型相比,自注意机制有几个优点。它允许模型更有效地捕获长距离依赖关系,因为序列中任何单词的信息都可以直接影响任何其他单词的表示。它还支持并行计算,因为可以为每个单词独立计算注意力权重。这使得自我关注模型高效且可扩展。
python 实现的代码如下:
# 定义自注意力层
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, channels, height, width = x.size()
query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)
key = self.key_conv(x).view(batch_size, -1, height * width)
energy = torch.bmm(query, key)
attention = torch.softmax(energy, dim=-1)
value = self.value_conv(x).view(batch_size, -1, height * width)
out = torch.bmm(value, attention.permute(0, 2, 1))
out = out.view(batch_size, channels, height, width)
out = self.gamma * out + x
return out
3. ResNet34 + SelfAttention
这里只对resnet34做了添加,事实上其他版本的resnet网络添加自注意力机制是一样的,只需要把resnet34换成52、101之类的即可
关键代码如下:
添加后的效果如下:
4. 遥感卫星下的土地使用情况分类
下载链接在下面:
Resnet网络改进实战(添加SelfAttention自注意力机制):遥感卫星下的土地利用图像分类资源-CSDN文库
解压后的完整目录如下,data是数据集,runs是训练好的结果
4.1 土地使用情况数据集
总共有21类别,分别放在不同的目录下,训练集有1470张图片,验证集有630张数据
标签类别如下:
{
"0": "agricultural",
"1": "airplane",
"2": "baseballdiamond",
"3": "beach",
"4": "buildings",
"5": "chaparral",
"6": "denseresidential",
"7": "forest",
"8": "freeway",
"9": "golfcourse",
"10": "harbor",
"11": "intersection",
"12": "mediumresidential",
"13": "mobilehomepark",
"14": "overpass",
"15": "parkinglot",
"16": "river",
"17": "runway",
"18": "sparseresidential",
"19": "storagetanks",
"20": "tenniscourt"
}
可视化结果:
4.2 训练
这里训练了30个epoch,参数如下:
"train parameters": {
"model": "resnet34",
"pretrained": true,
"freeze_layers": true,
"batch_size": 8,
"epochs": 30,
"optim": "SGD",
"lr": 0.001,
"lrf": 0.0001
},
"Datasets": {
"trainSets number": 1470,
"validSets number": 630
},
"model": {
"total parameters": 21731845.0,
"train parameters": 621001,
"flops": 3742463488.0
},
想要更改训练超参数的可以在train脚本更改
4.3 训练结果
这里最后一轮的指标如下:
"epoch:29": {
"train info": {
"accuracy": 0.9836734693810635,
"agricultural": {
"Precision": 1.0,
"Recall": 0.9857,
"Specificity": 1.0,
"F1 score": 0.9928
},
"airplane": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"baseballdiamond": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"beach": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"buildings": {
"Precision": 0.9857,
"Recall": 0.9857,
"Specificity": 0.9993,
"F1 score": 0.9857
},
"chaparral": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"denseresidential": {
"Precision": 0.9286,
"Recall": 0.9286,
"Specificity": 0.9964,
"F1 score": 0.9286
},
"forest": {
"Precision": 0.9722,
"Recall": 1.0,
"Specificity": 0.9986,
"F1 score": 0.9859
},
"freeway": {
"Precision": 0.971,
"Recall": 0.9571,
"Specificity": 0.9986,
"F1 score": 0.964
},
"golfcourse": {
"Precision": 0.9853,
"Recall": 0.9571,
"Specificity": 0.9993,
"F1 score": 0.971
},
"harbor": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"intersection": {
"Precision": 1.0,
"Recall": 0.9857,
"Specificity": 1.0,
"F1 score": 0.9928
},
"mediumresidential": {
"Precision": 0.9559,
"Recall": 0.9286,
"Specificity": 0.9979,
"F1 score": 0.9421
},
"mobilehomepark": {
"Precision": 0.9718,
"Recall": 0.9857,
"Specificity": 0.9986,
"F1 score": 0.9787
},
"overpass": {
"Precision": 0.9577,
"Recall": 0.9714,
"Specificity": 0.9979,
"F1 score": 0.9645
},
"parkinglot": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"river": {
"Precision": 0.9718,
"Recall": 0.9857,
"Specificity": 0.9986,
"F1 score": 0.9787
},
"runway": {
"Precision": 0.9722,
"Recall": 1.0,
"Specificity": 0.9986,
"F1 score": 0.9859
},
"sparseresidential": {
"Precision": 0.9859,
"Recall": 1.0,
"Specificity": 0.9993,
"F1 score": 0.9929
},
"storagetanks": {
"Precision": 1.0,
"Recall": 0.9857,
"Specificity": 1.0,
"F1 score": 0.9928
},
"tenniscourt": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"mean precision": 0.9837190476190478,
"mean recall": 0.9836666666666668,
"mean specificity": 0.9991952380952381,
"mean f1 score": 0.9836380952380953
},
"valid info": {
"accuracy": 0.8571428571292516,
"agricultural": {
"Precision": 0.8437,
"Recall": 0.9,
"Specificity": 0.9917,
"F1 score": 0.8709
},
"airplane": {
"Precision": 1.0,
"Recall": 0.9667,
"Specificity": 1.0,
"F1 score": 0.9831
},
"baseballdiamond": {
"Precision": 0.8529,
"Recall": 0.9667,
"Specificity": 0.9917,
"F1 score": 0.9062
},
"beach": {
"Precision": 0.7692,
"Recall": 1.0,
"Specificity": 0.985,
"F1 score": 0.8695
},
"buildings": {
"Precision": 0.7714,
"Recall": 0.9,
"Specificity": 0.9867,
"F1 score": 0.8308
},
"chaparral": {
"Precision": 0.9062,
"Recall": 0.9667,
"Specificity": 0.995,
"F1 score": 0.9355
},
"denseresidential": {
"Precision": 0.72,
"Recall": 0.6,
"Specificity": 0.9883,
"F1 score": 0.6545
},
"forest": {
"Precision": 0.8788,
"Recall": 0.9667,
"Specificity": 0.9933,
"F1 score": 0.9207
},
"freeway": {
"Precision": 0.7241,
"Recall": 0.7,
"Specificity": 0.9867,
"F1 score": 0.7118
},
"golfcourse": {
"Precision": 0.8387,
"Recall": 0.8667,
"Specificity": 0.9917,
"F1 score": 0.8525
},
"harbor": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"intersection": {
"Precision": 0.8889,
"Recall": 0.8,
"Specificity": 0.995,
"F1 score": 0.8421
},
"mediumresidential": {
"Precision": 0.8077,
"Recall": 0.7,
"Specificity": 0.9917,
"F1 score": 0.75
},
"mobilehomepark": {
"Precision": 0.8437,
"Recall": 0.9,
"Specificity": 0.9917,
"F1 score": 0.8709
},
"overpass": {
"Precision": 0.6897,
"Recall": 0.6667,
"Specificity": 0.985,
"F1 score": 0.678
},
"parkinglot": {
"Precision": 0.9355,
"Recall": 0.9667,
"Specificity": 0.9967,
"F1 score": 0.9508
},
"river": {
"Precision": 0.9,
"Recall": 0.6,
"Specificity": 0.9967,
"F1 score": 0.72
},
"runway": {
"Precision": 0.8571,
"Recall": 1.0,
"Specificity": 0.9917,
"F1 score": 0.9231
},
"sparseresidential": {
"Precision": 0.9,
"Recall": 0.9,
"Specificity": 0.995,
"F1 score": 0.9
},
"storagetanks": {
"Precision": 0.92,
"Recall": 0.7667,
"Specificity": 0.9967,
"F1 score": 0.8364
},
"tenniscourt": {
"Precision": 1.0,
"Recall": 0.8667,
"Specificity": 1.0,
"F1 score": 0.9286
},
"mean precision": 0.8594095238095237,
"mean recall": 0.857157142857143,
"mean specificity": 0.9928714285714286,
"mean f1 score": 0.8540666666666668
}
}
曲线图:
混淆矩阵:
4.4 推理
推理结果如下:
想要更换数据集训练的话,参考readme文件即可