残差网络实现

代码中涉及的图片实验数据下载地址:https://download.csdn.net/download/m0_37567738/88235543?spm=1001.2014.3001.5501

代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
#from utils import load_data,get_accur,train
import time


import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import numpy as np


def load_data(path, batch_size):
    datasets = torchvision.datasets.ImageFolder(
        root = path,
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
    )

    dataloder = DataLoader(datasets, batch_size=batch_size, shuffle=True)
    return datasets,dataloder

def get_accur(preds, labels):
    preds = preds.argmax(dim=1)
    return torch.sum(preds == labels).item()

def train(model, epochs, learning_rate, dataloader, criterion, testdataloader):
    optimizer = optim.Adam(model.parameters(),lr=learning_rate)

    train_loss_list = []
    test_loss_list = []
    train_accur_list = []
    test_accur_list = []
    train_len = len(dataloader.dataset)
    test_len = len(testdataloader.dataset)

    for i in range(epochs):
        train_loss = 0.0
        train_accur = 0
        test_loss = 0.0
        test_accur = 0
        for batch in dataloader:
            imgs, labels = batch
            preds = model(imgs)
            optimizer.zero_grad()
            loss = criterion(preds, labels)

            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_accur += get_accur(preds,labels)

        train_loss_list.append(train_loss)
        train_accur_list.append(train_accur / train_len)

        for batch in testdataloader:
            imgs, labels = batch
            preds = model(imgs)
            loss = criterion(preds, labels)
            test_loss += loss.item()
            test_accur += get_accur(preds,labels)

        test_loss_list.append(test_loss)
        test_accur_list.append(test_accur / test_len)

        print("epoch {} : train_loss : {}; train_accur : {}".format(i + 1, train_loss, train_accur / train_len))

    return np.array(train_accur_list), np.array(train_loss_list), np.array(test_accur_list), np.array(test_loss_list)

class ResidualBlock(nn.Module):
    
    def __init__(self, inchannel, outchannel, stride=1):
        
        super().__init__()
        
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),

            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1,padding=1, bias=False),
            # 尺寸不发生变化 通道改变
            nn.BatchNorm2d(outchannel)
        )
        
        self.shortcut = nn.Sequential()
        # 注意shortcut是对输入X进行卷积,利用1×1卷积改变形状
        if inchannel != outchannel or stride != 1:
            self.shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel))

    def forward(self, X):
        h = self.left(X)
        # 先相加再激活
        h += self.shortcut(X)
        out = F.relu(h)
        return out


class ResidualNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.residual_block = nn.Sequential(
            ResidualBlock(3, 32),
            ResidualBlock(32, 64),
            ResidualBlock(64, 32),
            ResidualBlock(32, 3)
        )
        self.fc1 = nn.Linear(3 * 64 * 64, 1024)
        self.fc2 = nn.Linear(1024, 3)

    def forward(self, X):
        h = self.residual_block(X)
        h = h.view(-1, 3 * 64 * 64)
        h = self.fc1(h)
        out = self.fc2(h)
        return out

if __name__ == "__main__":
    train_path = "./cnn/train/"
    test_path = "./cnn/test/"
    _, train_dataloader = load_data(train_path, 32)
    _, test_dataloader = load_data(test_path, 32)
    model = ResidualNet()
    critic = nn.CrossEntropyLoss()
    epoch = 20
    lr = 0.01
    start = time.clock()
    print("Start training model.....")
    train_accur_list, train_loss_list, test_accur_list, test_loss_list = train(model, epoch, lr, train_dataloader,
                                                                               critic, test_dataloader)
    end = time.clock()
    print("Train cost: {} s".format(end - start))
    test_accur = 0
    for batch in test_dataloader:
        imgs, labels = batch
        preds = model(imgs)
        test_accur += get_accur(preds, labels)

    print("Accuracy on test datasets : {}".format(test_accur / len(test_dataloader.dataset)))

执行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/82581.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PHP自己的框架实现操作成功失败跳转(完善篇四)

