Pytorch深度学习实践笔记6(b站刘二大人)

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili​

目录

1 Logistic Regression

2 整体流程

3 Softmax 多分类

4 交叉熵crossEntropy

5 Logistic Regression 代码


1 Logistic Regression


区别于Linear Regression,加入了激活函数,引入非线性。


一个二分类问题:
给定学习时长x,y为考试是否可以通过,通过为1,不通过为0


使用二分类激活函数sigmod,softmax多分类在只有两个分类时,也可以变成sigmoid


2 整体流程

  1. 数据准备:

  • 模型建立:

  • Loss建立



 

  • mini-Batch Loss for Binary Classification


3 Softmax 多分类

一文彻底搞懂 Softmax 函数,数学原理分析和 PyTorch 验证​



softmax就是将一些大的数字拉伸到0~1之间,而且使得大的数所占的比例更大,小的数所占的比例更小,这样如果每一个原始的数据代表score的话,将其总分控制在0~1之间,可以进一步使用交叉熵函数来计算loss。
这里Softmax会带来一些数据上溢和下溢问题,上溢问题可以减去max(y),下溢问题可以使用log(y)等手段来解决。


一个简单的softmax的例子:

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>

#define SIZE (4)

void softMax(double *src,double *dst,const int N)
{
    double exp_score[N];
    double sum = 0.f;
    
    for (int i = 0;i < N;i++)
    {
        exp_score[i] = exp(src[i]);
        sum += exp_score[i];
    }
    for (int i = 0;i < N;i++)
    {
        dst[i] = exp_score[i] / sum; 
    }

}

int main(int argc, char *argv[])
{

    double score[SIZE] = {2.1,2.0,0.1,3.7};

    double probability[SIZE] = {0.f};

    softMax(score,probability,SIZE);

    printf("Probability:[");
    for (int i = 0;i < SIZE;i++)
    {
        printf(" %f ",probability[i]);
    }
    printf(" ]\n");

    return 0;
}


4 交叉熵crossEntropy


计算两个概率分布之间的差异的,由于Softmax和Sigmod都将score转化为了概率分布,因此可以将交叉熵作为损失函数来计算y_true和y_pred之间的差异。
 

softmax分类器和交叉熵损失函数




一个多分类的交叉熵例子:

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>

#define M (3)
#define N (4)

double crossEntropy(double *y_true,double *y_pred,int m,int n)
{
    // 防止log(0)的情况,对预测概率分布进行微小的修正
    for (int i = 0;i < m;i++)
    {
        for (int j = 0;j < n;j++)
        {
            y_pred[i*n+j] = y_pred[i*n+j] <= 0 ? (1e-7) : (y_pred[i*n+j] > 1 ? (1) : y_pred[i*n+j]);
        }
        
    }

    double sum = 0.f;
    for (int i = 0;i < m;i++)
    {
        for (int j = 0;j < n;j++)
        {
            sum += -y_true[i*n+j] * log2(y_pred[i*n+j]);
        }
        
    }

    return (sum / (double)m);
}

int main(int argc, char *argv[])
{
    // one-hot 编码 代表标签分类 例如: 猫 狗 鸟 猪
    // 真实概率分布
    double probability[M][N] = {{1,0,0,0},{0,1,0,0},{0,0,1,0}};

    // double y_pred[M][N] = {{0.7,0.2,0.05,0.05},{0,1,0,0},{0.1,0.1,0.8,0}};

    double y_pred[M][N] = {{0.98,0.02,0,0},{0,1,0,0},{0,0,1,0}};

    double loss = crossEntropy(&probability[0][0],&y_pred[0][0],M,N);
    printf("loss: % .6lf \n",loss);

    return 0;
}



5 Logistic Regression 代码

pytorch nn.BCELoss()详解​

torch.empty()和torch.Tensor.random_()的使用举例-CSDN博客​

import torch

# prepare dataset
# 0 1 为分类标签
# x_data 输入 学习的时间
# y_data 输出 考试通过与否(0 or 1)
x_data = torch.Tensor([[1.0], [2.0], [3.0],[1.1],[2.1],[1.5],[2.3],[4.1]])
y_data = torch.Tensor([[0], [0], [1],[0],[0],[0],[0],[1]])
 
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1,1)
 
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred
model = LogisticRegressionModel()
 
