基于ResNet框架的CNN

数据准备

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

一、训练集和验证集的划分

#spile_data.py

import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla] #['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
mkfile('flower_data/train') #生成train文件夹
for cla in flower_class:
    mkfile('flower_data/train/'+cla) #在train文件夹下生成各个类别的文件夹

mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/'+cla)

split_rate = 0.1
for cla in flower_class:
    cla_path = file + '/' + cla + '/'
    images = os.listdir(cla_path)
    num = len(images)
    eval_index = random.sample(images, k=int(num*split_rate)) #在images中随机获取0.1的图片
    for index, image in enumerate(images):
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
    print()

print("processing done!")

二、ResNet网络

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import mmd
from attention import ChannelAttention
from attention import SpatialAttention
import torch


__all__ = ['ResNet', 'resnet50']


model_urls = {
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}

def conv3x3(in_planes, out_planes, stride=1,groups=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


'''
Resnet中BasicBlock结构,ResNet中使用的网络结构。分2步走:3x3; 3x3
'''
class BasicBlock(nn.Module):
    expansion = 1  # 最后输出的通道数扩充的比例

    # BN层来加快网络模型的收敛速度/训练速度/解决梯度消失或者梯度爆炸的问题
    # 对batch中所有的同一个channel的数据元素进行标准化处理。一个batch共享一套参数
    # 即如果有C个通道,对N*H*W进行标准化处理,一共进行C次。

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # downsample是用一个1x1的卷积核处理,改变通道数,如果H/W尺度也不一样就设计stride
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


'''
Resnet中Bottleneck结构,ResNet中使用的网络结构。目的是为了降低参数量,分三步走:
1数据降维(1x1),2常规卷积核的卷积(3x3),3数据升维(1x1)
结果图片长宽不变,通道数扩大4倍
'''
class Bottleneck(nn.Module):
    expansion = 4  # 最后输出的通道数扩充的比例

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, norm_layer=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        # BN层来加快网络模型的收敛速度/训练速度/解决梯度消失或者梯度爆炸的问题
        # 对batch中所有的同一个channel的数据元素进行标准化处理。一个batch共享一套参数
        # 即如果有C个通道,对N*H*W进行标准化处理,一共进行C次。
        self.bn1 = nn.BatchNorm2d(planes)  # 卷积层后加BatchNorm2d,按照channel进行归一化
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.dropout=nn.Dropout()
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        # downsample是用一个1x1的卷积核处理,改变通道数,如果H/W尺度也不一样就设计stride
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


'''
ResNet由以下组成:
1.conv1、norm1、relu(当指定了deep_stem,这三个将被stem代替)
2.maxpool
3.layer1~layer4(定义为ResLayer类,分别由多个BasicBlock或Bottleneck组成)
'''
class ResNet(nn.Module):
    # 参数block指明残差块是两层或三层,参数layers指明每个卷积层需要的残差块数量,num_classes指明分类数,zero_init_residual是否初始化为0
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self.inplanes = 64
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(12, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)

        # 网络的第一层加入注意力机制
        # self.ca = ChannelAttention(self.inplanes)
        # self.sa = SpatialAttention()

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)

        # 网络的卷积层的最后一层加入注意力机制
        # self.ca1 = ChannelAttention(self.inplanes)
        # self.sa1 = SpatialAttention()

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应平均池化,指定输出(H,W)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    # 构造ResLayer类,layer1~layer4
    # block:BasicBlock/Bottleneck; planes:块的输入通道数; blocks:块的数目
    def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        downsample = None  # downSample的作用于在残差连接时 将输入的图像的通道数变成和卷积操作的尺寸一致
        if stride != 1 or self.inplanes != planes * block.expansion:
            # 通道数恢复成一致/长宽恢复一致
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, norm_layer=norm_layer))

        return nn.Sequential(*layers)

    '''
    卷积/池化后的tensor维度为(batchsize,channels,x,y),其中x.size(0)指batchsize的值,
    通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y)
    即将(channels,x,y)拉直,然后就可以和fc层连接
    
    因为最后avgpool(1,1)指定输出长*宽为1*1,通道为512*4,所以channels*x*y=2048
    '''
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # x = self.ca(x) * x
        # x = self.sa(x) * x

        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # x = self.ca1(x) * x
        # x = self.sa1(x) * x

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        #x=self.fc(x)

        return x

class DANNet(nn.Module):

    def __init__(self, num_classes=2):
        super(DANNet, self).__init__()
        self.sharedNet = resnet50(False)
        self.cls_fc = nn.Linear(2048, num_classes)  # channels*x*y=2048*1*1,见上面的备注

    def forward(self, source, target):
        loss = 0
        source = self.sharedNet(source)
        if self.training == True:
            target = self.sharedNet(target)
            # loss += mmd.mmd_rbf_accelerate(source, target)
            loss += mmd.mmd_rbf_noaccelerate(source, target)

        source = self.cls_fc(source)
        #target = self.cls_fc(target)

        return source, loss

