RESNET的复现pytorch版本

RESNET的复现pytorch版本

使用的数据为Object_102_CaDataset,可以在网上下载,也可以在评论区问。

RESNET模型的亮点

1.提出了残差模块。

2.使用Batch Normalization加速训练

3.残差网络:易于收敛,很好的解决了退化问题,模型可以很深,准确率大大提高了。

残差结构如下所示:

image-20240314180554463

首先,是模型构建部分

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride_1=1, stride_2=1, padding=1, kernel_size=(3, 3), short_cut=None):
        super(ResBlock, self).__init__()
        self.short_cut = short_cut
        self.model = Sequential(
            # 1.1
            Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride_1,
                   padding=padding),
            BatchNorm2d(out_channels),
            ReLU(),
            Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride_2,
                   padding=padding),
            BatchNorm2d(out_channels),
            ReLU(),
        )
        self.short_layer = Sequential(
            Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), stride=2, padding=0),
            BatchNorm2d(out_channels),
            ReLU(),
        )

        self.R = ReLU()

    def forward(self, x):
        f1 = x
        if self.short_cut is not None:
            f1 = self.short_layer(x)
        out = self.model(x)
        out = self.R(f1+out)
        return out

该部分为模型的残差块,使用了3*3的卷积,然后进行归一化。

对于整个模型的构建部分:

class Resnet_easier(nn.Module):
    def __init__(self, num_classes):
        super(Resnet_easier, self).__init__()
        self.model0 = Sequential(
            # 0
            # 输入3通道、输出64通道、卷积核大小、步长、补零、
            Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=2, padding=3),
            BatchNorm2d(64),
            ReLU(),
            MaxPool2d(kernel_size=(3, 3), stride=2, padding=1),
        )

        self.model1 = ResBlock(64, 64)

        self.model2 = ResBlock(64, 64)

        self.model3 = ResBlock(64, 128, stride_1=2, stride_2=1, short_cut=True)

        self.model4 = ResBlock(128, 128)

        self.model5 = ResBlock(128, 256, stride_1=2, stride_2=1, short_cut=True)

        self.model6 = ResBlock(256, 256)

        self.model7 = ResBlock(256, 512, stride_1=2, stride_2=1, short_cut=True)

        self.model8 = ResBlock(512, 512)

        # AAP 自适应平均池化
        self.aap = AdaptiveAvgPool2d((1, 1))
        # flatten 维度展平
        self.flatten = Flatten(start_dim=1)
        # FC 全连接层
        self.fc = Linear(512, num_classes)

    def forward(self, x):
        x = x.to(torch.float32)
        x = self.model0(x)
        x = self.model1(x)
        x = self.model2(x)
        x = self.model3(x)
        x = self.model4(x)
        x = self.model5(x)
        x = self.model6(x)
        x = self.model7(x)
        x = self.model8(x)
        # 最后3个
        x = self.aap(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

接下来是读入数据模块

class Object_102_CaDataset(Dataset):
    def __init__(self, folder):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        self.file_list = []
        label_names = [item for item in os.listdir(folder) if os.path.isdir(os.path.join(folder, item))]  # 获取文件夹下的所有标签
        label_to_index = dict((label, index) for index, label in enumerate(label_names))  # 将label转为数字
        self.all_picture_paths = self.get_all_picture(folder)  # 获取所有图片路径
        self.all_picture_labels = [label_to_index[os.path.split(os.path.dirname(os.path.abspath(path)))[1]] for path in
                                   self.file_list]
        self.mean = np.array(mean).reshape((1, 1, 3))
        self.std = np.array(std).reshape((1, 1, 3))

    def __getitem__(self, index):
        img = cv2.imread(self.all_picture_paths[index])
        if img is None:
            print(os.path.join("image", self.all_picture_paths[index]))
        img = cv2.resize(img, (224, 224))  #统一图片的尺寸
        img = img / 255
        img = (img - self.mean) / self.std
        img = np.transpose(img, [2, 0, 1])
        label = self.all_picture_labels[index]
        img = torch.tensor(img)
        label = torch.tensor(label)
        return img, label

    def __len__(self):
        return len(self.all_picture_paths)

    def get_all_picture(self, folder):
        for filename in os.listdir(folder):
            file_path = os.path.join(folder, filename)

            if os.path.isfile(file_path):
                self.file_list.append(file_path)
            elif os.path.isdir(file_path):
                self.file_list = self.get_all_picture(file_path)
        return self.file_list

使用上述dataloader可以方便的对数据进行读取操作。

接下来就是整个的训练模块

import torch
from torch import nn
from torch.utils.data import DataLoader

from ResNet.ResNet18 import Resnet18
from ResNet.ResNet18_easier import Resnet_easier
from ResNet.dataset import Object_102_CaDataset
from ResNet.res_net import ResNet, ResBlock

from torchsummary import summary
data_dir = 'E:\PostGraduate\Paper_review\computer_view_model/ResNet/data/101_ObjectCategories'
Object_102 = Object_102_CaDataset(data_dir)
train_size = int(len(Object_102) * 0.7)
# print(train_size)
test_size = len(Object_102) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(Object_102, [train_size, test_size])
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
#显示数据,此处的注释内容可以让我们看到读取的图片
# import random
# from matplotlib import pyplot as plt
# import matplotlib
# matplotlib.use('TkAgg')
# def denorm(img):
#     for i in range(img.shape[0]):
#         img[i] = img[i] * std[i] + mean[i]
#     img = torch.clamp(img, 0., 1.)
#     return img
# plt.figure(figsize=(8, 8))
# for i in range(9):
#     img, label = train_dataset[random.randint(0, len(train_dataset))]
#     img = denorm(img)
#     img = img.permute(1, 2, 0)
#     ax = plt.subplot(3, 3, i + 1)
#     ax.imshow(img.numpy()[:, :, ::-1])
#     ax.set_title("label = %d" % label)
#     ax.set_xticks([])
#     ax.set_yticks([])
# plt.show()

train_iter = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_iter = DataLoader(train_dataset, batch_size=64)
model = Resnet_easier(102)
# print(summary(model, (3, 224, 224)))
epoch = 50  # 训练轮次
optmizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# optmizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()#.cuda()  # 定义交叉熵损失函数
log_interval = 10

train_losses = []
train_counter = []
test_losses = []
test_counter = [i * len(train_iter.dataset) for i in range(epoch + 1)]




# test_loop(model,'cpu',test_iter)
def train_loop(n_epochs, optimizer, model, loss_fn, train_loader):
    for epoch in range(1, n_epochs + 1):
        model.train()
        for i, data in enumerate(train_loader):
            correct = 0
            (images, label) = data
            images = images#.cuda()
            label = label#.cuda()
            # print(len(images))
            output = model(images)
            loss = loss_fn(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pred = output.data.max(1, keepdim=True)[1]
            pred = torch.tensor(pred, dtype=torch.float32)
            for index in range(0, len(pred)):
                if pred[index] == label[index]:
                    correct += 1
            # correct = torch.eq(pred, label).sum()
            # print(correct)
            if i % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\t accuracy:[{}/{} ({:.0f}%)] \tLoss: {:.6f}'.format(
                    epoch, i * len(images), len(train_loader.dataset),
                           100. * i / len(train_loader), correct, len(pred), 100. * correct / len(pred), loss.item()))
                train_losses.append(loss.item())
                train_counter.append(
                    (i * 64) + ((epoch - 1) * len(train_loader.dataset)))
                torch.save(model.state_dict(), 'model_paramter/test/model.pth')
                torch.save(optimizer.state_dict(), 'model_paramter/test/optimizer.pth')
                # test_loop(model, 'cpu', test_iter)

# PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
# dictionary = torch.load(PATH)
# model.load_state_dict(dictionary)
train_loop(epoch, optmizer, model, loss_fn, train_iter)

# PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
# dictionary = torch.load(PATH)
# model.load_state_dict(dictionary)
# test_loop(model, 'cpu', test_iter)

若要测试数据的准确度等内容可以参考之前的博文使用LSTm进行情感分析,对test部分进行修改即可。

