Tensorflow 2.0 cnn训练cifar10 准确率只有0.1 [已解决]

cifar10 准确率只有0.1

  • 问题描述
  • 踩坑
  • 解决办法

问题描述

如果你看的是北京大学曹健老师的tensorflow2.0,你在class5的部分可能会遇见这个问题

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout,MaxPooling2D,Flatten,Conv2D,BatchNormalization,Activation
from tensorflow.keras import Model
import os
import numpy as np

# np.set_printoptions(threshold=np.inf)


class Baseline(Model):
    def __init__(self):
        super(Baseline, self).__init__()
        self.conv1 = Conv2D(6, (5,5), activation='sigmoid')
        self.pool1 = MaxPooling2D(pool_size=(2,2),strides=2)
        self.conv2 = Conv2D(16, (5,5), activation='sigmoid')
        self.pool2 = MaxPooling2D(pool_size=(2,2),strides=2)

        self.flatten1 = Flatten()
        self.f1=Dense(120,activation='sigmoid')
        self.f2=Dense(84,activation='sigmoid')
        self.f3=Dense(10,activation='softmax')

    def call(self,x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)

        x = self.flatten1(x)
        x = self.f1(x)
        x = self.f2(x)
        y = self.f3(x)
        return y


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train,x_test = x_train/255.0,x_test/255.0


model = Baseline()
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
              ,metrics=['sparse_categorical_accuracy'])

checkpoint_save_path="lenet.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):
    model.load_weights(checkpoint_save_path)
    print("---------------------Loaded model---------------")

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True
                                              ,save_best_only=True, verbose=1)


history=model.fit(x_train,y_train,batch_size=32, epochs=5, validation_data=(x_test, y_test)
          ,validation_freq=1,callbacks=[cp_callback])
model.summary()

file=open('weights_lenet.txt','w')
for v in model.trainable_variables:
    file.write(str(v.name)+'\n')
    file.write(str(v.shape)+'\n')
    file.write(str(v.numpy())+'\n')
file.close()