def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model

三、训练模型

#train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from model import resnet34, resnet101
import torchvision.models.resnet


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

#数据增强操作,训练集:随机裁剪(RandomResizedCrop)、随机水平翻转(RandomHorizontalFlip)、转换为张量(ToTensor)以及归一化(Normalize)
#验证集:大小调整(Resize)、中心裁剪(CenterCrop)、转换为张量(ToTensor)以及归一化(Normalize)
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#来自官网参数
    "val": transforms.Compose([transforms.Resize(256),#将最小边长缩放到256
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}


data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set path

train_dataset = datasets.ImageFolder(root=image_path + "train",
                                     transform=data_transform["train"])
train_num = len(train_dataset) #3306

flower_list = train_dataset.class_to_idx   #{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
cla_dict = dict((val, key) for key, val in flower_list.items())  #{0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)  #将cla_dict字典对象转换为JSON格式的字符串,并通过indent=4参数指定缩进为4个空格
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)  #364
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)
#net = resnet34()
net = resnet34(num_classes=5)
# load pretrain weights

# model_weight_path = "./resnet34-pre.pth"
# missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型参数

# for param in net.parameters():
#     param.requires_grad = False
# change fc layer structure

# inchannel = net.fc.in_features
# net.fc = nn.Linear(inchannel, 5)


net.to(device)  #将神经网络模型net移动到指定的设备上,这样模型就可以在GPU/CPU上计算

loss_function = nn.CrossEntropyLoss()  #损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0001)  #优化器

best_acc = 0.0
save_path = './resNet34.pth'
#一个epoch表示对整个训练数据集进行一次完整的迭代训练
for epoch in range(3):
    # train
    net.train()
    running_loss = 0.0
    #step表示当前的步数(或者称为批次数),data则表示从train_loader中加载的数据对象
    for step, data in enumerate(train_loader, start=0):
        images, labels = data  #images:(16,3,224,224) labels:(16,)
        optimizer.zero_grad()
        logits = net(images.to(device)) #logits:(16,5)将输入的图像数据images传入神经网络net
        loss = loss_function(logits, labels.to(device)) #1.6871 计算模型输出logits进行标准化(softmax),再计算每个样本预测标签和真实标签的交叉熵,对于整个批次的样本,计算平均交叉熵损失
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()  #累加每个批次的损失值
        # print train process
        rate = (step+1)/len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
    print()

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))  # eval model only have last output layer
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]  #torch.max包含两个维度信息,第一个维度是最大值,第二个维度是最大值对应的索引
            acc += (predict_y == val_labels.to(device)).sum().item() #每一次的validate_loader中预测正确的个数,.item() 方法转换为标量
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)  #state_dict()方法返回模型的参数字典,save_path保存模型参数的文件路径
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

四、预测

#predict.py

