卷积神经网络-奥特曼识别

 数据集 

 四种奥特曼图片_数据集-飞桨AI Studio星河社区 (baidu.com)

 中间的隐藏层 已经使用参数的空间

Conv2D卷积层

ReLU激活层

MaxPool2D最大池化层

AdaptiveAvgPool2D自适应的平均池化

Linear全链接层

Dropout放置过拟合,随机丢弃神经元

--------------------------------------------------------------------------------
   Layer (type)          Input Shape          Output Shape         Param #    
================================================================================
     Conv2D-1        [[50, 3, 227, 227]]   [50, 64, 227, 227]       1,792     
      ReLU-1        [[50, 64, 227, 227]]   [50, 64, 227, 227]         0       
     Conv2D-2       [[50, 64, 227, 227]]   [50, 64, 227, 227]      36,928     
      ReLU-2        [[50, 64, 227, 227]]   [50, 64, 227, 227]         0       
    MaxPool2D-1     [[50, 64, 227, 227]]   [50, 64, 113, 113]         0       
     Conv2D-3       [[50, 64, 113, 113]]  [50, 128, 113, 113]      73,856     
      ReLU-3        [[50, 128, 113, 113]] [50, 128, 113, 113]         0       
     Conv2D-4       [[50, 128, 113, 113]] [50, 128, 113, 113]      147,584    
      ReLU-4        [[50, 128, 113, 113]] [50, 128, 113, 113]         0       
    MaxPool2D-2     [[50, 128, 113, 113]]  [50, 128, 56, 56]          0       
     Conv2D-5        [[50, 128, 56, 56]]   [50, 256, 56, 56]       295,168    
      ReLU-5         [[50, 256, 56, 56]]   [50, 256, 56, 56]          0       
     Conv2D-6        [[50, 256, 56, 56]]   [50, 256, 56, 56]       590,080    
      ReLU-6         [[50, 256, 56, 56]]   [50, 256, 56, 56]          0       
     Conv2D-7        [[50, 256, 56, 56]]   [50, 256, 56, 56]       590,080    
      ReLU-7         [[50, 256, 56, 56]]   [50, 256, 56, 56]          0       
    MaxPool2D-3      [[50, 256, 56, 56]]   [50, 256, 28, 28]          0       
     Conv2D-8        [[50, 256, 28, 28]]   [50, 512, 28, 28]      1,180,160   
      ReLU-8         [[50, 512, 28, 28]]   [50, 512, 28, 28]          0       
     Conv2D-9        [[50, 512, 28, 28]]   [50, 512, 28, 28]      2,359,808   
      ReLU-9         [[50, 512, 28, 28]]   [50, 512, 28, 28]          0       
     Conv2D-10       [[50, 512, 28, 28]]   [50, 512, 28, 28]      2,359,808   
      ReLU-10        [[50, 512, 28, 28]]   [50, 512, 28, 28]          0       
    MaxPool2D-4      [[50, 512, 28, 28]]   [50, 512, 14, 14]          0       
     Conv2D-11       [[50, 512, 14, 14]]   [50, 512, 14, 14]      2,359,808   
      ReLU-11        [[50, 512, 14, 14]]   [50, 512, 14, 14]          0       
     Conv2D-12       [[50, 512, 14, 14]]   [50, 512, 14, 14]      2,359,808   
      ReLU-12        [[50, 512, 14, 14]]   [50, 512, 14, 14]          0       
     Conv2D-13       [[50, 512, 14, 14]]   [50, 512, 14, 14]      2,359,808   
      ReLU-13        [[50, 512, 14, 14]]   [50, 512, 14, 14]          0       
    MaxPool2D-5      [[50, 512, 14, 14]]    [50, 512, 7, 7]           0       