train_acc=history.history['sparse_categorical_accuracy']
val_acc=history.history['val_sparse_categorical_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(loss,label='train_loss')
plt.plot(val_loss,label='val_loss')
plt.title('model loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(train_acc,label='train_acc')
plt.plot(val_acc,label='val_acc')
plt.title('model acc')
plt.legend()
plt.show()


代码写的看起来没有问题,但是就是acc一直在0.1,总共10个类,也就是说网络根本没有训练效果,就是瞎蒙的。为什么会这样呢。想知道答案的直接跳到最后。下面是我踩的坑,

踩坑

我尝试升级tensorflow版本,但是我们知道升级tensorflow,对应的cudatoolkit 和cudnn 也要升级,官网版本对应

在这里插入图片描述
conda install cudatoolkit==11.2.0

但是我去安装的时候显示PackagesNotFoundError: The following packages are not available from current channels:
搜不到这个版本,conda search cudatoolkit查看可以安装的版本
在这里插入图片描述就是没有11.2,这就很烦人,
我电脑环境是

windows11
cuda 12.3
cudnn 8.9.7

我不能把电脑cuda卸载重新装,因为我pytorch要求的是上面的环境。我尝试去官网再安装一个cuda但是失败了(想试一下windows电脑能不能安装两个cuda)。总之折腾了一下午

解决办法

方法一

cudatoolkit 和cudnn保持不变,直接升级tensorflow
pip install tensorflow==2.4
但是这样就不能用gpu训练了,跑代码的时候用的是cpu,具体原因我也不是很清楚,

方法二

看我之前的文章,卸载电脑上的cuda安装,安装cuda11.2和对应的cudnn8.1
cuda下载地址
cudnn下载地址
然后安装tensoflow 2.10版本
conda install tensorflow_gpu==2.10.0


你windows电脑如果想同时可以跑tensorflow和pytorch,建议电脑的cuda环境就按照tensorflow的安装。
因为pytorch安装比较简单,一般会自带对应的cuda,而tensorflow对cuda要求比较严格,用指令(conda install cudatoolkit==11.2.0 )一般找不到对应的版本,只能去官网下载

windows要是想跑代码就用pytorch吧,tensorflow对windows真的很不友好,tensorflow2.10以上直接不支持了,可以用实验室的服务器跑tensorflow代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/885610.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Verilog学习日常】—牛客网刷题—Verilog企业真题—VL69

脉冲同步器(快到慢) 描述 sig_a 是 clka(300M)时钟域的一个单时钟脉冲信号(高电平持续一个时钟clka周期),请设计脉冲同步电路,将sig_a信号同步到时钟域 clkb(100M&…

长文本溢出,中间位置显示省略号

1.说明 Flutter支持在文本末尾显示溢出省略号。现在想要实现在文本中间位置显示省略号,这里使用的方法是通过TextPainter计算文本宽度。(我目前没有找到更好的方法,欢迎大家指教。) 2.效果 源码 1.MiddleEllipsisTextPainter …

全球IP归属地查询-IP地址查询-IP城市查询-IP地址归属地-IP地址解析-IP位置查询-IP地址查询API接口

IP地址城市版查询接口 API是指能够根据IP地址查询其所在城市等地理位置信息的API接口。这类接口在网络安全、数据分析、广告投放等多个领域有广泛应用。以下是一些可用的IP地址城市版查询接口API及其简要介绍 1. 快证 IP归属地查询API 特点:支持IPv4 提供高精版、…

TypeScript 算法手册 【数组基础知识】

文章目录 1. 数组简介1.1 数组定义1.2 数组特点 2. 数组的基本操作2.1 访问元素2.2 添加元素2.3 删除元素2.4 修改元素2.5 查找元素 3. 数组的常见方法3.1 数组的创建3.2 数组的遍历3.3 数组的映射3.4 数组的过滤3.5 数组的归约3.6 数组的查找3.7 数组的排序3.8 数组的反转3.9 …

深度学习常见术语介绍

文章目录 数据集(Dataset)特征(Feature)标签(Label)训练集(Training Set)测试集(Test Set)验证集(Validation Set)模型(Mo…

什么是文件完整性监控(FIM)

组织经常使用基于文件的系统来组织、存储和管理信息。文件完整性监控(FIM)是一种用于监控和验证文件和系统完整性的技术,识别用户并提醒用户对文件、文件夹和配置进行未经授权或意外的变更是 FIM 的主要目标,有助于保护关键数据和…

【NVIDIA】如何使用nvidia-smi命令管理和监控GPU

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客…

Golang | Leetcode Golang题解之第436题寻找右区间

题目: 题解: func findRightInterval(intervals [][]int) []int {n : len(intervals)type pair struct{ x, i int }starts : make([]pair, n)ends : make([]pair, n)for i, p : range intervals {starts[i] pair{p[0], i}ends[i] pair{p[1], i}}sort.…

面向人工智能: 对红酒数据集进行分析 (实验四)

由于直接提供截图是不切实际的,我将详细解释如何使用scikit-learn(通常称为sk-learn)自带的红酒数据集进行葡萄酒数据的分析与处理。这包括实验要求的分析、数据的初步分析(完整性和重复性)以及特征之间的关联关系分析…

MATLAB绘图基础9:多变量图形绘制

参考书:《 M A T L A B {\rm MATLAB} MATLAB与学术图表绘制》(关东升)。 9.多变量图形绘制 9.1 气泡图 气泡图用于展示三个或更多变量变量之间的关系,气泡图的组成要素: 横轴( X {\rm X} X轴):表示数据集中的一个变量&#xff0c…

双端搭建个人博客

1. 准备工作 确保你的两个虚拟机都安装了以下软件: 虚拟机1(Web服务器): Apache2, PHP虚拟机2(数据库服务器): MariaDB2. 安装步骤 虚拟机1(Web服务器) 安装Apache2和PHP 更新系统包列表: sudo apt update安装Apache2: sudo apt install apache2 -y安装PHP及其Apac…

只写CURD后台管理的Java后端要如何提升自己

你是否工作3~5年后,发现日常只做了CURD的简单代码。 你是否每次面试就会头疼,自己写的代码,除了日常CURD简历上毫无亮点可写 抱怨过苦恼过也后悔过,但是站在现在的时间点回想以前,发现有很多事情我们是可以做的更好的。…

Spring之生成Bean

Bean的生命周期:实例化->属性填充->初始化->销毁 核心入口方法:finishBeanFactoryInitialization-->preInstantiateSingletons DefaultListableBeanFactory#preInstantiateSingletons用于实例化非懒加载的bean。 1.preInstantiateSinglet…

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-09-26

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-09-26 1. LLMs Still Can’t Plan; Can LRMs? A Preliminary Evaluation of OpenAI’s o1 on PlanBench Authors: Karthik Valmeekam, Kaya Stechly, Subbarao Kambhampati LLMs仍然无法规划;LRMs可以…

Mybatis的基本使用

什么是Mybatis? Mybatis是一个简化JDBC的持久层框架,MyBatis是一个半自动化框架,是因为它在SQL执行过程中只提供了基本的SQL执行功能,而没有像Hibernate那样将所有的ORM操作都自动化了。在MyBatis中,需要手动编写SQL语…

【Android】布局优化—include,merge,ViewStub的使用方法

引言 1.重要性 在Android应用开发中,布局是用户界面的基础。一个高效的布局不仅能提升用户体验,还能显著改善应用的性能。随着应用功能的复杂性增加,布局的优化变得尤为重要。优化布局能够减少渲染时间,提高响应速度&#xff0c…

Docker安装consul + go使用consul + consul知识

1. 什么是服务注册和发现 假如这个产品已经在线上运行,有一天运营想搞一场促销活动,那么我们相对应的【用户服务】可能就要新开启三个微服务实例来支撑这场促销活动。而与此同时,作为苦逼程序员的你就只有手动去 API gateway 中添加新增的这…

探索分布式IO模块的介质冗余:赋能工业自动化的稳健之心

在日新月异的工业自动化领域,每一个细微环节的稳定性都直接关系到生产线的效率与安全。随着智能制造的深入发展,分布式IO(Input/Output)模块作为连接现场设备与控制系统的关键桥梁,其重要性日益凸显。我们自主研发的带…

五子棋双人对战项目(3)——匹配模块

一、分析需求 二、约定前后端接口 三、实现游戏大厅页面(前端代码) 四、实现后端代码 五、线程安全问题 六、忙等问题 一、分析需求 需求:多个玩家,在游戏大厅进行匹配,系统会把实力相近的玩家匹配到一起。 要想实…

Redis 简单的消息队列

使用redis 进行简单的队列很容易,不需要使用较为复杂的MQ队列,直接使用redis 进行,不过唯一不足的需要自己构造生产者消费者,这里使用while True的方法进行消费者操作 目录 介绍数据类型StringHash 重要命令消息队列 介绍 key-v…