深度学习_4_实战_直线最优解

梯度
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

实战

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

代码:

# %matplotlib inline
import random
import torch
import matplotlib.pyplot as plt
# from d21 import torch as d21

def synthetic_data(w, b, num_examples):
    """生成 Y = XW + b + 噪声。"""
    X = torch.normal(0, 1, (num_examples, len(w)))# 均值为0,方差为1的随机数,n个样本,列数为w的长度
    y = torch.matmul(X, w) + b # y = x * w + b
    y += torch.normal(0, 0.01, y.shape) # 加入随机噪音,均值为0.。形状与y的一样
    return X, y.reshape((-1, 1))# x, y做成列向量返回


true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
#读取小批量,输出batch_size的小批量,随机选取
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))#转成list
    random.shuffle(indices)#打乱
    for i in range(0, num_examples, batch_size):#
        batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)])#取
    yield features[batch_indices], labels[batch_indices]#不断返回



# #print(features)
# #print(labels)
#
#
batch_size = 10
#
# for x, y in data_iter(batch_size, features,labels):
#      print(x, '\n', y)
#      break
# # 提取第一列特征作为x轴,第二列特征作为y轴
# x = features[:, 1].detach().numpy() #将特征和标签转换为NumPy数组,以便能够在Matplotlib中使用。
# y = labels.detach().numpy()
#
# # 绘制散点图
# plt.scatter(x, y, 1)
# plt.xlabel('Feature 1')
# plt.ylabel('Feature 2')
# plt.title('Synthetic Data')
# plt.show()
#
# #定义初始化模型

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad = True)

def linreg(x, w, b):
    return torch.matmul(x, w) + b

#定义损失函数

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape))**2 / 2 #弄成一样的形状

# 定义优化算法
def sgd(params, lr, batch_size):
    """小批量随梯度下降"""
    with torch.no_grad():#节省内存和计算资源。
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()#用于清空张量param的梯度信息。

print("训练函数")

lr = 0.03 #学习率
num_ecopchs = 300 #数据扫描三遍
net = linreg #指定模型
loss = squared_loss #损失

for epoch in range(num_ecopchs):#扫描数据
    for x, y in data_iter(batch_size, features, labels): #拿出x, y
      l = loss(net(x, w, b), y)#求损失,预测net,真实y
      l.sum().backward()#算梯度
      sgd([w, b], lr, batch_size)#使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1},loss {float(train_l.mean()):f}')


运行效果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/104591.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux:firewalld防火墙-基础使用(2)

上一章 Linux:firewalld防火墙-介绍(1)-CSDN博客https://blog.csdn.net/w14768855/article/details/133960695?spm1001.2014.3001.5501 我使用的系统为centos7 firewalld启动停止等操作 systemctl start firewalld 开启防火墙 systemct…

华为OD机考算法题:高效的任务规划

题目部分 题目高效的任务规划难度难题目说明 你有 n 台机器编号为 1 ~ n,每台都需要完成一项工作, 机器经过配置后都能独立完成一项工作。 假设第 i 台机器你需要花 分钟进行设置, 然后开始运行, 分钟后完成任务。 现在&#x…

报错:SSL routines:ssl3_get_record:wrong version number

一、问题描述 前后端联调的时候,连接后端本地服务器,接口一直pending调不通,控制台还报以下错误: 立马随手搜索了一下解决方案,但是emmm,不符合前端的实际情况: 二、解决方法: 实际…

IT行业职场走向,哪些方向更有就业前景?——IT行业的发展现状及趋势探析

文章目录 每日一句正能量前言IT技术发展背景及历程IT行业的就业方向有哪些?分享在IT行业的就业经历后记 每日一句正能量 如果你认为你自己无法控制自己的情绪,这就是一种极为严重的不良暗示。 前言 在信息量浩如烟海、星罗棋布的大数据时代,…

服务器动态/静态/住宅/原生IP都是什么意思

​  在互联网的世界中,我们经常会听到关于IP地址的各种说法,比如服务器动态IP、静态IP、住宅IP和原生IP。那么这些术语究竟代表着什么意思呢?让我们一起来了解一下。 动态IP 动态IP(Dynamic IP)是指互联网服务提供商(ISP)在每次用户上网时&#xff0c…

【C++】list的介绍及使用 | 模拟实现list(万字详解)

目录 一、list的介绍及使用 什么是list? list的基本操作 增删查改 获取list元素 不常见操作的使用说明 ​编辑 接合splice ​编辑 移除remove 去重unique 二、模拟实现list 大框架 构造函数 尾插push_back 迭代器__list_iterator list的迭代器要如何…