AdaptiveAvgPool2D-1   [[50, 512, 7, 7]]     [50, 512, 7, 7]           0       
     Linear-1           [[50, 25088]]          [50, 4096]        102,764,544  
      ReLU-14           [[50, 4096]]           [50, 4096]             0       
     Dropout-1          [[50, 4096]]           [50, 4096]             0       
     Linear-2           [[50, 4096]]           [50, 4096]        16,781,312   
      ReLU-15           [[50, 4096]]           [50, 4096]             0       
     Dropout-2          [[50, 4096]]           [50, 4096]             0       
     Linear-3           [[50, 4096]]            [50, 4]            16,388     
================================================================================
Total params: 134,276,932
Trainable params: 134,276,932
Non-trainable params: 0
--------------------------------------------------------------------------------
Input size (MB): 29.49
Forward/backward pass size (MB): 11120.24
Params size (MB): 512.23
Estimated Total Size (MB): 11661.95
--------------------------------------------------------------------------------

如果paddle还没配置的话建议去网上搜一下,这里就不给链接了 

 用于训练模型的代码

import paddle
from paddle.io import Dataset,DataLoader
import os
from PIL import Image
import numpy as np
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
from paddle.vision.datasets import DatasetFolder

transforms=T.Compose([T.Resize([227,227]),T.RandomRotation(degrees=[-10,10]),T.ColorJitter(0.4,0.4,0.4,0.4),T.ToTensor()])
dataset=DatasetFolder("aoteman",extensions=[".jpg"],transform=transforms)
#使用paddle.io.random_split切分训练集和测试集
from paddle.io import random_split
train_size=int(0.8*len(dataset))
test_size=len(dataset)-train_size
train_dataset,test_dataset=random_split(dataset=dataset,lengths=[train_size,test_size])
print(len(train_dataset),len(test_dataset))

# plt.figure(figsize=[3,3])
# for idx,data in enumerate(train_dataset):
#     plt.subplot(3,3,idx+1)
#     im=data[0];label=data[1]
#     im=im.reshape([224,224,3])
#     plt.imshow(im)
#     if idx+1>=9:
#         break
# plt.show()

print(dataset.class_to_idx)

net=paddle.vision.models.vgg16(pretrained=True, num_classes=4)
paddle.summary(net,(50,3,227,227))

#网络配置
lr=0.001
batch_size=50
#预训练模型优化器 Adam优化器
opt =paddle.optimizer.Adam(learning_rate=lr,parameters=net.classifier.parameters())
#损失函数
loss_fn=paddle.nn.CrossEntropyLoss()
#训练模式
net.train()
model=paddle.Model(net)
model.prepare(optimizer=opt,loss=loss_fn,metrics=paddle.metric.Accuracy())
import time
vsdl=paddle.callbacks.VisualDL(log_dir='vsdl/trainlog'+str(time.time()))
# model.load('mymodel/vgg_aoteman')
# res=model.predict()
model.fit(train_data=train_dataset,eval_data=test_dataset, batch_size=batch_size,
          epochs=1, verbose=1,shuffle=True,callbacks=vsdl)
model.save('mymodel/vgg_aoteman')

用于预测模型的代码

import math

import paddle
import paddle.vision.transforms as T

from PIL import Image
from paddle.vision.datasets import DatasetFolder
import numpy as np

transforms = T.Compose([T.Resize([227, 227]), T.ToTensor()])
# 使用paddle.io.random_split切分训练集和测试集

img = Image.open('aoteman/predict_demo.jpg')#输入图片
img.show()
img = transforms(img)
img = img.unsqueeze(0)

start_index = 0  # 开始切片的索引
end_index = 3    # 结束切片的索引
axes = [1]       # 要切片的轴(通道轴)
img = paddle.slice(img, axes=axes, starts=[start_index], ends=[end_index])



