TensorFlow简单的线性回归任务

如何使用 TensorFlow 和 Keras 创建、训练并进行预测

1. 数据准备与预处理

2. 构建模型

3. 编译模型

4. 训练模型

5. 评估模型

6. 模型应用与预测

7. 保存与加载模型

8.完整代码


1. 数据准备与预处理

我们将使用一个简单的线性回归问题,其中输入特征 x 和标签 y 之间存在线性关系。我们创建一个训练数据集,并将标签设置为输入特征的两倍加上一些噪声。

import numpy as np
import tensorflow as tf

# 创建训练数据,x 是输入特征,y 是标签(y = 2 * x + 噪声)
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=float)  # 输入数据
y = 2 * x + np.random.normal(0, 1, size=x.shape)  # 标签数据,加一些噪声

2. 构建模型

我们使用一个简单的神经网络来进行线性回归。这个网络只有一个全连接层,激活函数是线性的。

model = tf.keras.Sequential([
    tf.keras.layers.Dense(units=1, input_dim=1, activation='linear')  # 线性激活函数
])

3. 编译模型

使用 SGD 优化器和均方误差损失函数,适合线性回归问题。

model.compile(optimizer='sgd', loss='mean_squared_error')

4. 训练模型

训练模型时,我们设置 1000 个训练周期,并传入数据 x 和标签 y

model.fit(x, y, epochs=1000)

5. 评估模型

训练结束后,我们评估模型的表现,使用 evaluate 函数来查看损失值。

loss = model.evaluate(x, y)
print(f"模型的损失值:{loss}")

6. 模型应用与预测

训练完成后,我们使用 model.predict() 来进行预测。你可以将新的输入数据传入模型,得到预测结果。

# 使用模型进行预测
new_x = np.array([11, 12, 13, 14, 15], dtype=float)
predictions = model.predict(new_x)

print("新的输入数据预测结果:")
print(predictions)

7. 保存与加载模型

你还可以保存和加载训练好的模型,以便在未来使用。\

# 保存模型
model.save('linear_model.keras')

# 加载模型
loaded_model = tf.keras.models.load_model('linear_model.keras')

# 使用加载的模型进行预测
loaded_predictions = loaded_model.predict(new_x)
print("加载的模型预测结果:")
print(loaded_predictions)

8.完整代码

import numpy as np
import tensorflow as tf

# 创建训练数据,x 是输入特征,y 是标签(y = 2 * x + 噪声)
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=float)
y = 2 * x + np.random.normal(0, 1, size=x.shape)

# 构建模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(units=1, input_dim=1, activation='linear')  # 线性激活函数
])

# 编译模型
model.compile(optimizer='sgd', loss='mean_squared_error')

# 训练模型
model.fit(x, y, epochs=1000)

# 评估模型
loss = model.evaluate(x, y)
print(f"模型的损失值:{loss}")

# 使用模型进行预测
new_x = np.array([11, 12, 13, 14, 15], dtype=float)
predictions = model.predict(new_x)

print("新的输入数据预测结果:")
print(predictions)

# 保存模型
model.save('linear_model.keras')

# 加载模型
loaded_model = tf.keras.models.load_model('linear_model.keras')

# 使用加载的模型进行预测
loaded_predictions = loaded_model.predict(new_x)
print("加载的模型预测结果:")
print(loaded_predictions)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963582.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

langchain基础(二)

一、输出解析器(Output Parser) 作用:(1)让模型按照指定的格式输出; (2)解析模型输出,提取所需的信息 1、逗号分隔列表 CommaSeparatedListOutputParser:…

docker安装MySQL8:docker离线安装MySQL、docker在线安装MySQL、MySQL镜像下载、MySQL配置、MySQL命令

一、镜像下载 1、在线下载 在一台能连外网的linux上执行docker镜像拉取命令 docker pull mysql:8.0.41 2、离线包下载 两种方式: 方式一: -)在一台能连外网的linux上安装docker执行第一步的命令下载镜像 -)导出 # 导出镜…

【MySQL】语言连接

语言连接 一、下载二、mysql_get_client_info1、函数2、介绍3、示例 三、其他函数1、mysql_init2、mysql_real_connect3、mysql_query4、mysql_store_result5、mysql_free_result6、mysql_num_fields7、mysql_num_rows8、mysql_fetch_fields9、mysql_fetch_row10、mysql_close …

建表注意事项(2):表约束,主键自增,序列[oracle]

没有明确写明数据库时,默认基于oracle 约束的分类 用于确保数据的完整性和一致性。约束可以分为 表级约束 和 列级约束,区别在于定义的位置和作用范围 复合主键约束: 主键约束中有2个或以上的字段 复合主键的列顺序会影响索引的使用,需谨慎设计 添加…

本地缓存~

前言 Caffeine是使用Java8对Guava缓存的重写版本,在Spring Boot 2.0中取而代之,基于LRU算法实现,支持多种缓存过期策略。 以下摘抄于https://github.com/ben-manes/caffeine/wiki/Benchmarks-zh-CN 基准测试通过使用Java microbenchmark ha…

视觉状态空间模型(VMamba)的解读

