【tensorflow框架神经网络实现鸢尾花分类】

文章目录

  • 1、数据获取
  • 2、数据集构建
  • 3、模型的训练验证
  • 可视化训练过程

1、数据获取

  • 从sklearn中获取鸢尾花数据,并合并处理
from sklearn.datasets import load_iris
import pandas as pd

x_data = load_iris().data
y_data = load_iris().target

x_data = pd.DataFrame(x_data, columns=['花萼长度','花萼宽度','花瓣长度','花瓣宽度'])
pd.set_option('display.unicode.east_asian_width', True)

x_data['类别'] = y_data
x_data

在这里插入图片描述

2、数据集构建

  • 数据集构建包括:
    • 数据读取
    • 数据打乱
    • 数据划分
    • 小批量迭代器生成
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris

# 1、从sklearn包中datasets读取数据集
x_data = load_iris().data
y_data = load_iris().target

# 2、数据打乱
np.random.seed(1)   # 使用相同的seed,使输入特征/标签一一对应
np.random.shuffle(x_data)
np.random.seed(1)
np.random.shuffle(y_data)
tf.random.set_seed(1)

# 3、训练集、测试集划分
x_train, x_test = x_data[:-30], x_data[-30:]
y_train, y_test = y_data[:-30], y_data[-30:]

# 4、小批量数据
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
train_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

3、模型的训练验证

# 定义超参数,预设变量
lr = 0.1
loss_all = 0
Epoch = 500
train_loss_list = []
test_acc = []

# 定义神经网络的可训练参数
w = tf.Variable(tf.random.truncated_normal([4,3], stddev=0.1, seed=1))
b = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))

# 循环迭代,训练参数
for epoch in range(Epoch):
    for step, (x_, y_) in enumerate(train_db):
        with tf.GradientTape() as tape:
            x_ = tf.cast(x_, tf.float32)
            y_pre = tf.matmul(x_, w) + b
            y_pre = tf.nn.softmax(y_pre)

            y_lab = tf.one_hot(y_, depth=3)
            loss = tf.reduce_mean(tf.square(y_lab - y_pre))
            loss_all += loss.numpy()

        grads = tape.gradient(loss, [w,b])
        w.assign_sub(lr * grads[0])
        b.assign_sub(lr * grads[1])
    print(f'Epoch: {epoch}, loss: {loss_all/4}')
    train_loss_list.append(loss_all/4)
    loss_all = 0

    # 测试部分
    total_correct, total_number = 0, 0
    for x_,y_ in test_db:
        x_ = tf.cast(x_, tf.float32)
        y_pre = tf.matmul(x_, w) + b
        y_pre = tf.nn.softmax(y_pre)
        y_p = tf.argmax(y_pre, axis=1)
        y_p = tf.cast(y_p, dtype=y_.dtype)

        correct = tf.cast(tf.equal(y_p, y_), dtype=tf.int32)
        correct = tf.reduce_sum(correct)
        total_correct += int(correct) 
        total_number += x_.shape[0]
    acc = total_correct / total_number
    test_acc.append(acc)
    print("Test_acc:", acc)
    print("-"*30)

在这里插入图片描述

可视化训练过程

# 绘制测试Acc曲线和训练loss曲线
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(train_loss_list,'b-')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')

ax1 = ax.twinx()
ax1.plot(test_acc,'r-')
ax1.set_ylabel('Acc')

ax1.spines['left'].set_color('blue')
ax1.spines['right'].set_color('red')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/497497.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

kubernetes K8s的监控系统Prometheus升级Grafana,来一个酷炫的Node监控界面(二)

上一篇文章《kubernetes K8s的监控系统Prometheus安装使用(一)》中使用的监控界面总感觉监控的节点数据太少,不能快算精准的判断出数据节点运行的状况。 今天我找一款非常酷炫的多维度数据监控界面,能够非常有把握的了解到各节点的数据,以及运…

快速上手Spring Cloud 七:事件驱动架构与Spring Cloud

快速上手Spring Cloud 一:Spring Cloud 简介 快速上手Spring Cloud 二:核心组件解析 快速上手Spring Cloud 三:API网关深入探索与实战应用 快速上手Spring Cloud 四:微服务治理与安全 快速上手Spring Cloud 五:Spring …

AOP切入点表达式基本格式

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl 官方地址 https://docs.spring.io/spring-framework/reference/core/aop/ataspectj/pointcuts.html AOP切入点表达式基本格式如下: execution(modifiers-patte…

竞赛 python+opencv+深度学习实现二维码识别

0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 pythonopencv深度学习实现二维码识别 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:3分工作量:3分创新点:3分 该项目较为新颖&…

ES6 学习(三)-- es特性

文章目录 1. Symbol1.1 使用Symbol 作为对象属性名1.2 使用Symbol 作为常量 2. Iterator 迭代器2.1 for...of循环2.2 原生默认具备Interator 接口的对象2.3 给对象添加Iterator 迭代器2.4 ... 解构赋值 3. Set 结构3.1 初识 Set3.2 Set 实例属性和方法3.3 遍历3.4 相关面试题 4…

RabbitMQ3.x之四_RabbitMQ角色说明及创建用户与授权

