深度学习神经网络实战:多层感知机,手写数字识别

目的

利用tensorflow.js训练模型,搭建神经网络模型,完成手写数字识别

设计

简单三层神经网络

  • 输入层
    28*28个神经原,代表每一张手写数字图片的灰度
  • 隐藏层
    100个神经原
  • 输出层
    -10个神经原,分别代表10个数字

代码

// 导入 TensorFlow.js 库
import tf from "@tensorflow/tfjs";
import * as tfjsnode from "@tensorflow/tfjs-node";
import * as tfvis from "@tensorflow/tfjs-vis";
import fs from "fs";
import plot from "nodeplotlib";
// 定义模型
const model = tf.sequential();

// 添加输入层
model.add(
  tf.layers.dense({ units: 64, inputShape: [784], activation: "relu" })
);

// 添加隐藏层
model.add(tf.layers.dense({ units: 100, activation: "relu" }));

// 添加输出层
model.add(tf.layers.dense({ units: 10, activation: "softmax" }));

// 编译模型
model.compile({
  optimizer: "sgd",
  loss: "categoricalCrossentropy",
  metrics: ["accuracy"],
});
const trainDataLen = 3000;
const testDataLen = 2000;

// 加载 MNIST 数据集
import pkg from "mnist";
const { set: Dataset } = pkg;
const set = Dataset(trainDataLen, testDataLen);
const trainingSet = set.training;
const testSet = set.test;

const trainXs = [];
const testXs = [];

const trainLabels = [];
const testLabels = [];

for (let i = 0; i < trainingSet.length; i++) {
  trainXs.push(trainingSet[i].input);
  trainLabels.push(trainingSet[i].output.indexOf(1));
}

for (let i = 0; i < testSet.length; i++) {
  testXs.push(testSet[i].input);
  testLabels.push(testSet[i].output.indexOf(1));
}

// 准备数据
const trainXsTensor = tf.tensor(trainXs, [trainDataLen, 784]);
const trainYsOneHot = tf.oneHot(trainLabels, 10);

//记录每轮模型训练中的损失和精度,为了绘制曲线图
var accPlot = [];
var lossPlot = [];

// 模型训练
model
  .fit(trainXsTensor, trainYsOneHot, {
    batchSize: 64,
    epochs: 100,
    validationSplit: 0.2,
    callbacks: {
      onEpochBegin: (epoch) => console.log(`Epoch ${epoch + 1} started...`),
      onEpochEnd: async (epoch, logs) => {
        console.log(
          `Epoch ${epoch + 1} completed. Loss: ${logs.loss.toFixed(
            3
          )}, Accuracy: ${logs.acc.toFixed(3)}`
        );
        //记录loss和acc,绘制曲线图
        accPlot.push(logs.acc.toFixed(3));
        lossPlot.push(logs.loss.toFixed(3));

        await tf.nextFrame(); // 防止阻塞
      },
      onBatchEnd: async (batch, logs) => {
        console.log(
          `Batch ${batch} completed. Loss: ${logs.loss.toFixed(
            3
          )}, Accuracy: ${logs.acc.toFixed(3)}`
        );
        await tf.nextFrame(); // 防止阻塞
      },
    },
  })
  .then((history) => {
    console.log("Training completed!", history);
    //绘制模型训练过程中的损失函数和模型精度曲线变化
    const epochs = Array.from({ length: lossPlot.length }, (_, i) => i + 1);
    plot.plot(
      [
        { x: epochs, y: lossPlot, name: "Loss" },
        { x: epochs, y: accPlot, name: "Accuracy" },
      ],
      {
        filename: "loss_acc.png",
      }
    );

    //模型评估
    const testXsTensor = tf.tensor(testXs, [testDataLen, 784]);
    const testYsOneHot = tf.oneHot(testLabels, 10);

    const result = model.evaluate(testXsTensor, testYsOneHot);
    const testLoss = result[0].dataSync()[0];
    const testAccuracy = result[1].dataSync()[0];

    console.log(`Test loss: ${testLoss.toFixed(3)}`);
    console.log(`Test accuracy: ${testAccuracy.toFixed(3)}`);
    //保存模型
    model.save("file://./my-model").then(() => {
      console.log("Model saved!");
    });
  });

package.json

