pytorch-ResNet18简单复现

目录

  • 1. ResNet block
  • 2. ResNet18网络结构
  • 3. 完整代码
    • 3.1 网络代码
    • 3.2 训练代码

1. ResNet block

ResNet block有两个convolution和一个short cut层,如下图:
在这里插入图片描述
代码:

class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(ResBlk, self).__init__()

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self. bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_in != ch_out:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        out = self.extra(x) + out
        out = F.relu(out)

        return out

2. ResNet18网络结构

在这里插入图片描述
在这里插入图片描述
从上图可以看出,resnet18有1个卷积层,4个残差层和1一个线性输出层,其中每个残差层有2个resnet块,每个块有2个卷积层。
对于cifar10数据来说,输入层[b, 64, 32,32],输出是10分类
代码:

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # 四个残差层
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # 全连接层
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    # 创建一个残差层
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        #out = F.avg_pool2d(out, 4)
        out = F.adaptive_avg_pool2d(out, [1, 1])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

3. 完整代码

3.1 网络代码

import torch
from torch import nn
from torch.nn import functional as F


class ResBlk(nn.Module):
    expansion = 1

    def __init__(self, ch_in, ch_out, stride):
        super(ResBlk, self).__init__()

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_in != ch_out:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        out = self.extra(x) + out
        out = F.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # 初始卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        # 四个残差层
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # 全连接层
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    # 创建一个残差层
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        #out = F.avg_pool2d(out, 4)
        out = F.adaptive_avg_pool2d(out, [1, 1])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(ResBlk, [2, 2, 2, 2], 10)


if __name__ == '__main__':
    model = ResNet18()
    print(model)

3.2 训练代码

import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
import sys

sys.path.append('.')
#from Lenet5 import Lenet5
from resnet import ResNet18