RabbitMQ3.x之四_角色说明及创建用户与授权 文章目录 RabbitMQ3.x之四_角色说明及创建用户与授权1. 访问和授权1. Tags说明2. 命令行示例 2. 管理界面新建用户及访问授权1. 管理界面新建用户2. 管理界面中的授权说明3. guest用户不能远程登录提示 3. 创建用户1. 基本命令2. 实际…

新网站收录时间是多久,新建网站多久被百度收录

对于新建的网站而言,被搜索引擎收录是非常重要的一步,它标志着网站的正式上线和对外开放。然而,新网站被搜索引擎收录需要一定的时间,而且时间长短受多种因素影响。本文将探讨新网站收录需要多长时间,以及新建网站多久…

微信小程序更换头像的功能

微信小程序开发&#xff0c;个人中心中更换头像的更能使用频率很高&#xff0c;这里记录下实现方式&#xff1a; <view class"setting-list avatar-container"><text>头像</text><view class"avatar"><button hover-class"…

华为云使用指南02

5.​​使用GitLab进行团队及项目管理​​ GitLab旨在帮助团队进行项目开发协作&#xff0c;为软件开发和运营生命周期提供了一个完整的DevOps方案。GitLab功能包括&#xff1a;项目源码的管理、计划、创建、验证、集成、发布、配置、监视和保护应用程序等。该镜像基于CentOS操…

ZK友好代数哈希函数安全倡议

1. 引言 前序博客&#xff1a; ZKP中的哈希函数如何选择ZK-friendly 哈希函数&#xff1f;snark/stark-friendly hash函数Anemoi Permutation和Jive Compression模式&#xff1a;高效的ZK友好的哈希函数Tip5&#xff1a;针对Recursive STARK的哈希函数 随着Incrementally Ve…

STM32 字符数组结束符 “\0”

STM32 字符数组结束符 “\0” 使用字符数组使用printf&#xff0c;string参考 使用字符数组 使用STM32的串口发送数据&#xff0c;核心代码如下&#xff1a; char str[] "hello world!\n\r";while(1) {HAL_UART_Transmit(&huart2, str, sizeof (str), 10);HAL…

构建一个基础的大型语言模型(LLM)应用程序

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

Radio Silence for mac 好用的防火墙软件

Radio Silence for Mac是一款功能强大的网络防火墙软件&#xff0c;专为Mac用户设计&#xff0c;旨在保护用户的隐私和网络安全。它具备实时网络监视和控制功能&#xff0c;可以精确显示每个网络连接的状态&#xff0c;让用户轻松掌握网络活动情况。 软件下载&#xff1a;Radio…

图扑数字孪生智慧城市,综合治理一屏统览

现代城市作为一个复杂系统&#xff0c;牵一发而动全身&#xff0c;城市化进程中产生新的矛盾和社会问题都会影响整个城市系统的正常运转。智慧城市是应对这些问题的策略之一。领导曾在中央城市工作会议上指出&#xff0c;城市工作要树立系统思维&#xff0c;从构成城市诸多要素…

探索数据库mysql--------------mysql主从复制和读写分离

目录 前言 为什么要主从复制&#xff1f; 主从复制谁复制谁&#xff1f; 数据放在什么地方&#xff1f; 一、mysql支持的复制类型 1.1STATEMENT&#xff1a;基于语句的复制 1.2ROW&#xff1a;基于行的复制 1.3MIXED&#xff1a;混合类型的复制 二、主从复制的工作过程 三个重…

HarmonyOS 应用开发之UIAbility组件与UI的数据同步

基于当前的应用模型&#xff0c;可以通过以下几种方式来实现UIAbility组件与UI之间的数据同步。 使用EventHub进行数据通信&#xff1a;在基类Context中提供了EventHub对象&#xff0c;可以通过发布订阅方式来实现事件的传递。在事件传递前&#xff0c;订阅者需要先进行订阅&a…

I/O模型的一些理解

I/O模型的一些理解 一些基本的概念同步阻塞、同步非阻塞、异步阻塞、异步非阻塞总结概念 I/O模型一些例子 从源头解释从TCP发送数据的流程说起阻塞I/O | 非阻塞I/OI/O多路复用信号驱动I/O异步I/O再谈IO模型里面的同步异步 参考连接 参考链接 参考链接 一些基本的概念 阻塞(b…

Postman传对象失败解决

文章目录 情景复现解决方案总结 情景复现 postman中调用 debug发现pId传入失败 分析解释&#xff1a; 实体类中存在pId、uid和num字段 controller层将GoodsCar作为请求体传入 解决方案 当时觉得很奇怪&#xff0c;因为uid和num可以被接收&#xff0c;而pId和num的数据类型相…

pyecharts操作三

pyecharts操作三 pyecharts 是一个用于生成Echarts图表的Python库。Echarts是百度开源的一个数据可视化JS库&#xff0c;可以生成一些非常酷炫的图表。 环境安装 pip install pyecharts 检查版本 import pyecharts print(pyecharts.version) 2.0.3 GL关系图 import rando…

uniApp使用XR-Frame创建3D场景(3)光源投影的运用。

上一篇讲解了如何在uniApp中创建xr-frame子组件并创建简单的3D场景。 这篇我们讲解光源在场景中的运用以及相关属性。 在子组件 xr-start的index.wxml文件中我们加入如下代码 <xr-scene render-system"alpha:true" bind:ready"handleReady"><xr…