1、实现效果,操作成功后失败成功自动跳转 2、创建操作成功失败跳转方法CrlBase.php /**成功后跳转*跳转地址$url* 跳转显示信息$msg* 等待时间$wait* 是否自动跳转$jump*/protected function ok($urlNULL,$msg操作成功,$wait3,$jump1){$code1;include KJ_CORE./tp…

【前缀和】算法实战

文章目录 一、算法原理1. 一维前缀和2. 二维前缀和 二、算法实战1. leetcode560 和为K的子数组2. leetcode974 和可被K整除的子数组3. leetcode525 连续数组4. leetcode1314 矩阵区域和5. leetcode724 寻找数组的中心下标6. leetcode238 除自身以外数组的乘积 三、总结 一、算法…

Linux/Ubuntu 的日常升级和安全更新,如何操作?

我安装的是Ubuntu 20.04.6 LTS的Windows上Linux子系统版本,启动完成后显示: Welcome to Ubuntu 20.04.6 LTS (GNU/Linux 5.15.90.4-microsoft-standard-WSL2 x86_64) * Documentation: https://help.ubuntu.com * Management: https://landscape.c…

每天一道leetcode:934. 最短的桥(图论中等广度优先遍历)

今日份题目: 给你一个大小为 n x n 的二元矩阵 grid ,其中 1 表示陆地,0 表示水域。 岛 是由四面相连的 1 形成的一个最大组,即不会与非组内的任何其他 1 相连。grid 中 恰好存在两座岛 。 你可以将任意数量的 0 变为 1 &#…

Go 安装配置

介绍Ubuntu20.04 安装和配置Go 可以参考官网的这个为 Go 开发配置Visual Studio Code - Go on Azure | Microsoft Learn 1.安装Go 去这个地方下载Go https://go.dev/doc/install 如果之前安装过,可以参考这个(没有可以忽略) 下载完成后执…

Linux Shell如果ping失败就重启网卡(详解)

直接上脚本 -------------------------------------------------------------------------- #vi /tmp/ping_check.sh #!/bin/bash IP="1.1.1.1" PacketLoss=`ping -c 4 -w 4 1.1.1.1 | grep packet loss | awk -F packet loss {print $1} | awk {print $NF}|se…

vsCode使用cuda

一、vsCode使用cuda 前情提要:配置好mingw: 1.安装cuda 参考: **CUDA Toolkit安装教程(Windows):**https://blog.csdn.net/qq_42951560/article/details/116131410 2.在vscode中添加includePath c_cp…

Midjourney API 申请及使用

在人工智能绘图领域,想必大家听说过 Midjourney 的大名吧! Midjourney 以其出色的绘图能力在业界独树一帜。无需过多复杂的操作,只要简单输入绘图指令,这个神奇的工具就能在瞬间为我们呈现出对应的图像。无论是任何物体还是任何风…

Android内存泄漏总结和性能优化技巧

我们在开发安卓应用时,性能优化是非常重要的方面。一方面,优化可以提高应用的响应速度、降低卡顿率和提升应用流畅度,从而提升用户体验;另一方面,优化也可以减少应用的资源占用,提高应用的稳定性和安全性&a…

如何在window下cmd窗口执行linux指令?

1.Git:https://git-scm.com/downloads(官网地址) 2.根据自己的实际路径,添加两个环境变量 3.重启电脑

【Selenium学习】环境搭建 API学习

目录 一、javaSelenium的环境搭建? 二、认识Selenium 1、什么是自动化? 2、什么是Selenium? (重点) 3、selenium的工作原理?(重点) 三、Selenium操作元素API(重点) 第一部分&#…

使用 HTML、CSS 和 JavaScript 创建实时 Web 编辑器

使用 HTML、CSS 和 JavaScript 创建实时 Web 编辑器 在本文中,我们将创建一个实时网页编辑器。这是一个 Web 应用程序,允许我们在网页上编写 HTML、CSS 和 JavaScript 代码并实时查看结果。这是学习 Web 开发和测试代码片段的绝佳工具。我们将使用ifram…

excel中有哪些通配符、excel配置问题,数学函数篇1之sum系列

学习excel前需要明确的是事:   在学习excel函数之前,大家需要明确一件事,excel现在设计到了一些新函数,这些新函数只能存在于office365、office2019及更 新版本之中,所以建议大家在学习时安装较新的版本,…

高阶数据结构-图

高阶数据结构-图 图的表示 图由顶点和边构成,可分为有向图和无向图 邻接表法 图的表示方法有邻接表法和邻接矩阵法,以上图中的有向图为例,邻接表法可以表示为 A->[(B,5),(C,10)] B->[(D,100)] C->[(B,3)] D->[(E,7)] E->[…

Ribbon:负载均衡及Ribbon

什么是负载均衡&#xff1f; 第一种轮询算法&#xff0c;依次遍历去执行&#xff0c;达到负载均衡 集成Ribbon 导入pom&#xff0c;在消费者服务里的pom文件导入 <!-- Ribbon 集成 --><!-- https://mvnrepository.com/artifact/org.springframework.cloud/spr…

-Webkit-Box 在 Safari 中出现的兼容性问题

一、问题背景&#xff1a; UI要求要实现这样的效果&#xff0c;使用 display:-webket-box在chrome浏览器下完美解决 但是马上啪啪打脸&#xff0c;在safari浏览器下显示空白 &#xff0c;不能不说浏览器之间的兼容性简直就是天坑 二、解决办法 通过浏览器调试发现原本float的…

[机器学习]特征工程:主成分分析

目录 主成分分析 1、简介 2、帮助理解 3、API调用 4、案例 本文介绍主成分分析的概述以及python如何实现算法&#xff0c;关于主成分分析算法数学原理讲解的文章&#xff0c;请看这一篇&#xff1a; 探究主成分分析方法数学原理_逐梦苍穹的博客-CSDN博客https://blog.csdn.…

YOLOX算法调试记录

YOLOX是在YOLOv3基础上改进而来&#xff0c;具有与YOLOv5相媲美的性能&#xff0c;其模型结构如下&#xff1a; 由于博主只是要用YOLOX做对比试验&#xff0c;因此并不需要对模型的结构太过了解。 先前博主调试过YOLOv5,YOLOv7&#xff0c;YOLOv8,相比而言&#xff0c;YOLOX的环…

RS232、RS422、RS485硬件及RS指令、RS2指令应用知识学习

RS232、RS422、RS485硬件及RS指令、RS2指令应用知识学习 一、串行&#xff08;异步/同步)通讯、并行通讯、以太网通讯 二、单工通讯/半双工通讯/双工通讯 三、常用硬件接口&#xff08;工业上基本是RS485两线制的接线&#xff09; 常用硬件接口RS232/RS422/RS485&#xff0c;…

C#与西门子PLC1500的ModbusTcp服务器通信2--ModbusTcp协议

Modbus TCP是近年来越来越流行的工业控制系统通信协议之一&#xff0c;与其他通信协议相比&#xff0c;Modbus TCP通信速度快、可靠性高、兼容性强、适用于模拟或数字量信号的传输&#xff0c;阅读本文前你必须比较熟悉Modbus协议&#xff0c;了解tcp网络。 一、什么是Modbus …