net = paddle.vision.models.vgg16(pretrained=True, num_classes=4)
# 网络配置
lr = 0.001
batch_size = 50
# 预训练模型优化器 Adam优化器
opt = paddle.optimizer.Adam(learning_rate=lr, parameters=net.classifier.parameters())
# 损失函数
loss_fn = paddle.nn.CrossEntropyLoss()
# 训练模式
net.train()
model = paddle.Model(net)
model.prepare(optimizer=opt, loss=loss_fn, metrics=paddle.metric.Accuracy())
import time

vsdl = paddle.callbacks.VisualDL(log_dir='vsdl/trainlog' + str(time.time()))
model.load('mymodel/vgg_aoteman')

# print(img)
res = model.predict_batch(img)

sum=0
maxx=-1000000
idx=0
for i in range(4):
    # sum+=math.exp(res[0][0][i])
    if res[0][0][i]>maxx:
        maxx=res[0][0][i]
        idx=i
    # print(res[0][0][i])
# print(res)
# print(math.exp(res[0][0][idx])/sum*100,end='%:   ')
if idx==0:
    print("迪迦")
elif idx==1:
    print('杰克')
elif idx==2:
    print('赛文')
else:
    print('泰罗')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/654492.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

mac安装的VMware虚拟机进行桥接模式配置

1、先进行网络适配器选择,选择桥接模式 2、点击网络适配器 设置... 3、选择WiFi(我使用的是WiFi,所以选择这个),注意看右边的信息:IP和子网掩码,后续配置虚拟机的ifcfg-ens文件会用到 4、编辑if…

C语言学习笔记-- 3.4.2实型变量

1.实型数据在内存中的存放形式(了解) 实型数据一般占4个字节(32位)内存空间。按指数形式存储。 2.实型变量的分类(掌握) 实型变量分为:单精度(float型)、双精度&#…

AI视频智能分析技术赋能营业厅:智慧化管理与效率新突破

一、方案背景 随着信息技术的快速发展,图像和视频分析技术已广泛应用于各行各业,特别是在营业厅场景中,该技术能够有效提升服务质量、优化客户体验,并提高安全保障水平。TSINGSEE青犀智慧营业厅视频管理方案旨在探讨视频监控和视…

基于FPGA的VGA协议实现

目录 一、 内容概要二、 了解VGA2.1 概念 三、 VGA基础显示3.1 条纹显示3.2 显示字符3.2.1 准备工作3.2.2 提取文字3.2.3 编写代码3.2.4 编译烧录 3.3 显示图像3.3.1 准备工作3.3.2 实现例程3.3.3 编译烧录 四、参考链接 一、 内容概要 深入了解VGA协议,理解不同显示…

905. 按奇偶排序数组 - 力扣

1. 题目 给你一个整数数组 nums,将 nums 中的的所有偶数元素移动到数组的前面,后跟所有奇数元素。 返回满足此条件的 任一数组 作为答案。 2. 示例 3. 分析 开辟一个数组res用来保存操作过后的元素。第一次遍历数组只插入偶数,第二次遍历数组…

【ArcGISPro】CSMPlugins文件夹

在ArcGISPro软件的CSMPlugins文件夹含有以下一个应用程序的扩展 从文件的名称可以看出美国地质调查局的太空地质学与ESRI合作进行的一个软件扩展,而USGS主要是遥感影像方向的应该,所以估计该dll的主要功能是多遥感影像进行处理,支持软件的不同…

Steam游戏搬砖:靠谱吗,详细版说下搬砖中的核心内容!

可能大家也比较关注国外Steam游戏搬砖这个项目,最近单独找我了解的也比较多,其实也正常,因为现在市面上的项目很多都很鸡肋,而且很多都是一片红海,内卷太过严重,所以对于Steam的关注度也高很多,…

探秘网页内容提取:教你定位特定标签

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、引言 二、定位带有ID属性的第二个标签 三、定位具有特定属性值的标签 四、提取含有特…

【OpenCV】图形绘制与填充

介绍了绘制、填充图像的API。也介绍了RNG类用来生成随机数。相关API: line() rectangle() circle() ellipse() putText() 代码: #include "iostream" #include "opencv2/opencv.hpp"using namespace std; using namespace cv…