在计算机视觉领域,设计计算高效的网络架构一直是研究的热点。今天,我想和大家分享一篇发表在 NIPS 2024 上的论文——VMamba:Visual State Space Model,这篇论文提出了一种新的视觉骨干网络,具有线性时间复杂度&#x…

Kanass基础教程-创建项目

Kanass是一款国产开源免费的项目管理工具,工具简洁易用,开源免费,之前介绍过kanass的一些产品简介及安装配置方法,本文就从如何创建第一个项目来开始kanass上手之旅吧。 1. 创建项目 点击项目->项目添加 按钮进入项目添加页面…

问题的价值 ( Value of Question ) 公式

一、什么是问题的价值 我们的人生、工作的期间、瞬息万变的商业环境中,我们必然会面对很多问题,也会提出很多问题。 但这些问题是否具有回答的 价值,应该如何 衡量 呢? 简单如,女朋友问今晚应该吃什么、世界如何才能…

Zemax 中带有体素探测器的激光谐振腔

激光谐振腔是激光系统的基本组成部分,在光的放大和相干激光辐射的产生中起着至关重要的作用。 激光腔由两个放置在光学谐振器两端的镜子组成。一个镜子反射率高(后镜),而另一个镜子部分透明(输出耦合器)。…

在GPIO控制器中,配置通用输入,读取IO口电平时,上拉和下拉起到什么作用

上下拉电阻作用 在通用输入的时候,也就是在读某个IO的电平的时候 一定要让IO口先保持一个电平状态,这样才能检测到不同电平状态。 如何保持电平状态? 1. 可以通过芯片内部的上下拉电阻,由于是弱上下拉一般不用 2. 硬件外界一个…

如何使用 DeepSeek 和 Dexscreener 构建免费的 AI 加密交易机器人?

我使用DeepSeek AI和Dexscreener API构建的一个简单的 AI 加密交易机器人实现了这一目标。在本文中,我将逐步指导您如何构建像我一样的机器人。 DeepSeek 最近发布了R1,这是一种先进的 AI 模型。您可以将其视为 ChatGPT 的免费开源版本,但增加…

SAP HCM insufficient authorization, no.skipped personnel 总结归纳

导读 权限:HCM模块中有普通权限和结构化权限。普通权限就是PFCG的权限,结构化权限就是按照部门ID授权,颗粒度更细,对分工明细化的单位尤其重要,今天遇到的问题就是结构化权限的问题。 作者:vivi,来源&…

python-leetcode-二叉树的右视图

199. 二叉树的右视图 - 力扣(LeetCode) # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # self.right right class Solut…

冲刺一区!挑战7天完成一篇趋势性分析GBD DAY1-7

Day1. 公开数据库的挖掘太火热了,其中GBD数据库的挖掘又十分的火爆.那我就来挑战一篇GBD、一篇关于趋势性分析的GBD! GBD数据库挖掘是目前的四大刊常客,经常出现在顶级期刊上面。这个数据库亮点就是:可视化,统计学简单、而数据可…

Maven全解析:从基础到精通的实战指南

概念: Maven 是跨平台的项目管理工具。主要服务基于 Java 平台的构建,依赖管理和项目信息管理项目构建:高度自动化,跨平台,可重用的组件,标准化的流程 依赖管理: 对第三方依赖包的管理&#xf…

使用LLaMA-Factory对AI进行认知的微调

使用LLaMA-Factory对AI进行认知的微调 引言1. 安装LLaMA-Factory1.1. 克隆仓库1.2. 创建虚拟环境1.3. 安装LLaMA-Factory1.4. 验证 2. 准备数据2.1. 创建数据集2.2. 更新数据集信息 3. 启动LLaMA-Factory4. 进行微调4.1. 设置模型4.2. 预览数据集4.3. 设置学习率等参数4.4. 预览…

复制粘贴小工具——Ditto

在日常工作中,复制粘贴是常见的操作,但Windows系统自带的剪贴板功能较为有限,只能保存最近一次的复制记录,这对于需要频繁复制粘贴的用户来说不太方便。今天,我们介绍一款开源、免费且功能强大的剪贴板增强工具——Dit…

无人机图传模块 wfb-ng openipc-fpv,4G

openipc 的定位是为各种模块提供底层的驱动和linux最小系统,openipc 是采用buildroot系统编译而成,因此二次开发能力有点麻烦。为啥openipc 会用于无人机图传呢?因为openipc可以将现有的网络摄像头ip-camera模块直接利用起来,从而…

Redis代金卷(优惠卷)秒杀案例-多应用版

Redis代金卷(优惠卷)秒杀案例-单应用版-CSDN博客 上面这种方案,在多应用时候会出现问题,原因是你通过用户ID加锁 但是在多应用情况下,会出现两个应用的用户都有机会进去 让多个JVM使用同一把锁 这样就需要使用分布式锁 每个JVM都会有一个锁监视器,多个JVM就会有多个锁监视器…

国产之光DeepSeek架构理解与应用分析

目录 初步探索DeepSeek的设计 一、核心架构设计 二、核心原理与优化 三、关键创新点 四、典型应用场景 五、与同类模型的对比优势 六、未来演进方向 从投入行业生产的角度看 一、DeepSeek的核心功能扩展 二、机械电子工程产业中的具体案例 1. 预测性维护(Predictive…