# construct loss and optimizer
criterion = torch.nn.BCELoss(size_average = True) 
optimizer = torch.optim.Adam(model.parameters(), lr = 0.05)
 
# training cycle forward, backward, update
for epoch in range(10000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())
 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (loss < 1e-7):
        break
 
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# test dataset
x_test = torch.Tensor([[4.0],[5.0],[1.5],[2.5],[2.9],[3.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/645213.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Spring Security + OAuth2】OAuth2

Spring Security OAuth2 第一章 Spring Security 快速入门 第二章 Spring Security 自定义配置 第三章 Spring Security 前后端分离配置 第四章 Spring Security 身份认证 第五章 Spring Security 授权 第六章 OAuth2 文章目录 Spring Security OAuth21、OAuth2简介1.1、OAu…

绘唐科技绘唐ai工具邀请码

绘唐科技绘唐ai工具邀请码 绘唐AI工具 https://qvfbz6lhqnd.feishu.cn/wiki/QBr4wOAz2ilF4NknrqbcoKRhn2c TensorFlow是一个开源的机器学习框架,由Google开发并维护。它提供了一个灵活且高效的接口,用于构建和训练各种机器学习模型。 TensorFlow的基本概念包括: 1. 张量(…

基于python的网页自动刷新工具

1.下载webdriver https://msedgewebdriverstorage.z22.web.core.windows.net/?prefix122.0.2365.59/下载Edge的浏览器驱动 2.安装selenium pip install selenium4.11.1 3.写代码 # -*- coding: utf-8 -*- import tkinter as tk from tkinter import messagebox import thr…

当标签中出现输入了字母或者数字直接在一行上,没有换行的 情况时怎么办

当标签块中输入的是包含字母或者数字的时候&#xff0c;他不会换行&#xff0c;在一行上显示滚动条的形式&#xff0c;而我们想让他走正常文档流&#xff0c;该换行的时候换行 想要的如下效果 给相应的元素块添加该代码即可 word-break: break-all; .card-content { …

uniapp使用uni.chooseImage选择图片后对其是否符合所需的图片大小和类型进行校验

uni.chooseImage的返回值在H5平台和其他平台的返回值有所差异&#xff0c;具体差异看下图 根据图片可以看出要想判断上传的文件类型是不能直接使用type进行判断的&#xff0c;所以我使用截取字符串的形式来判断&#xff0c;当前上传图片的后缀名是否符合所需要求。 要求&#…

牛客网刷题 | BC97 回文对称数

目前主要分为三个专栏&#xff0c;后续还会添加&#xff1a; 专栏如下&#xff1a; C语言刷题解析 C语言系列文章 我的成长经历 感谢阅读&#xff01; 初来乍到&#xff0c;如有错误请指出&#xff0c;感谢&#xff01; 描述 今天牛牛学到了回文…

电子电器架构 - AUTOSAR软件架构介绍

电子电器架构 - AUTOSAR软件架构介绍 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明自己…

Nginx代理配置(专业版)

写在前面提醒&#xff1a;使用代理&#xff0c;如果可以&#xff0c;请尽量支持双协议&#xff0c;http、https均要支持哈。 注意&#xff1a;监控系统只是运行代码&#xff0c;是否支持https&#xff0c;需要运维同学在你们的服务器上配置https证书&#xff0c;配置好证书&…

关于构建生成式AI产品的思考

在过去的六个月里&#xff0c;我们 LinkedIn 的团队一直在努力开发一种新的人工智能体验。我们希望重新构想我们的会员如何进行求职和浏览专业内容。 生成式人工智能的爆炸式增长让我们停下来思考一年前不可能实现的事情。我们尝试了许多想法&#xff0c;但都没有真正实现&…

OpenAI模型GPT-4o、GPT-4、Gemini 1.5性能比较

大家好&#xff0c;OpenAI最新推出的GPT-4o&#xff0c;标志着人工智能语言模型和交互方式迈入了新纪元。最引人注目的是&#xff0c;GPT-4o支持实时互动和流畅的对话切换&#xff0c;让交流更加自然。 本文将对比分析GPT-4o、GPT 4以及谷歌的Gemini和Unicorn模型&#xff0c;…

LabelMe下载及关键点检测数据标注

本文关键点数据集链接,提取码:x1pk 1.LabelMe下载 这部分内容和YOLOv8_seg的标注软件是一样的,使用anaconda创建虚拟环境安装LabelMe,指令如下: conda create -n labelme python=3.6 -y conda activate labelme conda install pyqt conda install pillow pip install la…

第六节:带你全面理解vue3 浅层响应式API: shallowRef, shallowReactive, shallowReadonly

前言 前面两章,给大家讲解了vue3中ref, reactive,readonly创建响应式数据的API, 以及常用的计算属性computed, 侦听器watch,watchEffect的使用 其中reactive, ref, readonly创建的响应式数据都是深层响应. 而本章主要给大家讲解以上三个API 对应的创建浅层响应式数据的 API,…

使用Java和XxlCrawler获取各城市月度天气情况实践

目录 前言 一、历史数据获取 1、关于天气后报 2、信息界面分析 二、数据的提取开发 1、PageVo的定义 2、属性定义 3、实际信息抓取 三、信息抓取调试以及可能的问题 1、信息获取成果 2、关于超时的问题 四、总结 前言 这篇文章主要来源于一个我们家小朋友的一个作业…

数据库管理-第194期 网络加速RDMA初探(20240526)

数据库管理194期 2024-05-26 数据库管理-第194期 网络加速RDMA初探&#xff08;20240526&#xff09;1 概念2 发展3 使用总结 数据库管理-第194期 网络加速RDMA初探&#xff08;20240526&#xff09; 作者&#xff1a;胖头鱼的鱼缸&#xff08;尹海文&#xff09; Oracle ACE A…

GCANet去雾算法

目录 1. 引言 2. 门控上下文注意机制&#xff08;GCA&#xff09; 3. 去雾流程 4. 模型代码 5. GCANet的优势 6. 去雾效果 1. 引言 GCANet(Gate-Controlled Attention Network)是一种用于图像去雾的深度学习算法&#xff0c;通过引入注意力机制来改进传统的去雾方法&…

使OpenCV可以读取中文路径图片的方法

一&#xff0e;问题复现 1.代码 #! /usr/bin/env python # -*- coding: utf-8 -*-# File: show_img.pyimport cv2# 读取图片 img cv2.imread("车牌素材/冀A.png")# 显示图片 cv2.imshow("img", img) cv2.waitKey(0)2.报错截图 3.报错内容 [ WARN:00.05…

简单好用的文本识别方法--付费的好用,免费的更有性价比

文章目录 先说付费的进入真题&#xff0c;免费的来喏&#xff01;PixPin微信 先说付费的 直达网址!!! 进入真题&#xff0c;免费的来喏&#xff01; PixPin 商店里就有 使用示例&#xff1a; 可以看到&#xff1a;贴在桌面上的图片可以复制图片中的文字&#xff0c;真的很…

第一篇【传奇开心果系列】Python的跨平台开发工具beeware技术点案例示例:使用beeware实现跨平台开发,从hello world开始

传奇开心果博文系列 系列博文目录Python的跨平台开发工具beeware技术点案例示例系列 博文目录前言一、BeeWare套件主要功能介绍二、Toga相对于其他Python UI库具有的优势介绍三、使用toga开发安卓手机应用hello world步骤和示例代码四、使用toga写一个iOS 苹果手机应用hello wo…

红蓝对抗-HW红蓝队基本知识(网络安全学习路线笔记)

第一, 什么是蓝队 蓝队&#xff0c;一般是指网络实战攻防演习中的攻击一方。 蓝队一般会采用针对目标单位的从业人员&#xff0c;以及目标系统所在网络内的软件、硬件设备同时执行多角度、全方位、对抗性的混合式模拟攻击手段&#xff1b;通过技术手段实现系统提权、控制业务、…

高速数据采集与传输(一):ADC08D500调研

前言&#xff1a;高速ADC数据采集的应用和开发&#xff0c;涉及的技术面非常的广泛&#xff0c;后续阶段博主将尝试以纯项目开发的形式做一次专题技术分享&#xff0c;将基于高速数据采集的相关内容进行一系列的技术文档更新。博主全凭兴趣在更新和总结&#xff0c;很难做到一直…