全局配置Maven

如果开着项目,就file->close project 如果创建有问题可以转到这篇rIDEA2024创建maven项目-CSDN博客https://blog.csdn.net/weixin_45588505/article/details/139271562?spm1001.2014.3001.5502

Unity SetParent第二个参数worldPositionStays的意义

初学Unity的小知识: 改变对象的父级有三种调用方式,如下: transMe.SetParent(transParent,true); transMe.SetParent(transParent,false); transMe.parent transParent;具体有什么区别呢,这里写一个测试例子来详细说明&#xff…

React18 apexcharts数据可视化之甜甜圈图

03 甜甜圈图 apexcharts数据可视化之甜甜圈图。 有完整配套的Python后端代码。 本教程主要会介绍如下图形绘制方式: 基本甜甜圈图个性图案的甜甜圈图渐变色的甜甜圈图 面包圈 import ApexChart from react-apexcharts;export function DonutUpdate() {// 数据…

在matlab里面计算一组给定参数的方程的解

如: k (1:1024); f (x)(1-x-k.*x.^2); 在这段代码给出了一组函数,若需要计算f0,可以通过自带的函数实现: x0 zeros(length(k),1); options optimoptions(fsolve,Display,none,TolX,tol,TolFun,tol); tic for ik 1:length…

基于OrangePi AIpro开发一个电子纸屏时钟

OrangePi AIpro 简介 OrangePi AIpro(8T)采用昇腾AI技术路线,具体为4核64位处理器AI处理器,集成图形处理器,支持8TOPS AI算力,拥有8GB/16GB LPDDR4X,可以外接32GB/64GB/128GB/256GB eMMC模块,支持双4K高清…

Web3革命:探索科技与物联网的无限可能

引言 Web3时代正在悄然而至,带来了对互联网的彻底颠覆和改变。作为互联网的下一代,Web3不仅是技术革新的延续,更是对传统互联网模式的重新构想。在这个新时代,科技与物联网的结合将迎来无限的可能性,将探索到一片全新…

如何在Python 中如何导入和引用外部文件(Colab VS Code)

1. 上传文件 在 Google Colab 中,从左侧界面的文件选项中使用 "Upload" 按钮上传文件。 在 VS Code 中,通过菜单栏中的 "File" -> "Open File/Folder" 选项上传文件(建议将所有文件放入一个文件夹中&#…

【paper】基于分布式采样的多机器人编队导航信念传播模型预测控制

Distributed Sampling-Based Model Predictive Control via Belief Propagation for Multi-Robot Formation NavigationRAL 2024.4Chao Jiang 美国 University of Wyoming 预备知识 马尔可夫随机场(Markov Random Field, MRF) 马尔可夫随机场&#xff…

如何解决SEO排名上升后遭遇的攻击问题

随着搜索引擎优化(SEO)策略的成功实施,网站排名的提升往往会引来更多的流量与关注,但同时也可能成为恶意攻击的目标,包括DDoS攻击、SQL注入、XSS攻击等。这些攻击不仅影响用户体验,还可能导致网站降权甚至被…

目标检测数据集 - 铁路工人安全检测数据集下载「包含VOC、COCO、YOLO三种格式」

数据集介绍:铁路工人安全检测数据集,真实铁路监控场景高质量图片数据,涉及场景丰富,比如铁路工地工人作业数据、铁路巡检工人作业数据、铁路搬运工人作业数据、铁路场景货车上工人作业数据、铁路旁堆料区工人作业数据等。数据标签…

【图书推荐】《机器学习实战(视频教学版)》

本书用处 快速入门Python机器学习基础算法。 最后3个综合实战项目(包括新闻内容分类实战、泰坦尼克号获救预测实战、中药数据分析项目实战)可以作为研究可以的素材。 内容简介 本书基于Python语言详细讲解机器学习算法及其应用,用于读者快…