【深度学习】2.单层感知机

目标:

实现一个简单的二分类模型的训练过程,通过模拟数据集进行训练和优化,训练目标是使模型能够根据输入特征正确分类数据。

演示:

1.通过PyTorch生成了一个模拟的二分类数据集,包括特征矩阵data_x和对应的标签数据data_y。标签数据通过基于特征的线性组合生成,并转换成独热编码的形式。

import torch
# 从torch库中导入神经网络模块nn,用于构建神经网络模型
from torch import nn
# 导入torch.nn模块中的functional子模块,可用于访问各种函数,例如激活函数
import torch.nn.functional as F

n_item = 1000
n_feature = 2
learning_rate = 0.01
epochs = 100


# 生成一个模拟的数据集,其中包括一个随机生成的特征矩阵data_x和相应生成的标签数据data_y。标签数据通过基于特征的线性组合生成,并且转换成独热编码的形式。


# 设置随机数生成器的种子为123,通过设置随机种子,我们可以确保在每次运行代码时生成的随机数相同,这对于结果的可重现性非常重要。
torch.manual_seed(123)
# 生成一个随机数矩阵data_x,其中包含n_item行和n_feature列。矩阵中的元素是从标准正态分布(均值为0,标准差为1)中随机采样的。
data_x = torch.randn(size=(n_item, n_feature)).float()
# torch.where(...): 根据条件返回两个张量中相应位置的值。如果条件成立,将为0,否则为1。  long(): 用于将张量转换为Long型数据类型。
data_y = torch.where(torch.subtract(data_x[:, 0]*0.5, data_x[:, 1]*1.5)+0.02 > 0, 0, 1).long()
# 将标签数据data_y转换为独热编码形式,即将每个标签转换为一个相应长度的独热向量
data_y = F.one_hot(data_y)

# print(data_x)
# print(data_y)

2.定义了一个简单的二分类模型BinaryClassificationModel,包含一个单层感知器(Single Perceptron)结构,其中使用了一个线性层和sigmoid激活函数,用于将输入特征映射到概率空间。

# 定义了一个简单的二分类模型,采用单层感知器的结构,包含一个线性层和sigmoid激活函数,用于将输入特征映射到概率空间。这样的模型可以用来对数据集进行二分类任务的预测。

# 定义了一个名为BinaryClassificationModel的类,其继承自nn.Module类,这意味着这个类是一个PyTorch模型。
class BinaryClassificationModel(nn.Module):
    def __init__(self, in_feature):
        # 调用了父类nn.Module的构造函数,确保正确初始化模型。
        super(BinaryClassificationModel, self).__init__()
        """single perception"""
        # 这行代码定义了模型的第一层,是一个线性层(Fully Connected Layer)。in_features参数指定输入特征的数量,out_features指定输出特征的数量,这里设置为2表示二分类问题。bias=True表示该层包含偏置项。
        self.layer_1 = nn.Linear(in_features=in_feature, out_features=2, bias=True)

# 定义模型前向传播的方法,即输入数据x通过模型前向计算得到输出。
    def forward(self, x):
        # 输入数据x首先通过定义的线性层self.layer_1进行线性变换,然后通过F.sigmoid()函数进行激活函数处理。
        return F.sigmoid(self.layer_1(x))

3.创建了该二分类模型的实例model、使用随机梯度下降(SGD)优化器opt、以及二分类问题常用的损失函数BCELoss(Binary Cross Entropy Loss)。

4.在训练过程中,通过多个epoch和每个样本的批处理(在这里是一次处理一个样本),计算模型预测输出和真实标签之间的损失值,进行反向传播计算梯度,并更新模型参数以最小化损失函数。

# 完成对模型的训练过程,每个epoch中通过优化器进行参数更新,计算损失,反向传播更新梯度。最终我们会得到训练过程中每个epoch的损失值,并可以观察损失的变化情况。

# 创建了一个二分类模型实例model,参数n_feature表示输入特征的数量。
model = BinaryClassificationModel(n_feature)
# 创建了一个随机梯度下降(SGD)优化器opt,用于根据计算出的梯度更新模型参数。
opt = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 创建了一个二分类问题常用的损失函数BCELoss(Binary Cross Entropy Loss),用于衡量模型输出与真实标签之间的差异。
criteria = nn.BCELoss()

