Tensorflow2.0笔记 - 自定义Layer和Model实现CIFAR10数据集的训练

       本笔记记录使用自定义Layer和Model来做CIFAR10数据集的训练。

        CIFAR10数据集下载:

        https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

        自定义的Layer和Model实现较为简单,参数量较少,并且没有卷积层和dropout等,最终准确率不高,仅做练习使用。

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

tf.__version__

def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255
    y = tf.cast(y, dtype=tf.int32)
    return x,y

batchsize = 128
#CIFAR10数据集下载,可以直接使用网络下载
(x,y), (x_val, y_val) = datasets.cifar10.load_data()
#CIFAR10的标签(训练集)数据维度是[50000, 1],通过squeeze消除掉里面1的维度,变成[50000]
print("y.shape:", y.shape)
y = tf.squeeze(y)
print("squeezed y.shape:", y.shape)
y_val = tf.squeeze(y_val)
#进行onehot编码
y = tf.one_hot(y, depth=10)
y_val = tf.one_hot(y_val, depth=10)
print("Datasets: ", x.shape, " ", y.shape, " x.min():", x.min(), " x.max():", x.max())

train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsize)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batchsize)

sample = next(iter(train_db))
print("Batch:", sample[0].shape, sample[1].shape)

#自定义Layer
class MyDense(layers.Layer):
    def __init__(self, input_dim, output_dim):
        super(MyDense, self).__init__()
        self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))
        self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))
        #self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim])
        #self.bias = self.add_weight(name='b', shape=[output_dim])
        
    def call(self, inputs, training = None):
        x = inputs@self.kernel + self.bias
        return x

class MyNetwork(keras.Model):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.fc1 = MyDense(32 * 32 * 3, 512)
        self.fc2 = MyDense(512, 512)
        self.fc3 = MyDense(512, 256)
        self.fc4 = MyDense(256, 256)
        self.fc5 = MyDense(256, 10)

    def call(self, inputs, training = None):
        x = tf.reshape(inputs, [-1, 32 * 32 * 3])
        x = self.fc1(x)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        x = tf.nn.relu(x)
        #返回logits
        return x

total_epoches = 35
learn_rate = 0.001
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(learning_rate=learn_rate),
             loss = tf.losses.CategoricalCrossentropy(from_logits=True),
             metrics=['Accuracy'])
network.fit(train_db, epochs=total_epoches, validation_data=test_db, validation_freq=1)

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/515182.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java—抽象方法与接口

声明:以下内容是根据B站黑马程序员的Java课程+博主自己的理解整理而成,课程很好,适合初学者学习。 关于此类题目,重要的是识别出用什么来实现,到底是接口还是抽象方法,还是共有的属性等等&…

医用三维影像PACS系统源码 一套成熟的PACS系统应具备哪些核心要素?

