【tensorflow】连续输入的线性回归模型训练代码

【tensorflow】连续输入的感知机模型训练

  • 全部代码 - 复制即用
  • 训练输出
  • 代码介绍

  查看本系列三种模型写法:
  【tensorflow】连续输入的线性回归模型训练代码
  【tensorflow】连续输入的神经网络模型训练代码
  【tensorflow】连续输入+离散输入的神经网络模型训练代码

全部代码 - 复制即用

from sklearn.model_selection import train_test_split
import tensorflow as tf
import numpy as np
from keras import Input, Model, Sequential
from keras.layers import Dense, concatenate, Embedding, LSTM
from sklearn.preprocessing import StandardScaler
from tensorflow import keras

def get_data():
    # 设置随机种子,以确保结果可复现(可选)
    np.random.seed(0)

    # 生成随机数据
    data = np.random.rand(10000, 10)
    
    # 正则化数据
    scaler = StandardScaler()
    data = scaler.fit_transform(data)

    # 生成随机数据
    target = np.random.rand(10000, 1)

    return train_test_split(data, target, test_size=0.1, random_state=42)

data_train, data_val, target_train, target_val = get_data()

# 迭代轮次
train_epochs = 10
# 学习率
learning_rate = 0.0001
# 批大小
batch_size = 200

# 定义模型
with tf.name_scope("Model"):
    x = tf.placeholder(tf.float32, [None,10]) # 10个特征数据(10列)
    y = tf.placeholder(tf.float32, [None, 1])

    d = tf.Variable(tf.random_normal([10,10], stddev=0.01))
    w = tf.Variable(tf.random_normal([10,1], stddev=0.01))

    # b 初始化值为 1.0
    a = tf.Variable(1.0)
    b = tf.Variable(1.0)

    k = tf.matmul(x, d) + a
    pred = tf.matmul(k, w) + b

#损失函数
loss_function = tf.reduce_mean(tf.square(y-pred))
# 创建优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

#3生成会话,训练STEPS轮
with tf.Session() as sess:
    # 初始化参数
    init = tf.global_variables_initializer()
    sess.run(init)


    # 训练模型
    STEPS = 1000 #3000
    for i in range(STEPS):
        start = (i*batch_size) % 4
        end =  start + batch_size
        #sess.run(optimizer, feed_dict={x: X[start:end], y_: Y_[start:end]})
    
        if i % 100 == 0:
            total_loss = sess.run(loss_function, feed_dict={x: data_train, y: target_train})
            print("After %d training step(s), loss_mse on all data is %g" % (i, total_loss))            
        sess.run(optimizer, feed_dict={x: data_val, y: target_val})

训练输出

  模型训练过程中的输出如下:
在这里插入图片描述

代码介绍

  get_data函数用于生成随机的训练和验证数据集。首先使用np.random.rand生成一个形状为(10000, 10)的随机数据集,来模拟10维的连续输入,然后使用StandardScaler对数据进行标准化。再生成一个(10000,1)的target,表示最终拟合的目标分数。最后使用train_test_split函数将数据集划分为训练集和验证集。

  由于target是浮点数,所以我们这个任务就是回归任务了。

  在定义模型时,使用tf.placeholder定义占位符,用于传入输入数据和目标数据。定义变量d、w、a和b作为模型的权重和偏置。使用tf.matmul进行矩阵乘法和加法操作,得到预测值pred。

可以尝试改改模型,换一换激活函数。

  定义损失函数为均方差(MSE),使用梯度下降优化器进行参数更新。

  在tf.Session中创建会话,通过tf.global_variables_initializer初始化模型的变量。然后进行训练,迭代STEPS轮。在每一轮训练中,通过sess.run运行优化器进行参数更新,并计算训练集上的损失。每训练100轮,打印出当前轮次的损失值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/29683.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

常用JVM命令

top 展示 进程运行的完整命令行的话可以用 top -c ,当命令行较长无法分辨是哪个程序,可使用键盘右键将窗口不断滑动至右侧查看。 uptime jps 查看当前正在运行的java进程 执行结果: pid 运行文件 [roottest1 ~]# jps 24001 rs-medical-rp…

DBeaver连接SQLite数据库

一、前言 SQLite小巧轻便的开源免费关系型数据库,适合嵌入单机应用随身携带。桌面版推荐使用DBeaver。 官网:SQLite Download Page github:GitHub - sqlite/sqlite: Official Git mirror of the SQLite source tree 类似的开源免费且小巧…

WebGL前言——WebGL相关介绍

第一讲内容主要介绍WebGL技术和相应的硬件基础部分,在初级课程和中级课程的基础上,将技术和硬件基础进行串联,能够对WebGL从产生到消亡有深刻全面的理解。同时还介绍WebGL大家在初级课程和中级课程中的一些常见错误以及错误调试的办法。 1.1…

Jmeter常用参数化技巧总结!

说起接口测试,相信大家在工作中用的最多的还是Jmeter。 JMeter是一个100%的纯Java桌面应用,由Apache组织的开放源代码项目,它是功能和性能测试的工具。具有高可扩展性、支持Web(HTTP/HTTPS)、SOAP、FTP、JAVA 等多种协议。 在做…

Shell脚本文本三剑客之sed编辑器

目录 一、sed编辑器简介 二、sed工作流程 三、sed命令 四、sed命令的使用 1.sed打印文件内容(p) (1)打印文件所有行 (2)打印文件指定行 2.sed增加、插入、替换行(a、i、c) …