for epoch in range(epochs):
    # 对每个样本进行训练。
    for step in range(n_item):
        x = data_x[step]
        y = data_y[step]
        # 梯度清零,避免梯度累加影响优化结果。
        opt.zero_grad()
        # 将输入特征x通过模型前向传播得到预测输出y_hat。unsqueeze(0)是因为我们的模型期望输入是(batch_size, n_feature)的形式。
        y_hat = model(x.unsqueeze(0))
        # 计算预测输出y_hat和真实标签y之间的损失值。
        loss = criteria(y_hat, y.unsqueeze(0).float())
        # 反向传播计算梯度。
        loss.backward()
        # 根据计算出的梯度更新模型参数。
        opt.step()
        
    print("Epoch: %03d, Loss: %.3f" % (epoch, loss.item()))

5.打印出每个epoch的序号和损失值,用于监控训练过程中损失值的变化情况。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/645897.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

UniApp 2.0可视化开发工具:引领前端开发新纪元

一、引言 在移动互联网迅猛发展的今天,移动应用开发已经成为前端开发的重要方向之一。为了简化移动应用开发流程,提高开发效率,各大开发平台不断推出新的工具和框架。UniApp作为一款跨平台的移动应用开发框架,自诞生以来就备受开…

航运复兴?大摩不信!

大摩认为,从供需关系来看红海危机只是推迟了航运业下行周期的到来,一旦干扰消除,行业可能重回周期性低迷。 红海危机加剧运力紧张,航运市场价格飞涨。 大摩在24日的一份报告中指出,受红海危机干扰航运市场运力&#…

如何恢复未保存或丢失的Word文档?

许多用户会遇到Word文档未保存而关闭的问题。实际上,您不会立即丢失未保存的文档数据。请不要对文档进行进一步的更改,例如修改并再次保存。您仍然有机会恢复未保存的Word文档。有一些方法可以帮助您恢复未保存的 Word 文档。 如果您不幸遇到这样的问题…

基于单片机设计的多功能数字电压表开发

摘 要:在电路设计中我们时常会用到电压表,过去大部分电压表还是模拟的,虽然精度较高但模拟电压表采用用指针式,里面是磁电或电磁式结构,所以响应较慢。为适应许多高速信号领域目前已广泛使用数字电压表。数字电压表的诞…

各大排序算法

目录 插入排序 希尔排序(缩小增量排序) 冒泡排序 快速排序 选择排序 归并排序 插入排序 插入排序的基本思想是,将N个待排序元素分为一组有序表和一个无序表,一开始有序表只有一个元素,无序表中有N-1个元素,排序过程中每次取无序表的第一个元素依次与有序表的元素进行对比…

Leetcode260