import torch
from model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# load image
img = Image.open("./roses.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/169020.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

YOLO目标检测——无人机航拍行人检测数据集下载分享【含对应voc、coc和yolo三种格式标签】

实际项目应用:智能交通管理、城市安防监控、公共安全救援等领域数据集说明:无人机航拍行人检测数据集,真实场景的高质量图片数据,数据场景丰富标签说明:使用lableimg标注软件标注,标注框质量高,…

whisper使用方法

看这个 github https://github.com/Purfview/whisper-standalone-win/tags下载 视频提取音频 ffmpeg -i 222.mp4 -vn -b:a 128k -c:a mp3 output.mp3截取4秒后的音频 ffmpeg -i output.mp3 -ss 4 -c copy output2.mp3使用 whisper-faster.exe 生成字幕 whisper-faster.exe …

基于水基湍流算法优化概率神经网络PNN的分类预测 - 附代码

基于水基湍流算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于水基湍流算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于水基湍流优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神…

游戏报错d3dcompiler_47.dll缺失怎么修复,总结多种修复方法

在使用这些软件和游戏的过程中,我们常常会遇到一些问题,其中之一就是d3dcompiler_47.dll丢失的问题。这个问题可能会导致软件或游戏无法正常运行,给用户带来困扰。本文将详细介绍解决软件游戏d3dcompiler_47.dll丢失的方法,帮助您…

读懂:“消费报销”模式新零售打法,适用连锁门店加盟的营销方案

读懂:“消费报销”模式新零售打法,适用连锁门店加盟的营销方案 引言:2023年的双十一已经落下帷幕,作为每年的经典电商促销节,今年已是第15个年头,但是今年各大电商平台却都是非常默契的,没有公布…

《数据:挖掘价值,洞察未来

大数据:挖掘价值,洞察未来 我们正身处一个数据驱动的时代,大数据已经成为企业和个人决策的重要依据。本文将深入探讨大数据的魅力,挖掘其价值,并洞察未来发展趋势,让我们一起领略大数据的无穷奥秘。 一、大…

《云计算:云端协同,智慧互联》

《云计算:云端协同,智慧互联》 云计算,这个科技领域中的热门词汇,正在逐渐改变我们的生活方式。它像一座座无形的桥梁,将世界各地的设备、数据、应用紧密连接在一起,实现了云端协同,智慧互联的愿…

比科奇推出5G小基站开放式RAN射频单元的高性能低功耗SoC

全新的PC805作为业界首款支持25Gbps速率eCPRI和CPRI前传接口的系统级芯片(SoC),消除了实现低成本开放式射频单元的障碍 中国北京,2023年11月 - 5G开放式RAN基带芯片和电信级软件提供商比科奇(Picocom)今日…

pip list 和 conda list的区别

PS : 网上说conda activate了之后就可以随意pip了 可以conda和pip混用 但是安全起见还是尽量用pip 这样就算activate了,进入base虚拟环境了 conda与pip的区别 来源 Conda和pip通常被认为几乎完全相同。虽然这两个工具的某些功能重叠,但它们设计用于不…

Linux文件目录以及文件类型

文章目录 Home根目录 //bin/sbin/etc/root/lib/dev/proc/sys/tmp/boot/mnt/media/usr 文件类型 Home 当尝试使用gedit等编辑器保存文件时,系统默认通常会先打开个人用户的“家”(home)目录, 建议在通常情况下个人相关的内容也是保…

在windows Server安装Let‘s Encrypt的SSL证书

1、到官网(https://certbot.eff.org/instructions?wswebproduct&oswindows)下载 certbot客户端。 2、安装客户端(全部默认安装即可) 3、暂停IIS中的网站 开始菜单中找到并运行“Certbot”,输入指令: …

【Web】Ctfshow Nodejs刷题记录

目录 ①web334 ②web335 ③web336 ④web337 ⑤web338 ⑥web339 ⑦web340 ⑧web341 ⑨web342-343 ⑩web344 ①web334 进来是一个登录界面 下载附件,简单代码审计 表单传ctfshow 123456即可 ②web335 进来提示 get上传eval参数执行nodejs代码 payload: …

Java贪吃蛇小游戏

Java贪吃蛇小游戏 import javax.swing.*; import java.awt.*; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.KeyEvent; import java.awt.event.KeyListener; import java.util.LinkedList; import java.util.Random;publi…

【开源】基于Vue.js的计算机机房作业管理系统的设计和实现

项目编号: S 017 ,文末获取源码。 \color{red}{项目编号:S017,文末获取源码。} 项目编号:S017,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 登录注册模块2.2 课程管理模块2.3 课…

前端本地存储数据库IndexedDB

前端本地存储数据库IndexedDB 1、前言2、什么是 indexedDB?3、什么是 localForage?4、localForage 的使用5、VUE 推荐使用 Pinia 管理 localForage 1、前言 前端本地化存储算是一个老生常谈的话题了,我们对于 cookies、Web Storage&#xff…

【C++ STL】string类-----迭代器(什么是迭代器?迭代器分哪几类?迭代器的接口如何使用?)

目录 一、前言 二、什么是迭代器 三、迭代器的分类与接口 💦迭代器的分类 💦迭代器的接口 💦迭代器与接口之间的关联 四、string类中迭代器的应用 💦 定义string类----迭代器 💦string类中迭代器进行遍历 ✨be…

庖丁解牛:NIO核心概念与机制详解 06 _ 连网和异步 I/O

文章目录 Pre概述异步 I/OSelectors打开一个 ServerSocketChannel选择键内部循环监听新连接接受新的连接删除处理过的 SelectionKey传入的 I/O回到主循环 Pre 庖丁解牛:NIO核心概念与机制详解 01 庖丁解牛:NIO核心概念与机制详解 02 _ 缓冲区的细节实现…

C# 监测 Windows 设备变动事件

本程序通过WPF窗口的 WindowProc 函数处理Windows的硬件或配置改变的事件。开发环境为VS 2022。 基础信息 硬件或配置改变的基础有以下内容: 消息: WM_DEVICECHANGE 要实现的WindowProc 函数参数: protected IntPtr WndProc(IntPtr hwnd, int msg, In…

React 中 react-i18next 切换语言( 项目国际化 )

背景 平时中会遇到需求,就是切换语言,语种等。其实总的来说都是用i18n来实现的 思路 首先在项目中安装i18n插件,然后将插件引入到项目,然后配置语言包(语言包需要你自己来进行配置,自己编写语言包&#xff…

C++初阶 | [四] 类和对象(下)

摘要:初始化列表,explicit关键字,匿名对象,static成员,友元,内部类,编译器优化 类是对某一类实体(对象)来进行描述的,描述该对象具有哪些属性、哪些方法,描述完成后就形成…