Git工作流(随笔)

目录 前言 一、工作流概述 1、概念 2、分类 二、集中式工作流 1、概述 2、介绍 3、操作过程 三、功能分支工作流 1、概述 2、介绍 3、操作过程 1)创建远程分支 2)删除远程分支 四、GitFlow工作流 1、概述 2、介绍 3、操作过程 五、Forki…

管理类联考——英语二——知识篇——写作——题目说明——A节

MBA,MPA,MPAcc管理类联考英语写作部分由A,B两节组成,主要考查考生的书面表达能力。共2题,25分。A节要求考生根据所给情景写出约100词(标点符号不计算在内)的应用文,包括私人和公务信函、通知、备忘录等。共…

Get请求参数过多导致请求失败

1. 问题 系统正常使用没有问题,但是有极个别的用户出现系统异常,通过日志发现某个get请求,传入的城市list太多,就会抛出异常 java.lang.ClassCastException: java.lang.String cannot be cast to java.util.Map。 2. 排查过程 …

Vue中如何进行游戏开发与游戏引擎集成?

Vue中如何进行游戏开发与游戏引擎集成? Vue.js是一款流行的JavaScript框架,它的MVVM模式和组件化开发思想非常适合构建Web应用程序。但是,如果我们想要开发Web游戏,Vue.js并不是最合适的选择。在本文中,我们将介绍如何…

【java】使用 BeanUtils.copyProperties 11个坑(注意事项)

文章目录 背景第1个坑: 类型不匹配第2个坑: BeanUtils.copyProperties是浅拷贝第3个坑:属性名称不一致第4个坑:Null 值覆盖第5个坑:注意引入的包第6个坑:Boolean类型数据is属性开头的坑第7个坑:查找不到字段…

【鲁棒优化】具有可再生能源和储能的区域微电网的最优运行:针对不确定性的鲁棒性和非预测性解决方案(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

景区旅游多商户版小程序v14.3.1+前端

🎈 限时活动领体验会员:可下载程序网创项目短视频素材 🎈 🎉 有需要的朋友记得关赞评,文章底部来交流!!! 🎉 ✨ 源码介绍 【新增】全新授权登录支持取消登录 【新增】商…

无需租云服务器,Linux本地搭建web服务,并内网穿透发布公网访问

文章目录 前言1. 本地搭建web站点2. 测试局域网访问3. 公开本地web网站3.1 安装cpolar内网穿透3.2 创建http隧道,指向本地80端口3.3 配置后台服务 4. 配置固定二级子域名5. 测试使用固定二级子域名访问本地web站点 转载自cpolar文章:Linux CentOS本地搭建…

saltstack草稿

salt [options] <target> <module.function> [arguments] salt的自建函数&#xff1a; salt * test.rand_sleep 120 salt/salt/modules/test.py 这个是salt自带的包 salt * disk.usage salt -G ipv4:192.168.50.12 cmd.run ls -l /home salt * grain…

C语言之动态内存分配(1)

目录 本章重点 为什么存在动态内存分配 动态内存函数的介绍 malloc free calloc realloc 常见的动态内存错误 几个经典的笔试题 柔性数组 动态内存管理—自己维护自己的内存空间的大小 首先我们申请一个变量&#xff0c;再申请一个数组 这是我们目前知道的向内存申请…

Spark大数据处理学习笔记1.4 掌握Scala运算符

文章目录 一、学习目标二、运算符等价于方法&#xff08;一&#xff09;运算符即方法&#xff08;二&#xff09;方法即运算符1、单参方法2、多参方法3、无参方法 三、Scala运算符&#xff08;一&#xff09;运算符分类表&#xff08;二&#xff09;Scala与Java运算符比较 四、…

基于prefix tuning + Bert的标题党分类器

文章目录 背景一、Prefix-Tuning介绍二、分类三、效果四、参阅 背景 近期, CSDN博客推荐流的标题党博客又多了起来, 先前的基于TextCNN版本的分类模型在语义理解上能力有限, 于是, 便使用的更大的模型来优化, 最终准确率达到了93.7%, 还不错吧. 一、Prefix-Tuning介绍 传统的…

搭建TiDB负载均衡环境-LVS+KeepAlived实践

作者&#xff1a; 我是咖啡哥 原文来源&#xff1a; https://tidb.net/blog/f614b200 昨天&#xff0c;发了一篇使用HAproxyKP搭建TiDB负载均衡环境的文章&#xff0c;今天我们再用LVSKP来做个实验。 环境信息 TiDB版本&#xff1a;V7.1.0 haproxy版本&#xff1a;2.6.2 …

Linux环境下的工具(yum,gdb,vim)

一&#xff0c;yum yum其实是linux环境下的一种应用商店&#xff0c;主要用centos等版本。它也有三板斧&#xff1a;yum list,yum remove,yum install。当然不是说他只有这三个命令&#xff0c;还有yum search等等。在这直说以上三个。 yum list其实是查看你所能安装的软件包…

13.常用类|Java学习笔记

文章目录 包装类包装类型和String类型的相互转换Integer类和Character类的常用方法Integer创建机制&面试题 String类创建String对象的两种方式和区别字符串的特性String类的常用方法 StringBuffer类String和StringBuffer相互转换StringBuffer常用方法 StringBuilder类Strin…