260. 只出现一次的数字 III - 力扣(LeetCode) class Solution {public int[] singleNumber(int[] nums) {//通过异或操作,使得最终结果为两个只出现一次的元素的异或值int filterResult 0;for(int num:nums){filterResult^num;}//计算首个1(从右侧开始)…

c++ 虚函数常见问题

1 虚函数,虚表基础 虚函数,虚表基础 2 虚函数表保存在哪里 ? 虚函数表在编译的时候确定。在 linux 下,保存在只读数据段的重定位段,这个段的名字是 .data.rel.ro。 如下代码,编译之后,使用 readelf -t a…

vue3 vite项目配置了proxy代理情况下查看真实的接口调用地址

vite配置了proxy代理情况下如何查看真实的接口调用地址? 使用vite进行代理 在vite.config.ts配置了代理 在浏览器查看请求头和响应头发现只有代理前的url,没有显示代理后的路径 然后发现一个bypass函数,但是此函数只能修改res响应头的数据&#xff0…

Visual Studio 的调试(一)

最近事儿很多昂,更新速度相较以往慢了许多,备考六月份的四级,还有学校的期末等等,事儿真的太多啦,所以后面的更新速度也会放慢一点,实在是抽不开身啊诸位,相当抱歉,还望诸君见谅 言…

原哥花了1个多月的时间终于开发了一款基于android studio的原生商城app

大概讲一下这个app实现的功能和前后端技术架构。 功能简介 广告展示商品展示跳转淘宝联盟优惠卷购买发布朋友圈宝妈知识资讯商品搜索朋友圈展示/点赞/评论登陆注册版本升级我的个人资料商品和资讯收藏我的朋友圈意见反馈 安卓端技术选型 Arouter组件化daggerrxjavaretrofit…

遇见问题-VMware虚拟机打开运行一段时间后卡死

1.问题原因 因为Windows自带的虚拟化软件Hyper-V与VMware存在冲突。 2.关闭Hyper-V 1.打开【控制面板】-【程序和功能】-【启用或关闭Windows功能】3.关闭HV主机服务 1.右击计算机-》管理-》服务和应用名称-》服务-》找到HV主机服务-》右击属性停止服务 -》启动类型设置为禁…

专业渗透测试 Phpsploit-Framework(PSF)框架软件小白入门教程(十三)

本系列课程,将重点讲解Phpsploit-Framework框架软件的基础使用! 本文章仅提供学习,切勿将其用于不法手段! 接上一篇文章内容,讲述如何进行Phpsploit-Framework软件的基础使用和二次开发。 我们,继续讲一…

乡村振兴的农业品牌建设:打造农业品牌,提升农产品附加值,增强乡村经济竞争力,实现美丽乡村经济繁荣

目录 一、引言 二、农业品牌建设的重要性 (一)提升农产品附加值 (二)增强乡村经济竞争力 (三)实现美丽乡村经济繁荣 三、农业品牌建设的现状及问题 (一)现状 (二…

python-10(爬虫)

1.网络爬虫 1.1.引言 我们平时都说Python爬虫,其实这里可能有个误解,爬虫并不是Python独有的,可以做爬虫的语言有很多例如:PHP、JAVA、C#、C、Python。 为什么Python的爬虫技术会异军突起呢? Python火并不是因为爬…

兵器室管控系统|DW-306是一套成熟系统

概述 智慧兵器室管理系统(DW-S306)是依托互3D技术、大数据、RFID技术、数据库技术、AI、视频分析技术对RFID智能仓库进行统一管理、分析的信息化、智能化、规范化的系统。 本解决方案利用现有内部网络,部署部队智能兵器室管理系统&#xff…

使用Django框架搭建Web应用

文章目录 简介安装Django创建一个Django项目创建一个Django应用编写视图配置URL运行开发服务器总结与拓展数据库集成管理后台表单处理模板引擎安全性 简介 Django 是一款基于 Python 语言的开源 Web 应用框架,采用了 MVC(模型-视图-控制器)设…

数据库系统原理实验报告6 | 视图

整理自博主本科《数据库系统原理》专业课自己完成的实验报告,以便各位学习数据库系统概论的小伙伴们参考、学习。 专业课本: ​ ———— 本次实验使用到的图形化工具:Heidisql ​ 目录 一、实验目的 二、实验内容 1.根据EDUC数…

Matlab中函数或变量 ‘eeglab‘ 无法识别

EEGLAB 没有安装或添加到 MATLAB 路径中: 确保已经安装了 EEGLAB,并且将其添加到 MATLAB 的路径中。您可以通过在 MATLAB 命令窗口中运行 which eeglab 来检查是否能够找到 EEGLAB。 EEGLAB 函数路径设置错误: 如果已经安装了 EEGLAB&#x…

信息系统项目管理师0131:输出(8项目整合管理—8.7监控项目工作—8.7.3输出)

点击查看专栏目录 文章目录 8.7.3 输出8.7.3 输出 工作绩效报告工作绩效信息可以用实体或电子形式加以合并、记录和分发。基于工作绩效信息,以实体或电子形式编制形成工作绩效报告,以制定决策、采取行动或引起关注。根据项目沟通管理计划,通过沟通过程向项目干系人发送工作绩…

Top期刊:针对论文Figure图片的7个改进建议

我是娜姐 迪娜学姐 ,一个SCI医学期刊编辑,探索用AI工具提效论文写作和发表。 通过对来自细胞生物学、生理学和植物学领域的580篇论文,进行检查和归纳总结,来自德国德累斯顿工业大学的Helena Jambor及合作者,在PLOS Bio…