也可以参考下面的

PATH = 'E:\\PostGraduate\\Paper_review\\computer_view_model\\ResNet/model_paramter/model.pth'
dictionary = torch.load(PATH)
model.load_state_dict(dictionary)
def test_loop(model, device, test_iter):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_iter:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            output = output.data.max(1, keepdim=True)[1]
            output = torch.tensor(output, dtype=torch.float32)
            # loss_func = loss_fn(output, target)
            # test_loss += loss_func
            pred = output
            for index in range(0, len(pred)):
                if pred[index] == target[index]:
                    correct += 1
    test_loss /= len(test_iter.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_iter.dataset),
        100. * correct / len(test_iter.dataset)))

test_loop(model,'cpu',test_iter)

loss /= len(test_iter.dataset)
test_losses.append(test_loss)
print(‘\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n’.format(
test_loss, correct, len(test_iter.dataset),
100. * correct / len(test_iter.dataset)))

test_loop(model,‘cpu’,test_iter)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/476309.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

真实数据!一张切片实现101种蛋白的超多重空间单细胞原位成像

头颈鳞状细胞癌 (HNSCC) 是第七大常见癌症。免疫检查点抑制剂 (ICIs) 在治疗复发/转移病例方面显示出良好前景,约30%的患者可获得持久获益。但是目前反映HNSCC肿瘤微环境 (TME) 特征的生物标志物有限,需要更深入的组织表征分析。因此,需要新的…

linux查看cpu/内存/磁盘利用率

1、cpu 命令: top 2、内存 命令: free -h 3、磁盘 命令: df -h

《操作系统真相还原》读书笔记九:用c编写内核

用c语言先编写一个死循环 main.c int main(void) {while(1);return 0; }编译该文件 gcc -c -o main.o main.c-- Ttext参数表示起始虚拟地址为0xc0001500 -e参数表示程序入口地址 ld main.o -Ttext 0xc0001500 -e main -o kernel.bin-- 将kernel.bin写入第9个扇区 dd if/ho…

十九、网络编程

目录 一、什么是网络编程二、网络编程三要素2.1 IP2.2 InetAddress的使用2.3 端口号2.4 协议 三、UDP通信程序3.1 发送数据3.2 接收数据3.3 练习 四、UDP的三种通信方式五、TCP的通信程序六、三次握手和四次挥手七、练习7.1 TCP通信练习1——多发多收7.2 TCP通信练习2——接收和…

Cookie使用

文章目录 一、Cookie基本使用1、发送Cookie2、获取Cookie 二、Cookie原理三、Cookie使用细节 一、Cookie基本使用 1、发送Cookie package com.itheima.web.cookie;import javax.servlet.*; import javax.servlet.http.*; import javax.servlet.annotation.*; import java.io.I…

82.删除排序链表中的重复元素II

给定一个已排序的链表的头 head , 删除原始链表中所有重复数字的节点,只留下不同的数字 。返回 已排序的链表 。 示例 1: 输入:head [1,2,3,3,4,4,5] 输出:[1,2,5]示例 2: 输入:head [1,1,1,2…

【OJ比赛日历】快周末了,不来一场比赛吗? #03.23-03.29 #16场

CompHub[1] 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…)比赛。本账号会推送最新的比赛消息,欢迎关注! 以下信息仅供参考,以比赛官网为准 目录 2024-03-23(周六) #7场比赛2024-03-24…

高级数据结构 <AVL树>

本文已收录至《数据结构(C/C语言)》专栏! 作者:ARMCSKGT 目录 前言正文AVL树的性质AVL树的定义AVL树的插入函数左单旋右单旋右左双旋左右双旋 检验AVL树的合法性关于AVL树 最后 前言 前面我们学习了二叉树,普通的二叉树没有任何特殊性质&…

C语言易错知识点:二级指针、数组指针、函数指针

指针在C语言中非常关键,除开一些常见的指针用法,还有一些可能会比较生疏,但有时却也必不可少,本文章整理了一些易错知识点,希望能有所帮助! 1.二级指针: parr是一个指针数组,其中每…