{
  "name": "neural_network",
  "version": "1.0.0",
  "description": "",
  "type": "module",
  "main": "mlpTest.js",
  "scripts": {
    "test": "echo \"Error: no test specified\" && exit 1",
  },
  "author": "",
  "license": "ISC",
  "dependencies": {
    "@tensorflow/tfjs": "^4.17.0",
    "@tensorflow/tfjs-node": "^4.17.0",
    "@tensorflow/tfjs-vis": "^1.0.0",
    "mnist": "^1.1.0",
    "nodeplotlib": "^0.7.7"
  },
  "devDependencies": {
    "@babel/core": "^7.0.0",
    "@babel/preset-env": "^7.0.0",
    "babel-loader": "^8.0.0",
    "webpack": "^5.0.0",
    "webpack-cli": "^4.0.0"
  }
}

模型结果

模型精度

损失函数与模型精度变化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/402431.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于FPGA的I2C接口控制器(包含单字节和多字节读写)

1、概括 前文对IIC的时序做了详细的讲解&#xff0c;还有不懂的可以获取TI的IIC数据手册查看原理。通过手册需要知道的是IIC读、写数据都是以字节为单位&#xff0c;每次操作后接收方都需要进行应答。主机向从机写入数据后&#xff0c;从机接收数据&#xff0c;需要把总线拉低来…

CSP-J 2023 T3 一元二次方程

文章目录 题目题目背景题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示 题目传送门题解思路总代码 提交结果尾声 题目 题目背景 众所周知&#xff0c;对一元二次方程 a x 2 b x c 0 , ( a ≠ 0 ) ax ^ 2 bx c 0, (a \neq 0) ax2bxc0,(a0)&#xff0c;可…

收单外包机构备案2023年回顾和2024年展望

孟凡富 本文原标题为聚合支付深度复盘与展望&#xff0c;首发于《支付百科》公众号&#xff01; 收单外包服务机构在我国支付收单市场中占据着举足轻重的地位&#xff0c;其规模在政策引导和市场需求驱动下不断扩大。同时&#xff0c;随着行业自律管理体系的持续发展和完善&a…

pycharm 远程运行报错 Failed to prepare environment

什么也没动的情况下&#xff0c;远程连接后运行是没问题的&#xff0c;突然在运行时就运行不了了&#xff0c;解决方案 清理缓存&#xff1a; 有时候 PyCharm 的内部缓存可能出现问题&#xff0c;可以尝试清除缓存&#xff08;File > Invalidate Caches / Restart&#xff0…

通俗理解Kotlin及其30大特性

通俗理解Kotlin及其30大特性 文章目录 通俗理解Kotlin及其30大特性前言背景编译&运行字节码对比 Java VS Kotlin变量/常量类型声明变量初始化空安全特性 函数函数声明函数参数函数可变参数局部函数函数/属性/操作符的扩展函数/属性的引用操作符重载Lambda 表达式数组/List/…

css中选择器的优先级

CSS 的优先级是由选择器的特指度&#xff08;Specificity&#xff09;和重要性&#xff08;Importance&#xff09;决定的&#xff0c;以下是优先级规则&#xff1a; 特指度&#xff1a; ID 选择器 (#id): 每个ID选择器计为100。 类选择器 (.class)、属性选择器 ([attr]) 和伪…

一个服务器实现本机服务互联网化

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 一个服务器实现本机服务互联网化 前言痛点关于中微子代理实战演练搭建服务端搭建客户端服务端配置代理实现 前言 在数字世界的网络战场上&#xff0c;中微子代理就像是一支潜伏在黑暗中的数字特工队&…

PacketSender-用于发送/接收 TCP、UDP、SSL、HTTP 的网络实用程序

PacketSender-用于发送/接收 TCP、UDP、SSL、HTTP 的网络实用程序 PacketSender是一款开源的用于发送/接收 TCP、UDP、SSL、HTTP 的网络实用程序&#xff0c;作者为dannagle。 其官网地址为&#xff1a;https://packetsender.com/&#xff0c;Github源代码地址&#xff1a;htt…

Java 事件处理机制