def main():
    batchz = 128
    cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifa', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchz, shuffle=True)

    device = torch.device('cuda')
    #model = Lenet5().to(device)
    model = ResNet18().to(device)
    crition = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        model.train()
        for batch, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = crition(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/763707.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于C语言+控制台的学生信息管理系统

博主介绍: 大家好,本人精通Java、Python、Php、C#、C、C编程语言,同时也熟练掌握微信小程序和Android等技术,能够为大家提供全方位的技术支持和交流。 我有丰富的成品Java、Python、C#毕设项目经验,能够为学生提供各类…

绘唐3一键追爆款文刻创作聚星文社

聚星文社是一个中国的文学社交平台,提供了一个让作家和读者相互交流和分享作品的平台。 在聚星文社,作家可以在平台上发布自己的作品,获得读者的阅读和评论,同时也可以与其他作家进行交流与学习。 点击下载即可 读者可以在平台上…

第五届计算机、大数据与人工智能国际会议(ICCBD+AI 2024)

随着科技的飞速发展,计算机、大数据和人工智能等前沿技术已成为推动社会进步的重要力量。为了加强这一领域的学术交流与合作,促进技术创新与发展,第五届计算机、大数据与人工智能国际会议(ICCBDAI 2024)将于2024年11月…

Linux基础IO操作详解

C文件IO相关接口 fopen函数 pathname: 要打开的文件名字符串mode: 访问文件的模式 模式描述含义“r”读文件不存在失败返回null“r”读写文件不存在打开失败返回null,文件存在则从头开始覆盖现有的数据(不会清空数据)“w”写文件不存在创建…

蜜雪冰城小程序逆向

app和小程序算法一样 小程序是wasm

并发编程工具集——Lock和Condition(下)(十四)

如何利用两个条件变量快速实现阻塞队列呢? 入队与出队需要同步,用一个锁。一个阻塞队列,需要两个条件变量,一个是队列不空(空队列不允许出队),另一个是队列不满(队列已满不允许入队&…

Spring每日面试题(day1)

目录 JavaWeb三大组件依赖注入的方式Autowire和Resurce有什么区别?Spring Boot的优点Spring IoC是什么?说说Spring Aop的优点Component和Bean的区别自定义注解时使用的RetentionPolicy枚举类有哪些值?Spring Boot自动装配原理Spring MVC工作原…

CV每日论文--2024.6.28

1、On Scaling Up 3D Gaussian Splatting Training 中文标题:扩展 3D 高斯泼溅训练 简介:3D高斯点描(3DGS)由于其卓越的视觉质量和渲染速度,越来越受欢迎用于3D重建。然而,3DGS的训练目前仅在单个GPU上进行,由于内存限制,它的处理高分辨率和大规模3D重建…

链表--逆置

#include <stdio.h> #include <stdlib.h>struct ListNode {int data;struct ListNode *next; };struct ListNode *createlist(); /*裁判实现&#xff0c;细节不表*/ struct ListNode *reverse( struct ListNode *head ); void printlist( struct ListNode *head ) …

一些迷你型信息系统

只有一个表&#xff0c;比较简单易用&#xff1b; 1 博物馆信息查询系统 信息录入&#xff0c;浏览&#xff0c;添加&#xff0c;更新&#xff0c;删除&#xff1b; 下载&#xff0c; https://download.csdn.net/download/bcbobo21cn/89505217

基于LangChain-Chatchat实现的RAG-本地知识库的问答应用[6]-实现Milvus向量检索+实现自定义关键词调整Embedding模型

基于LangChain-Chatchat实现的RAG-本地知识库的问答应用[6]-实现Milvus向量检索+实现自定义关键词调整Embedding模型 0.Milvus与Faiss对比 Milvus相对于Faiss的优势主要体现在以下几个方面: 在线数据更新与实时搜索: Milvus支持在线的数据更新和实时的向量搜索,这意味着在数…

调度器APScheduler定时执行任务

APScheduler&#xff08;Advanced Python Scheduler&#xff09;是一个Python库&#xff0c;用于调度任务&#xff0c;使其在预定的时间间隔或特定时间点执行。它支持多种调度方式&#xff0c;包括定时&#xff08;interval&#xff09;、日期&#xff08;date&#xff09;和Cr…

网络安全等级保护2.0(等保2.0)全面解析

一、等保2.0的定义和背景 网络安全等级保护2.0&#xff08;简称“等保2.0”&#xff09;是我国网络安全领域的基本制度、基本策略、基本方法。它是在《中华人民共和国网络安全法》指导下&#xff0c;对我国网络安全等级保护制度进行的重大升级。等保2.0的发布与实施&#xff0c…

WEB01MySQL安装和数据库

第一天、WEB课程 web课程主要讲三部分内容 数据库 数据库介绍 什么是数据库 数据存储的仓库&#xff0c;其本质也是一个文件系统 数据库会按照特定的格式对数据进行存储&#xff0c;用户可以对数据库中的数据进行增加&#xff0c;修改&#xff0c;删除及查询操作。 数据库…

240701_昇思学习打卡-Day13-Vision Transformer图像分类

240701_昇思学习打卡-Day13-Vision Transformer图像分类 Transformer最开始是应用在NLP领域的&#xff0c;拿过来用到图像中取得了很好的效果&#xff0c;然后他就要摇身一变&#xff0c;就叫Vision Transformer。 该部分内容还是参考太阳花的小绿豆-CSDN博客大佬的视频11.1 …

JTracker IDEA 中最好的 MyBatis 日志格式化插件

前言 如果你使用 MyBatis ORM 框架&#xff0c;那么你应该用过 MyBatis Log 格式化插件&#xff0c;它可以让我们的程序输出的日志更人性化。 但是有一个问题&#xff0c;通常我们只能看到格式化后的效果&#xff0c;没办法知道这个 SQL 是谁执行的以及调用的链路。 如下图所…

python之列表

1.概述 线性的数据结构 有序的队列&#xff0c;可以使用下标进行索引 可变的序列 列表中的个体称为元素&#xff0c;多个元素组成列表 列表的语法是[],多个元素使用逗号分隔 列表中的元素类型可以不同 2.定义列表 使用【】方法&#xff0c;多个元素之间使用逗号进行分隔 使用li…

【单片机毕业设计选题24042】-基于无线传输的老人健康监护系统

系统功能: 系统操作说明&#xff1a; 上电后OLED显示 “欢迎使用健康监护系统请稍后”&#xff0c;两秒后显示Connecting...表示 正在连接阿里云&#xff0c;正常连接阿里云后显示第一页面&#xff0c;如长时间显示Connecting...请 检查WiFi网络是否正确。 第一页面第一行…

机器人入门路线及参考资料(机器人操作方向)

机器人&#xff08;操作方向&#xff09;入门路线及参考资料 前言1 数理基础和编程2 机器人学理论3 计算机视觉4 机器人实操5 专攻方向总结Reference: 前言 随着机器人和具身智能时代的到来&#xff0c;机器人越来越受到大家的重视&#xff0c;本文就介绍了机器人&#xff08;…

方正小标宋简体、仿宋GB2312、楷体GB2312字体

文章目录 下载地址所有的文件wps使用方正小标宋简体、仿宋GB2312、楷体GB2312 字体用途方正小标宋简体仿宋GB2312楷体GB2312 下载地址 【金山文档 | WPS云文档】 方正小标宋简体、仿宋GB2312、楷体GB2312 https://kdocs.cn/l/cksgHDLneqDk 所有的文件 wps使用 方正小标宋简体…