GEE遥感云大数据林业应用典型案例及GPT模型应用

近年来遥感技术得到了突飞猛进的发展,航天、航空、临近空间等多遥感平台不断增加,数据的空间、时间、光谱分辨率不断提高,数据量猛增,遥感数据已经越来越具有大数据特征。遥感大数据的出现为相关研究提供了前所未有的机遇&#xf…

数据结构:初识树和二叉树

目前主流的方式是左孩子右兄弟表示法 我们的文件系统就是一个树 以上就是树的概念,我们今天还要来学习一种从树演变的重要的结构:二叉树 顾名思义二叉树就是一个结点最多有两个子树。 其中我们还要了解满二叉树和完全二叉树的概念 注意我们的完全二叉…

【一起学Rust | 基础篇】rust线程与并发

文章目录 前言一、创建线程二、mpsc多生产者单消费者模型1.创建一个简单的模型2.分批发送数据3. 使用clone来产生多个生产者 三、共享状态:互斥锁1. 创建一个简单的锁2. 使用互斥锁解决引用问题 前言 并发编程(Concurrent programming)&#…

网络: 传输层

功能: 将数据从发送到传给接收端 UDP 无连接状态: 知道对端的IP和端口号就直接进行传输, 不需要建立连接不可靠: 没有确认机制, 没有重传机制. 出错不会管面向数据包: 不能够灵活的控制读写数据的次数和数量 发送速度快: 立即发送 报文结构 TCP 面向连接可靠 校验和序列号(按…

Java项目基于Docker打包发布

1.打包应用 mvn clean package -DskipTests 或者 2.新建dockerfile FROM openjdk:8 #设置工作目录 WORKDIR /opt#COPY wms-app-0.0.1-SNAPSHOT.jar /wms-app/app.jar ADD wms-app-0.0.1-SNAPSHOT.jar app.jar #配置容器暴露的端口 EXPOSE 8080 #查看是否已经copy进去 R…

YOLOv1学习

YOLO系列学习笔记 YOLOv1评价指标PrecisionRecallAPmAP 置信度分数统一检测框架网络结构训练损失函数 测试YOLOv1的不足实验结论 YOLOv1 优点: 快全图推理,背景错误率低泛化能力强 每个图像固定大小 448*448,系统将输入图像分成S S网格。…

视频素材库哪里找?推荐几个高质量的无水印视频素材网

在寻找创意优质素材的道路上,拥有一个好的导航仪至关重要。这不仅仅是关于找到一张图片或一个视频,而是关于发现那些能让你的项目闪耀的宝藏。今天,我将混合介绍国内外的素材网站,旨在为你提供一个全面的视角,同时尽量…

Python之Web开发中级教程----Django站点管理

Python之Web开发中级教程----Django站点管理 网站的开发分为两部分:内容发布和公共访问 内容发布是由网站的管理员负责查看、添加、修改、删除数据 Django能够根据定义的模型类自动地生成管理模块 使用Django的管理模块, 需要按照如下步骤操作 : 1.管理界面本地…

Python 安装目录及虚拟环境详解

Python 安装目录 原文链接:https://blog.csdn.net/xhyue_0209/article/details/106661191 Python 虚拟环境 python 虚拟环境图解 python 虚拟环境配置与详情 原文链接:https://www.cnblogs.com/hhaostudy/p/17321646.html

C++进阶02 多态性

听课笔记简单整理,供小伙伴们参考~🥝🥝 第1版:听课的记录代码~🧩🧩 编辑:梅头脑🌸 审核:文心一言 目录 🐳课程来源 🐳前言 🐋运…

LeetCode困难题----84.柱状图中的最大矩形

今天刷LeetCode时遇到了一个很有意思的题: 看了半天题解还是没理解他的代码想要表达的是什么意思,在思考了很久之后,终于,我理解了这道题,接下来让我带你们走进这道题。 这道题的大概意思是,给你一个heights[]数组,(宽为1)让你求出他们可以组合出的最大面积 首先,我们先用暴力法…