一、快速入门 import javax.swing.*; import java.awt.*; import java.awt.event.KeyEvent; import java.awt.event.KeyListener; import java.awt.event.MouseListener; import java.awt.event.WindowListener;public class BallMove extends JFrame { //窗口MyPanel mp null…

一款高输出电流 PWM 转换器

一、产品描述 TPS543x 是一款高输出电流 PWM 转换器&#xff0c;集成了低电阻、高侧 N 沟道 MOSFET。具有所列的特性的基板上还包括高性能电压误差放大器&#xff08;可在瞬态条件下提供高稳压精度&#xff09;、欠压锁定电路&#xff08;用于防止在输入电压达到 5.5V 前启动&…

Py之ydata-profilin:ydata-profiling的简介、安装、使用方法之详细攻略

Py之ydata-profilin&#xff1a;ydata-profiling的简介、安装、使用方法之详细攻略 目录 ydata-profiling的简介 1、主要特点 2、案例应用 (1)、比较数据集、对时序数据集进行分析、对大型数据集进行分析、处理敏感数据、数据集元数据和数据字典、自定义报告的外观、不同类型…

【MATLAB源码-第144期】基于matlab的蝴蝶优化算法(BOA)无人机三维路径规划,输出做短路径图和适应度曲线。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 ​蝴蝶优化算法&#xff08;Butterfly Optimization Algorithm, BOA&#xff09;是基于蝴蝶觅食行为的一种新颖的群体智能算法。它通过模拟蝴蝶个体在寻找食物过程中的嗅觉导向行为以及随机飞行行为&#xff0c;来探索解空间…

使用两个队列实现栈

在计算机科学中&#xff0c;栈是一种数据结构&#xff0c;它遵循后进先出&#xff08;LIFO&#xff09;的原则。这意味着最后一个被添加到栈的元素将是第一个被移除的元素。然而&#xff0c;Java的标准库并没有提供栈的实现&#xff0c;但我们可以使用两个队列来模拟一个栈的行…

十五、随机数和随机颜色

项目功能实现&#xff1a;在原图上进行每隔0.5s随机绘制不同长度不同颜色的线段(保存之前的线段)&#xff0c;在另一个画布上进行绘制随机不同长度不同颜色的线段(不保存之前的线段) 按照之前的博文结构来&#xff0c;这里就不在赘述了 一、头文件 random.h #pragma once#i…

Fiddler工具 — 19.Fiddler抓包HTTPS请求(二)

5、查看证书是否安装成功 方式一&#xff1a; 点击Tools菜单 —> Options... —> HTTPS —> Actions 选择第三项&#xff1a;Open Windows Certificate Manager打开Windows证书管理器。 打开Windows证书管理器&#xff0c;选择操作—>查看证书&#xff0c;在搜索…

【C++精简版回顾】6.构造函数

一。类的三种初始化方式 1.不使用构造函数初始化类 使用函数引用来初始化类 class MM { public:string& getname() {return name;}int& getage() {return age;}void print() {cout << "name: " << name << endl << "age: &quo…

跨境电商消息多发脚本制作需要用到的代码!

在跨境电商的运营中&#xff0c;为了更有效地推广产品、提升品牌知名度并增强与消费者的互动&#xff0c;消息群发成为了一个重要的营销手段。 为了实现这一目的&#xff0c;许多跨境电商团队会选择制作消息多发脚本&#xff0c;通过自动化发送消息来提高效率和覆盖面&#xf…

Postman接口测试之Mock快速入门

一、Mock简介 1.Mock定义 Mock是一种比较特殊的测试技巧&#xff0c;可以在没有依赖项的情况下进行接口或单元测试。通常情况下&#xff0c;Mock与其他方法的区别是&#xff0c;用于模拟代码依赖对象&#xff0c;并允许设置对应的期望值。简单一点来讲&#xff0c;就是Mock创建…

LabVIEW多通道压力传感器实时动态检测

LabVIEW多通道压力传感器实时动态检测 介绍了一种基于LabVIEW的多通道压力传感器实时动态检测系统&#xff0c;解决压阻式压力传感器温度补偿过程的复杂度&#xff0c;提高测量的准确性。通过自动轮询检测方法&#xff0c;结合硬件检测模型和多通道检测系统设计&#xff0c;本…

ADC--模拟量转换成数字量

目录 一、ADC硬件组成七大部分&#xff1a; 二、单次转换&#xff0c;连续转换&#xff0c;不连续采样模式&#xff0c;扫描模式区别 1、举例(5种组合情况) 2、模拟看门狗中断的作用&#xff1a; 三、MCU使用ADC步骤 一、ADC硬件组成七大部分&#xff1a; ①输入电压&#…