卷积神经网络| 猫狗系列【AlexNet】

首先,搭建网络:

AlexNet神经网络原理图:

net代码:【根据网络图来搭建网络,不会的看看相关视频会好理解一些】

import torchfrom torch import nnimport torch.nn.functional as Fclass MyAlexNet(nn.Module):    def __init__(self):        super(MyAlexNet, self).__init__()#继承        self.c1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=11, stride=4, padding=2)#搭建第一层网络,输入通道3层,输出通道48,核11        self.ReLU = nn.ReLU()#激活函数        self.c2 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=5, stride=1, padding=2)#上面输出48,下面输入也是48,输出125,卷积核5        self.s2 = nn.MaxPool2d(2)#池化层        self.c3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, stride=1, padding=1)        self.s3 = nn.MaxPool2d(2)        self.c4 = nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=1, padding=1)        self.c5 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=3, stride=1, padding=1)        self.s5 = nn.MaxPool2d(kernel_size=3, stride=2)        self.flatten = nn.Flatten()#平展层        self.f6 = nn.Linear(4608, 2048)        self.f7 = nn.Linear(2048, 2048)        self.f8 = nn.Linear(2048, 1000)        self.f9 = nn.Linear(1000, 2)#输出二分类网络    def forward(self, x):        x = self.ReLU(self.c1(x))        x = self.ReLU(self.c2(x))        x = self.s2(x)#池化层        x = self.ReLU(self.c3(x))        x = self.s3(x)        x = self.ReLU(self.c4(x))        x = self.ReLU(self.c5(x))        x = self.s5(x)        x = self.flatten(x)#平展层        x = self.f6(x)        x = F.dropout(x, p=0.5)#防止过拟合,有50%的网络随机失效        x = self.f7(x)        x = F.dropout(x, p=0.5)        x = self.f8(x)        x = F.dropout(x, p=0.5)        x = self.f9(x)        return xif __name__ == '__mian__':    x = torch.rand([1, 3, 224, 224])#张量形式数组    model = MyAlexNet()    y = model(x)

测试一下这个网络:

划分数据集:(8:2)(spilit_data

import osfrom shutil import copyimport randomdef mkfile(file):    if not os.path.exists(file):        os.makedirs(file)# 获取data文件夹下所有文件夹名(即需要分类的类名)file_path = 'D:/Users/Twilight/PycharmProjects/AlexNet/data_name'flower_class = [cla for cla in os.listdir(file_path)]# 创建 训练集train 文件夹,并由类名在其目录下创建5个子目录mkfile('data/train')for cla in flower_class:    mkfile('data/train/' + cla)# 创建 验证集val 文件夹,并由类名在其目录下创建子目录mkfile('data/val')for cla in flower_class:    mkfile('data/val/' + cla)# 划分比例,训练集 : 验证集 = 8:2split_rate = 0.2# 遍历所有类别的全部图像并按比例分成训练集和验证集for cla in flower_class:    cla_path = file_path + '/' + cla + '/'  # 某一类别的子目录    images = os.listdir(cla_path)  # iamges 列表存储了该目录下所有图像的名称    num = len(images)    eval_index = random.sample(images, k=int(num * split_rate))  # 从images列表中随机抽取 k 个图像名称    for index, image in enumerate(images):        # eval_index 中保存验证集val的图像名称        if image in eval_index:            image_path = cla_path + image            new_path = 'data/val/' + cla            copy(image_path, new_path)  # 将选中的图像复制到新路径        # 其余的图像保存在训练集train中        else:            image_path = cla_path + image            new_path = 'data/train/' + cla            copy(image_path, new_path)        print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar    print()print("processing done!")

生成的新的文件夹:

由于数据集量太大,划分完了我还删了很多图片,我的data里面只有1000张,训练集猫狗分别400,测试集猫狗分别100。【跑不动根本跑不动】

训练代码:(train)

import torchfrom torch import nnfrom net import MyAlexNetimport numpy as npfrom torch.optim import lr_schedulerimport osfrom torchvision import transformsfrom torchvision.datasets import ImageFolderfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as plt# 解决中文显示问题(乱码)plt.rcParams['font.sans-serif'] = ['SimHei']plt.rcParams['axes.unicode_minus'] = FalseROOT_TRAIN = r'D:/Users/Twilight/PycharmProjects/AlexNet/data/train'#数据集路径训练集ROOT_TEST = r'D:/Users/Twilight/PycharmProjects/AlexNet/data/val'# 将图像的像素值归一化到【-1, 1】之间normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])#train_transform = transforms.Compose([#训练集    transforms.Resize((224, 224)),#224*224    transforms.RandomVerticalFlip(),#随机垂直    transforms.ToTensor(),#转化为张量    normalize])#归一化val_transform = transforms.Compose([#验证集    transforms.Resize((224, 224)),    transforms.ToTensor(),    normalize])train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)#批次32,打乱val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'#数据导入显卡里面model = MyAlexNet().to(device)#把数据送到神经网络中,然后输到显卡里面# 定义一个损失函数loss_fn = nn.CrossEntropyLoss()# 定义一个优化器optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)#随机梯度下降法,把模型参数传给优化器,学习率0.01# 学习率每隔10轮变为原来的0.5lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)# 定义训练函数def train(dataloader, model, loss_fn, optimizer):#数据,模型,学习率,优化器传入    loss, current, n = 0.0, 0.0, 0#指示器    for batch, (x, y) in enumerate(dataloader):        image, y = x.to(device), y.to(device)        output = model(image)#进行训练        cur_loss = loss_fn(output, y)#看误差        _, pred = torch.max(output, axis=1)#_代表不关心返回的最大值是多少。只需要得到pred        cur_acc = torch.sum(y==pred) / output.shape[0]#算精确率        # 反向传播        optimizer.zero_grad()#梯度降为0        cur_loss.backward()#反向传播        optimizer.step()#更新梯度        loss += cur_loss.item()#Loss值累加起来(一轮很多批次)        current += cur_acc.item()#准确度加起来        n = n+1#轮    train_loss = loss / n#计算这一轮学习的学习率(每一批的)    train_acc = current / n    print('train_loss' + str(train_loss))    print('train_acc' + str(train_acc))#训练精确的    return train_loss, train_acc#返回后面可视化用# 定义一个验证函数def val(dataloader, model, loss_fn):    # 将模型转化为验证模型    model.eval()    loss, current, n = 0.0, 0.0, 0    with torch.no_grad():        for batch, (x, y) in enumerate(dataloader):            image, y = x.to(device), y.to(device)            output = model(image)            cur_loss = loss_fn(output, y)            _, pred = torch.max(output, axis=1)            cur_acc = torch.sum(y == pred) / output.shape[0]            loss += cur_loss.item()            current += cur_acc.item()            n = n + 1    val_loss = loss / n    val_acc = current / n    print('val_loss' + str(val_loss))    print('val_acc' + str(val_acc))    return val_loss, val_acc# 定义画图函数def matplot_loss(train_loss, val_loss):    plt.plot(train_loss, label='train_loss')    plt.plot(val_loss, label='val_loss')    plt.legend(loc='best')    plt.ylabel('loss')    plt.xlabel('epoch')    plt.title("训练集和验证集loss值对比图")    plt.show(block=True)def matplot_acc(train_acc, val_acc):    plt.plot(train_acc, label='train_acc')    plt.plot(val_acc, label='val_acc')    plt.legend(loc='best')    plt.ylabel('acc')    plt.xlabel('epoch')    plt.title("训练集和验证集acc值对比图")    plt.show(block=True)# 开始训练loss_train = []acc_train = []loss_val = []acc_val = []epoch = 20 #20轮min_acc = 0for t in range(epoch):    lr_scheduler.step()#每十步分析一下学习率    print(f"epoch{t+1}\n-----------")    train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)    val_loss, val_acc = val(val_dataloader, model, loss_fn)    loss_train.append(train_loss)#写到集合里头    acc_train.append(train_acc)    loss_val.append(val_loss)    acc_val.append(val_acc)    # 保存最好的模型权重    if val_acc >min_acc:#如果模型精确度大于0        folder = 'save_model'        if not os.path.exists(folder):#如果文件夹不存在            os.mkdir('save_model')#生成        min_acc = val_acc        print(f"save best model, 第{t+1}轮")        torch.save(model.state_dict(), 'save_model/best_model.pth')    # 保存最后一轮的权重文件    if t == epoch-1:        torch.save(model.state_dict(), 'save_model/last_model.pth')matplot_loss(loss_train, loss_val)matplot_acc(acc_train, acc_val)print('Done!')

最好的模型保存,嘻嘻。

生成的loss、acc图​:​

(效果其实是非常不好的,因为数据量太少了哈哈,然后参数某些地方也可以再调一下)

测试代码(test)

import torchfrom net import MyAlexNetfrom torch.autograd import Variablefrom torchvision import datasets, transformsfrom torchvision.transforms import ToTensorfrom torchvision.transforms import ToPILImagefrom torchvision.datasets import ImageFolderfrom torch.utils.data import DataLoaderROOT_TRAIN = r'D:/Users/Twilight/PycharmProjects/AlexNet/data/train'#数据集路径训练集ROOT_TEST = r'D:/Users/Twilight/PycharmProjects/AlexNet/data/val'# 将图像的像素值归一化到【-1, 1】之间normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])train_transform = transforms.Compose([    transforms.Resize((224, 224)),    transforms.RandomVerticalFlip(),    transforms.ToTensor(),    normalize])val_transform = transforms.Compose([    transforms.Resize((224, 224)),    transforms.ToTensor(),    normalize    ])train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)#变张量val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'model = MyAlexNet().to(device)# 加载模型model.load_state_dict(torch.load("D:/Users/Twilight/PycharmProjects/AlexNet/save_model/best_model.pth"))# 获取预测结果classes = [    "cat",    "dog",]# 把张量转化为照片格式,后面可视化show = ToPILImage()# 进入到验证阶段model.eval()for i in range(10):#验证前十张    x, y = val_dataset[i][0], val_dataset[i][1]    show(x).show()    x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=True).to(device)#把值传入到显卡里面    x = torch.tensor(x).to(device)    with torch.no_grad():        pred = model(x)        predicted, actual = classes[torch.argmax(pred[0])], classes[y]        print(f'predicted:"{predicted}", Actual:"{actual}"')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/35038.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Flutter学习四:Flutter开发基础(六)调试Flutter应用

目录 0 引言 1 Flutter异常捕获 1.1 Dart单线程模型 1.2 Flutter异常捕获 1.2.1 Flutter框架异常捕获 1.2.1.1 Flutter默认异常捕获方式 1.2.1.2 自己捕获异常并上报 1.2.2 其他异常捕获与日志收集 1.2.3 最终的错误上报代码 0 引言 本文是对第二版序 | 《Flutter实…

《Lua程序设计》--学习2

表 Lua语言中的表本质上是一种辅助数组(associative array),这种数组不仅可以使用数值作为索引,也可以使用字符串或其他任意类型的值作为索引(nil除外)。 Lua语言中的表要么是值要么是变量,它…

防火墙基本原理详解

概要 防火墙是可信和不可信网络之间的一道屏障,通常用在LAN和WAN之间。它通常放置在转发路径中,目的是让所有数据包都必须由防火墙检查,然后根据策略来决定是丢弃或允许这些数据包通过。例如: 如上图,LAN有一台主机和一…

【macOS 系列】如何在mac 邮件客户端配置QQ邮箱和第二个账号

文章目录 一、配置QQ邮箱二、添加新的账户 一、配置QQ邮箱 需要在QQ邮箱账户设置中开启: 开启时,会让你发短信到指定号码,然后就会弹出一个验证码 也就是添加邮箱的密码不是QQ密码,而是这个验证码,这个可以生成多个&…

【OpenGL】读取视频并渲染

😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍读取视频并渲染。 学其所用,用其所学。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关注一下,下次更新不迷路&#…

ELK实验部署过程

ELK集群部署环境准备 配置ELK日志分析系统 192.168.1.51 elk-node1 es、logstash、kibana 192.168.1.52 elk-node2 es、logstash 192.168.1.53 apache logstash (我这里是把虚拟机的配置全部都改为2核3G的) 2台linux 第1台:elk-nod…

【数据库原理】MyShop 商城数据库设计(SQL server)

MyShop 商城数据库设计 项目背景定义课程设计要求概念结构设计逻辑结构设计数据结构的描述用户信息数据结构的描述地址信息数据结构的描述商品类别数据结构的描述商品数据结构的描述购物车数据结构的描述订单数据结构的描述订单项数据结构的描述 物理结构设计用户表结构地址表结…

STM32——GPIO配置

文章目录 一、GPIO八种模式1. 输入2. 输出3. 如何选择GPIO的模式 二、库函数GPIO配置1. 配置代码2.参数设置 一、GPIO八种模式 GPIO的输入输出是对于STM32单片机来说的。以下仅为个人粗略笔记,内部电路分析可参考博客https://blog.csdn.net/k666499436/article/det…

计算机网络_ 1.3 网络核心(数据交换_电路交换_多路复用)

计算机网络_数据交换_电路交换_多路复用 多路复用频分多路复用FDM时分多路复用TDM波分多路复用WDM码分多路复用CDM 多路复用 多路复用(Multiplexing),简称复用,是通信技术的基本概念。 链路/网络资源(如带宽&#x…

【K8S系列】如何高效查看 k8s日志

序言 你只管努力,其他交给时间,时间会证明一切。 文章标记颜色说明: 黄色:重要标题红色:用来标记结论绿色:用来标记一级论点蓝色:用来标记二级论点 Kubernetes (k8s) 是一个容器编排平台&#x…

docker安装失败 应用程序无法启动,因为应用程序的并行配置不正确

问题描述 报错“应用程序无法启动,因为应用程序的并行配置不正确”。 配置:windows10 解决过程 网上的解决方案有三种: 启动windows服务Windows Modules Installer。运行sxstrace.exe。安装visual c相关依赖。下载visual studio installer…

1.6 OSI 七层参考模型

OSI 参考模型 OSI参考模型解释的通信过程OSI参考模型数据封装与通信过程物理层功能数据链路层功能网络层的功能传输层功能会话层功能表示层功能应用层功能 开放系统互连 (OSI)参考模型是由国际标准化组织 (ISO) 于1984年提出的分层网络体系结构模型目的是支持异构网络系统的互联…

Selenium--做任何你想做的事情

大家好,今天为大家介绍Selenium自动化浏览器。就是这样!你可以通过这种力量做任何你想做的事情。 “getDevTools() 方法返回新的 Chrome DevTools 对象,允许您使用 send() 方法发送针对 CDP 的内置 Selenium 命令。这些命令是包装方法&#x…

k8s Label 2

在 k8s 中,我们会轻轻松松的部署几十上百个微服务,这些微服务的版本,副本数的不同进而会带出更多的 pod 这么多的 pod ,如何才能高效的将他们组织起来的,如果组织不好便会让管理微服务变得混乱不堪,杂乱无…

C#(四十九)之关于string的一些函数

1&#xff1a;startswith 字符串以。。。开头 // startswith 字符串以。。。开头string[] strArr { "asd","azx","qwe","aser","asdfgh"};for (int i 0; i < strArr.Length; i){if(strArr[i].StartsWith("as&qu…

LiangGaRy-学习笔记-Day28

1、回顾知识 1.1、docker启动MySQL 安装docker #准备好二进制的包 [rootNode2 ~]# ls docker-20.10.9.tgz docker-20.10.9.tgz [rootNode2 ~]# #解压docker的二进制包 [rootNode2 ~]# tar -xf docker-20.10.9.tgz #把它移动到/usr/local/下 [rootNode2 ~]# mv docker /usr/…

logback-spring.xml详解

本文来写说下logback-spring.xml相关的知识与概念 文章目录 概述configuration元素定义上下文名称定义变量appender组件RollingFileAppender配置logger配置root配置ELK的配置输出logback状态数据异步输出日志代码中的日志格式本文小结 概述 对于xml日志文件的配置&#xff0c;大…

语义分割大模型RSPrompter论文阅读

论文链接 RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model 开源代码链接 RSPrompter 论文阅读 摘要 Abstract—Leveraging vast training data (SA-1B), the foundation Segment Anything Model (SAM) propo…

遗传算法(GA)优化后RBF神经网络优化分析(Matlab代码实现)

目录 1 遗传算法 2 RBF神经网络 3 Matlab代码实现 4 结果 1 遗传算法 遗传算法是一种模拟自然界进化过程的优化算法。它通过模拟生物进化的遗传、交叉和变异等过程&#xff0c;来搜索最优解或近似最优解。 遗传算法的基本步骤如下&#xff1a; 初始化种群&#xff1a;随机生成…

SPSS读取纯文本文件

纯文本文件是通用的一种格式文件&#xff0c;根据纯文本文件中数据的排序方式&#xff0c;可以将其分为自由格式和固定格式。自由格式文本文件的数据项之间必须有分隔符&#xff0c;固定格式数据项之间不需要分隔符。 1.以自由格式读取数据 &#xff08;1&#xff09;选择“文…