Vue 商场首页头部布局

封装基础网络请求,前后端联调请求后端接口 npm install axios -Ssrc/network/requestConfig.js import axios from axios; import store from "/store"; export function request(config){const instance axios.create({baseURL:"http://127.0.0.…

TensorFlow2从磁盘读取图片数据集的示例(tf.data.Dataset.list_files)

import os import warnings warnings.filterwarnings("ignore") import tensorflow as tf from tensorflow.keras.optimizers import Adam from tensorflow.keras.applications.resnet import ResNet50 from pathlib import Path import numpy as np#数据所在文件夹 …

外包干了3个月,技术退步明显。。。。。

先说一下自己的情况,本科生生,19年通过校招进入广州某软件公司,干了接近4年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测…

最简单一个跳板机实现

一、背景 生产环境有很多服务器需要让开发人员登录上去做一些基本的操作,比如查看日志什么的,一般小厂对安全要求不是很高,就会直接把ROOT账号给到所有开发人员,但这样当员工离职后,需要去修改ROOT密码是一件很麻烦的…

nodejs使用es-batis

使用方法 创建连接 因为它只支持非连接池所以每次都要创建连接 let dao new MySqlDaoContext({charset: "utf8",host: "localhost",user: "root",password: "root",database: "test",});await dao.initialize();dao in…

0022Java程序设计-ssm微信小程序社区互助平台

文章目录 **摘要**目录系统设计开发环境 摘要 首先,论文一开始便是清楚的论述了小程序的研究内容。其次剖析系统需求分析,弄明白“做什么”,分析包括业务分析、业务流程分析、用例分析,更进一步明确系统的需求。然后在明白了小程序的需求基础上,需要进一步地设计系…

微信批量添加好友,让你的人脉迅速增长

在这个数字化时代,微信作为中国最流行的社交平台之一,已经成为了人们生活中不可或缺的一部分。它的广泛使用为我们提供了无限的社交可能性。你是否曾为了扩大人脉圈子而犯愁?今天,我将向你揭示一个高效添加微信好友的秘密武器&…

(免费领源码)java#Springboot#mysql装修选购网站99192-计算机毕业设计项目选题推荐

摘 要 随着科学技术,计算机迅速的发展。在如今的社会中,市场上涌现出越来越多的新型的产品,人们有了不同种类的选择拥有产品的方式,而电子商务就是随着人们的需求和网络的发展涌动出的产物,电子商务网站是建立在企业与…

Oracle通过透明网关查询SQL Server 报错ORA-00904

Oracle通过透明网关查询SQL Server 报错ORA-00904 问题描述: 只有全表扫描SELECT * 时SQL语句可以正常执行 添加WHERE条件或指定列名查询,查询语句就报错 问题原因: 字段大小写和SQLSERVER中定义的不一致导致查询异常 解决办法: 给…

Java中JVM、JRE和JDK三者有什么区别和联系?

任何语言或者软件的运行都需要环境。就像人要生活在空气中,鱼要活在水中,喜阴植物就不能放在阳光下暴晒一样,任何对象个体的存在都离不开其所需要的环境,编程语言亦是一样的。 java 语言的开发运行,也离不开 Java 语言…

windows如何查看电脑IP地址

介绍两种查看电脑IP的方式 一、第一种方式 1、在电脑左下角搜索网络连接 2、鼠标右键你目前连接的网络(wifi就点wifi 网线就点以太网);选择里面的状态。 3、点击详细信息,这里的IPv4地址就是你电脑的IP。 二、第二种 1、win…

出差学小白知识No5:|Ubuntu上关联GitLab账号并下载项目(ssh key配置)

1 注冊自己的gitlab账户 有手就行 2 ubuntu安装git ,并查看版本 sudo apt-get install git git --version 3 vim ~/.ssh/config Host gitlab.example.com User your_username Port 22 IdentityFile ~/.ssh/id_rsa PreferredAuthentications publickey 替换gitl…

CC001:CC照片建模

摘要:CC照片建模原理是通过从图像中提取特征点和特征描述符,然后根据特征点的匹配来计算相机的位姿,从而生成三维点云数据。最后,借助网格重建和纹理映射的方法,将点云转换为带有纹理的三维网格模型。 实验数据&#x…

每日一题 2520. 统计能整除数字的位数(简单)

简单题频率好高,预测一波明天困难 class Solution:def countDigits(self, num: int) -> int:ans 0for i in str(num):if num % int(i) 0:ans 1return ans