医用三维影像PACS系统源码 一套成熟的PACS系统应具备哪些核心要素? PACS及影像存取与传输系统”( Picture Archiving and Communication System),为以实现医学影像数字化存储、诊断为核心任务,从医学影像设备(如CT、CR、DR、MR、…

ZZS-7/1G212分合闸电源综合控制装置 220VAC 板前接线 JOSEF约瑟

系列型号: ZZS-7G/1分闸、合闸、电源监视综合控制装置; ZZS-7G/11分闸、合闸、电源监视综合控制装置; ZZS-7G/23分闸、合闸、电源监视综合控制装置; ZZS-7G/24分闸、合闸、电源监视综合控制装置; ZZS-7/1G11分闸、合闸…

21.兼容性测试

考试频率低; 一般考兼容性测试会结合web测试;(兼容性矩阵) 主要议题: 1.兼容性测试概述 2.硬件兼容性测试 最低配置不讲究工作负载,意思是软件能够运行的最低要求环境; 推荐配置&#xff0c…

修复503 Service Unavailable Error问题

近期我们网网站经常出现503 Service Unavailable Error,在此之前我们的网站从未出现过这种问题,我们向虚拟主机提供商Hostease咨询后,了解到503 Service Unavailable错误是指服务器暂时无法处理请求,通常是由于服务器过载、维护、…

Python数据结构与算法——数据结构(链表、哈希表、树)

目录 链表 链表介绍 创建和遍历链表 链表节点插入和删除 双链表 链表总结——复杂度分析 哈希表(散列表) 哈希表介绍 哈希冲突 哈希表实现 哈希表应用 树 树 树的示例——模拟文件系统 二叉树 二叉树的链式存储 二叉树的遍历 二叉搜索树 插入 查询 删除 AVL树 …

路由Vue-Router使用

Vue Router 是 Vue.js 的官方路由。它与 Vue.js 核心深度集成,让用 Vue.js 构建单页应用变得轻而易举。 介绍 | Vue Router (vuejs.org) 1. 安装 npm install vue-router4 查看安装好的vue-router 2. 添加路由 新建views文件夹用来存放所有的页面,在…

初入职,如何用好 git 快速上手项目开发

前言 介绍在工作中使用 git 工具 文章目录 前言一、git 简介1、是什么作用操作3、用途 二、基本概念1、工作区2、暂存区3、版本库4、操作过程 三、基本命令操作 一、git 简介 1、是什么 git 是一个方便管理代码版本的工具,用一个树结构来维护和管理所有的历史版本…

数据结构记录

之前记录的数据结构笔记,不过图片显示不了了 数据结构与算法(C版) 1、绪论 1.1、数据结构的研究内容 一般应用步骤:分析问题,提取操作对象,分析操作对象之间的关系,建立数学模型。 1.2、基本概念和术语 数据&…

Finite Element Procedures K.J.Bathe 【教材pdf+部分源码】|有限元经典教材 | 有限元编程

专栏导读 作者简介:工学博士,高级工程师,专注于工业软件算法研究本文已收录于专栏:《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现,并提供所有案例完整源码;2.单元…

flask的使用学习笔记1

跟着b站学的1-06 用户编辑示例_哔哩哔哩_bilibili flask是一个轻量级,短小精悍,扩展性强,可以扩展很多组件,django大而全 编程语言它们的区别: (这些语言都很了解,java和python是高级语言,都…

动手做一个最小Agent——TinyAgent!

Datawhale干货 作者:宋志学,Datawhale成员 前 言 大家好,我是不要葱姜蒜。在ChatGPT横空出世,夺走Bert的桂冠之后,大模型愈发地火热,国内各种模型层出不穷,史称“百模大战”。大模型的能力是毋…

UE4几个常用节点链接

UE4几个常用节点链接 2017-12-02 12:54 1. 流光材质(及uv平铺次数) 2. 跑九宫格 3.闪光3。1 粒子闪烁效果 4.图案重复5.平移扭曲 6.溶解 刀光的uv滚动图片源或采样节点属性里改成clamp无后期发光光晕anistropic 各向异性高光法线图 法线图叠加 blendangle orrectedNo…

探索设计模式的魅力:揭秘B/S模式在AI大模型时代的蜕变与进化

​🌈 个人主页:danci_ 🔥 系列专栏:《设计模式》 💪🏻 制定明确可量化的目标,坚持默默的做事。 揭秘B/S模式在AI大模型时代的蜕变与进化 🚀在AI的波澜壮阔中,B/S模式&…

为 AI 而生的编程语言「GitHub 热点速览」

Mojo 是一种面向 AI 开发者的新型编程语言。它致力于将 Python 的简洁语法和 C 语言的高性能相结合,以填补研究和生产应用之间的差距。Mojo 自去年 5 月发布后,终于又有动作了。最近,Mojo 的标准库核心模块已在 GitHub 上开源,采用…

面试题:JVM 调优

一、JVM 参数设置 1. tomcat 的设置 vm 参数 修改 TOMCAT_HOME/bin/catalina.sh 文件,如下图 JAVA_OPTS"-Xms512m -Xmx1024m" 2. springboot 项目 jar 文件启动 通常在linux系统下直接加参数启动springboot项目 nohup java -Xms512m -Xmx1024m -jar…

前端html+css+js常用总结快速入门

🔥博客主页: A_SHOWY🎥系列专栏:力扣刷题总结录 数据结构 云计算 数字图像处理 力扣每日一题_ 学习前端全套所有技术性价比低下且容易忘记,先入门学会所有基础的语法(cssjsheml)&#xff…

Valkey是一个新兴的开源项目,旨在成为Redis的替代品,背后得到了AWS、Google、Oracle支持

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Windows 禁用 Defender

原文:https://blog.iyatt.com/?p8078 2024.4.4 Windows 11 专业版 23H2 Beta 预览版 进入安全中心,关闭所有,特别是篡改防护选项 打开注册表 地址栏粘粘路径 计算机\HKEY_LOCAL_MACHINE\SOFTWARE\Policies\Microsoft\Windows Defende…

Rust线程间通信通讯channel的理解和使用

Channel允许在Rust中创建一个消息传递渠道,它返回一个元组结构体,其中包含发送和接收端。发送端用于向通道发送数据,而接收端则用于从通道接收数据。不能使用可变变量的方式,线程外面修改了可变变量的值,线程里面是拿不…