LLM - Llama 3 的 Pre/Post Training 阶段 Loss 以及 logits 和 logps 概念

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145056912


Llama 3 是 Meta 公司发布的开源大型语言模型,包括具有 80 亿和 700 亿参数的预训练和指令微调的语言模型,支持广泛的应用场景。在多个行业标准基准测试中展示了最先进的性能,特别是在推理、代码生成和指令遵循方面表现出色,超过了同等规模的商业模型。

Llama 3 Paper: The Llama 3 Herd of Models,Llama3 模型群

参考:大模型训练 RLHF 阶段的 PPO/DPO 策略公式与源码

1. Llama 3 Loss

Llama 3 主要包括 2个阶段,即 预训练阶段(Pre-Training) 和 后训练阶段(Post-Training)。

Llama 3 的网络架构,如下:

Llama 3

1.1 预训练阶段(Pre-Training)

预训练阶段(Pre-Training),包括:

  1. 初始预训练(Initial Pre-Training):使用 AdamW 优化器对 Llama 3 405B 进行预训练,峰值学习率为 8 × 1 0 − 5 8 × 10^{−5} 8×105;线性预热步数为 8,000 步,采用余弦学习率调度,在 1,200,000 步内衰减至 8 × 1 0 − 7 8 × 10^{−7} 8×107 ;为了提高训练稳定性,在训练初期使用较小的 BatchSize,在后续逐步增加以提高效率。具体:
    1. 初始使用 Batch Size 是 4M Tokens 和 序列(Sequences)长度是 4,096 (4K) Tokens,预训练 252M 个 Token。
    2. 将 Batch Size 和序列长度,增加至 8M Tokens (batch size) 和 8,192 (8K) Tokens (sequences),预训练 2.87T 个 Token。
    3. 再次将批量大小增加到 16M。
    4. 降低到损失值的突刺(Spikes),不需要进行干预以纠正模型训练的发散(Divergence)。
  2. 长上下文预训练(Long Context Pre-Training):使用 800B 个训练 Token,上下文长度增加到 6 个阶段,从 8K 的上下文窗口开始,达到 128K。在长上下文预训练中,自注意力层的计算量,随着序列长度的平方增长。评估模型适应长上下文的标准:
    1. 模型在 短上下文 评估中的性能是否完全恢复。
    2. 在特定长度中,模型是否能够完美解决 大海捞针(needle in a haystack) 任务。
  3. 退火阶段(Annealing):最后 400M 个 token 预训练,线性的(Linearly) 将学习率退火到 0,同时,保持上下文长度 128K Token。使用少量高质量的代码和数学数据,进行退火,提高预训练模型在关键基准测试上的性能。退火阶段,在 8B 模型中效果明显,但是,在 405B 模型中改进较小。

损失函数都是 交叉熵损失,其中 w i w_{i} wi 是第 i i i 个词, N N N 是序列长度, p p p 是概率(softmax), C C C 是 类别数(也就是输出维度),即:

L o s s = − 1 N ∑ i = 1 N l o g   p ( w i ∣ w 1 , w 2 , . . . , w i − 1 ) L o s s = − 1 N ∑ i = 1 N ∑ j = 1 C y i j l o g ( p i j ) \begin{align} Loss &= -\frac{1}{N}\sum_{i=1}^{N}log \ p(w_{i}|w_{1},w_{2},...,w_{i-1}) \\ Loss &= -\frac{1}{N}\sum_{i=1}^{N}\sum_{j=1}^{C}y_{ij}log(p_{ij}) \end{align} LossLoss=N1i=1Nlog p(wiw1,w2,...,wi1)=N1i=1Nj=1Cyijlog(pij)

这 2 个公式的含义是一样的,第 1 个是选择 w i w_{i} wi ,第 2 个是通过 y i j = 1 y_{ij}=1 yij=1 确定 w i w_{i} wi

Post-Training

1.2 后训练阶段(Post-Training)

后训练阶段(Post-Training),包括:

  1. 奖励模型(Reward Modeling, RM):偏好数据(Preference data) 包含 3 个经过排序的响应,即 编辑后(edited) > 所选(chosen) > 被拒(rejected)。

Reward Model 训练 Loss 函数:
L o s s = − l o g   σ ( r ϕ ( x , y w i n ) − r ϕ ( x , y l o s s ) ) Loss = -log \ \sigma(r_{\phi}(x,y_{win}) - r_{\phi}(x,y_{loss})) Loss=log σ(rϕ(x,ywin)rϕ(x,yloss))
源码:

loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
  1. 监督微调 (Supervised Finetuning, SFT):使用奖励模型对于人工标注提示进行拒绝采样(Rejection Sampling)。 SFT 的 Loss 函数与 PreTraining 阶段一致,数据略有不同。

  2. 直接偏好优化 (Direct Preference Optimization, DPO),其中,在 Llama3 中,学习率 L R = 1 0 − 5 LR=10^{-5} LR=105,超参数 β = 0.1 \beta=0.1 β=0.1 α = 0.2 \alpha=0.2 α=0.2,即:

L o s s D P O = L o s s D P O + α L o s s N L L = − l o g   σ ( β l o g π θ ( y w i n ∣ x ) π r e f ( y w i n ∣ x ) − β l o g π θ ( y l o s e ∣ x ) π r e f ( y l o s e ∣ x ) ) − α l o g   π θ ( y w i n ∣ x ) \begin{align} Loss_{DPO} &= Loss_{DPO} + \alpha Loss_{NLL} \\ &= - log \ \sigma(\beta log\frac{\pi_{\theta}(y_{win}|x)}{\pi_{ref}(y_{win}|x)} - \beta log\frac{\pi_{\theta}(y_{lose}|x)}{\pi_{ref}(y_{lose}|x)}) - \alpha log \ \pi_{\theta}(y_{win}|x) \end{align} LossDPO=LossDPO+αLossNLL=log σ(βlogπref(ywinx)πθ(ywinx)βlogπref(ylosex)πθ(ylosex))αlog πθ(ywinx)

DPO 训练的改进如下:

  1. 在 DPO 损失中,在 所选(chosen) 和 被拒(rejected) 的响应中,屏蔽 特殊格式令牌(Special Formatting Tokens)。学习这些标记,导致模型出现不期望的行为,例如尾部重复或突然生成终止标记。可能是由于 DPO 损失的对比性质,在选定和拒绝的响应中都存在常见标记,这导致了冲突的学习目标,因为模型需要同时增加和减少这些标记的出现概率。
  2. DPO 损失 增加 NLL (Negative Log-Likelihood, 负对数似然) 损失,避免 所选(chosen) 响应的对数概率下降。参考论文 Iterative Reasoning Preference Optimization 。
  1. 模型平均(Model Averaging):训练多个模型,在不同的数据集、不同的初始化、不同的学习率或其他超参数设置下进行训练,对于模型权重进行平均,可以使用简单的算术平均,也可以使用加权平均。
  2. 迭代轮次(Iterative Rounds):划分 6 个轮次应用上述方法(RM、SFT、DPO、MA)。在每个轮次中,收集新的偏好标注和 SFT数据,以及从最新模型中采样合成数据。

Llama 3 的网络参数:

Llama 3


2. 补充内容

前置概念:

  1. PreTraining 阶段 与 SFT 阶段 的差异
  2. logits 含义
  3. logps 含义

1.1 PreTraining 阶段 与 SFT 阶段 的差异

PreTraining 与 SFT 在训练过程中,没有任何区别,主要区别在于数据的组成形式上,包括 6 点,即:

  1. PreTraining 样本数据都是满编 4K / 8K;SFT 样本数据保持不变,原始多长就是多长。
  2. SFT 引入 PreTraining 阶段中未见过的 special_token,让模型学习全新的语义。
  3. SFT 让模型 重新学习 最重要的 eos_token,停止生成;PreTraining 阶段 eos_token 只是作为样本的一部分,无法停止生成。
  4. SFT 借助 special_token,把语料切分成不同的角色,标配包括 systemuserassistant,根据业务需求也可以自定义。
  5. SFT 的 prompt 部分不做 loss 反传,原因是 prompt 的同质化严重,如果不做 loss_mask,同样的一句话会被翻来覆去的学。如果保证每条 prompt 都是独一无二的,也可以省略 promptloss_mask 部分。
  6. SFT 的 session 数据(多轮对话),可以每一个 answer 都计算 loss,也可以只对最后一轮的 answer 计算 loss

两者的训练目的也不一样,PreTraining 是在背书,纯粹的学习知识;SFT 则是在做题,学习的是指令 follow 能力。

整体的数据源码,参考 Llama-Factory 的 src/llamafactory/data/preprocess.py 文件,包括不同阶段的数据处理,即 ptsftrmkto 等。

其中,PreTraining 数据源码(src/llamafactory/data/processors/pretrain.py),

def preprocess_pretrain_dataset(
    examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
    # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
    eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
    text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
    
    tokenized_examples = tokenizer(text_examples, add_special_tokens=False)  # add_special_tokens=False
    concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
    total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
    block_size = data_args.cutoff_len
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    return result

其中,SFT 的 prompt mask 源码(src/llamafactory/data/processors/supervised.py),参考:

IGNORE_INDEX = -100

if train_on_prompt:
    source_label = source_ids
elif template.efficient_eos:
    # mask 掉 prompt(source) 部分,保留 answer(target) 部分
    source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
    source_label = [IGNORE_INDEX] * source_len

if mask_history and turn_idx != 0:  # train on the last turn only
    # 只训练最后 1 轮
    target_label = [IGNORE_INDEX] * target_len
else:
    target_label = target_ids

PreTraining 阶段的数据

  • 数据组成形式:
    • 输入 input: <bos> X1 X2 X3
    • 标签 labels:X1 X2 X3 </s>
  • 典型的 Decoder 架构的数据训练方式

SFT 阶段的数据

  • 数据组成形式:
    • 输入 input:<bos> prompt response
    • 标签 labels: -100 ... -100 response </s>
  • labels 的重点在于prompt部分的被 -100 所填充

训练过程源码,参考 transformers/src/transformers/models/llama/modeling_llama.py,即:

logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    shift_logits = shift_logits.view(-1, self.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)

其中,shift_logitsshift_labels 数据:

输入文本<bos><eos>
shift_logits<bos>
shift_labels<eos>

1.2 logits 含义

模型输出的数值部分,被称为 logits,其英文含义是 log-odds,即对数几率,即:logits = self.lm_head(hidden_states)

  • logits 是模型在 应用激活函数(如 softmax)之前 的原始输出。假设有一个分类问题,模型的输出层有多个神经元,每个神经元对应一个类别。这些神经元的输出值就是 logits。例如,对于一个三分类问题,模型输出层的三个神经元输出值可能是[2.1, -1.5, 0.7],这就是 logits。
  • logits 可以看作是每个类别的 对数几率,表示某个事件发生的几率的对数。在分类问题中,假设某个类别发生的概率是 p,那么该类别的几率是 p/(1 - p)。对数几率就是 log(p/(1 - p))。模型输出的 logits 通过一些变换(如 softmax 函数)来近似地表示这种对数几率。在 softmax 函数中,使用 e 为底数,将 对数几率(logits) 转换为概率分布,使得输出值在 0 到 1 之间,并且所有类别的概率之和为 1。

Softmax 公式,其中 x 就是 logits ,即:
s o f t m a x ( X ) = e x ∑ i = 1 n e x i softmax(X) = \frac{e^x}{\sum_{i=1}^{n}e^{x_{i}}} softmax(X)=i=1nexiex
Sigmoid 公式,其中 x 就是 logits ,即:
σ ( x ) = 1 1 + e − x \sigma(x)=\frac{1}{1+e^{-x}} σ(x)=1+ex1
例如源码:

logits = self.lm_head(hidden_states)
logits = logits.float()

1.3 logps 含义

logpslog probabilities 的缩写,即对数概率。具体来说,logps 是模型输出的对数概率分布,通过对 logits 应用 softmax 函数,并且取 对数(log) 得到的。

源码测试:

  1. 初始化数据 logitslabels
import torch
from torch import nn
import torch.nn.functional as F
torch.manual_seed(42)

logits = torch.randn(3, 4)  # 模拟 logits,3个样本,维度是4
# 真实标签,即4个维度中,正确的是0位置、1位置、1位置,也就是说让 logits 的这些位置的对数几率最大。
labels = torch.tensor([0, 1, 1])  
print(f"[Info] logits: {logits}")
print(f"[Info] labels: {labels}")

[Info] logits: tensor([[ 0.3367,  0.1288,  0.2345,  0.2303],
        [-1.1229, -0.1863,  2.2082, -0.6380],
        [ 0.4617,  0.2674,  0.5349,  0.8094]])
[Info] labels: tensor([0, 1, 1])
  1. 计算 logps 使用 log_softmax ,即 softmax -> log,先概率化,再转换成对数(负值)
logps = torch.log(nn.Softmax(dim=1)(logits))
print(f"[Info] logps1: {logps}")
logps = F.log_softmax(logits, dim=1)
print(f"[Info] logps2: {logps}")

[Info] logps1: tensor([[-1.2849, -1.4928, -1.3871, -1.3912],
        [-3.5008, -2.5643, -0.1698, -3.0160],
        [-1.4621, -1.6565, -1.3889, -1.1144]])
[Info] logps2: tensor([[-1.2849, -1.4928, -1.3871, -1.3912],
        [-3.5008, -2.5643, -0.1698, -3.0160],
        [-1.4621, -1.6565, -1.3889, -1.1144]])
  1. 再调用 负对数似然(Negative Log-Likelihood,NLL) Loss,把多个样本的 logps 均值,再计算负值,因为负负得正,即 loss 值(正值),优化越来越小
v_nll = - (-1.2849 + -2.5643 + -1.6565) / 3  # 0,1,1 位置
print(f"[Info] NLL1: {v_nll}")
v_nll = nn.NLLLoss()(torch.log(nn.Softmax(dim=1)(logits)), labels)
print(f"[Info] NLL2: {v_nll}")

[Info] NLL1: 1.8352333333333333
[Info] NLL2: 1.8352112770080566
  1. 直接使用 CrossEntropyLoss 函数,代替这些操作,即 CrossEntropyLoss = Softmax+log+NLLLoss
v_nll = nn.CrossEntropyLoss()(logits, labels)
print(f"[Info] NLL3: {v_ce}")

[Info] NLL3: 1.835211157798767

参考 PPL (Perplexities,困惑度) 源码 (scripts/stat_utils/cal_ppl.py):

outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item()
perplexities.extend(sentence_logps.exp().tolist())

参考:

  • 知乎 - 详解 PyTorch 的损失函数:NLLLoss()和CrossEntropyLoss()
  • 知乎 - 大模型Pretrain和SFT阶段的Loss分析
  • 大模型训练(SFT)实践总结

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/951561.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【python基础——异常BUG】

什么是异常(BUG) 检测到错误,py编译器无法继续执行,反而出现错误提示 如果遇到错误能继续执行,那么就捕获(try) 1.得到异常:try的执行,try内只可以捕获一个异常 2.预案执行:except后面的语句 3.传入异常:except … as uestcprint(uestc) 4.没有异常:else… 5.鉴定完毕,收尾的语…

(长期更新)《零基础入门 ArcGIS(ArcMap) 》实验六----流域综合处理(超超超详细!!!)

流域综合处理 流域综合治理是根据流域自然和社会经济状况及区域国民经济发展的要求,以流域水流失治理为中心,以提高生态经济效益和社会经济持续发展为目标,以基本农田优化结构和高效利用及植被建设为重点,建立具有水土保持兼高效生态经济功能的半山区流域综合治理模式。数字高程…

设计模式与游戏完美开发(3)

更多内容可以浏览本人博客&#xff1a;https://azureblog.cn/ &#x1f60a; 该文章主体内容来自《设计模式与游戏完美开发》—蔡升达 第二篇 基础系统 第五章 获取游戏服务的唯一对象——单例模式&#xff08;Singleton&#xff09; 游戏实现中的唯一对象 在游戏开发过程中…

VSCode 在Windows下开发时使用Cmake Tools时输出Log乱码以及CPP文件乱码的终极解决方案

在Windows11上使用VSCode开发C程序的时候&#xff0c;由于使用到了Cmake Tools插件&#xff0c;在编译运行的时候&#xff0c;会出现输出日志乱码的情况&#xff0c;那么如何解决呢&#xff1f; 这里提供了解决方案&#xff1a; 当Settings里的Cmake: Output Log Encoding里设…

Solidity入门: 函数

函数 Solidity语言的函数非常灵活&#xff0c;可以进行各种复杂操作。在本教程中&#xff0c;我们将会概述函数的基础概念&#xff0c;并通过一些示例演示如何使用函数。 我们先看一下 Solidity 中函数的形式: function <function name>(<parameter types>) {in…

基于 Python 自动化接口测试(踩坑与实践)

文档&#xff1a;基于 Python 的自动化接口测试 目录 背景问题描述与解决思路核心代码修改点及其详细解释最终测试结果后续优化建议 1. 问题背景 本项目旨在使用 Python 模拟浏览器的请求行为&#xff0c;测试文章分页接口的可用性。测试目标接口如下&#xff1a; bashcoder…

Spring Boot教程之五十一:Spring Boot – CrudRepository 示例

Spring Boot – CrudRepository 示例 Spring Boot 建立在 Spring 之上&#xff0c;包含 Spring 的所有功能。由于其快速的生产就绪环境&#xff0c;使开发人员能够直接专注于逻辑&#xff0c;而不必费力配置和设置&#xff0c;因此如今它正成为开发人员的最爱。Spring Boot 是…

web-app uniapp监测屏幕大小的变化对数组一行展示数据作相应处理

web-app uniapp监测屏幕大小的变化对数组一行展示数据作相应处理 1.uni.getSystemInfoSync().screenWidth; 获取屏幕宽度 2.uni.onWindowResize&#xff08;&#xff09; 实时监测屏幕宽度变化 3.根据宽度的大小拿到每行要展示的数量itemsPerRow 4.为了确保样式能够根据 items…

使用强化学习训练神经网络玩俄罗斯方块

一、说明 在 2024 年暑假假期期间&#xff0c;Tim学习并应用了Q-Learning &#xff08;一种强化学习形式&#xff09;来训练神经网络玩简化版的俄罗斯方块游戏。在本文中&#xff0c;我将详细介绍我是如何做到这一点的。我希望这对任何有兴趣将强化学习应用于新领域的人有所帮助…

计算机网络 (32)用户数据报协议UDP

前言 用户数据报协议&#xff08;UDP&#xff0c;User Datagram Protocol&#xff09;是计算机网络中的一种重要传输层协议&#xff0c;它提供了无连接的、不可靠的、面向报文的通信服务。 一、基本概念 UDP协议位于传输层&#xff0c;介于应用层和网络层之间。它不像TCP那样提…

如何将 DotNetFramework 项目打包成 NuGet 包并发布

如何将 DotNetFramework 项目打包成 NuGet 包并发布 在软件开发过程中&#xff0c;将项目打包成 NuGet 包并发布到 NuGet 库&#xff0c;可以让其他开发者方便地引用和使用你的项目成果。以下是将 WixWPFWizardBA 项目打包成 NuGet 包并发布的详细步骤&#xff1a; 1. 创建 .n…

解决GitHub上的README.md文件的图片内容不能正常显示问题

一、问题描述 我们将项目推送到GitHub上后&#xff0c;原本在本地编写配置好可展现的相对路径图片内容&#xff0c;到了GitHub上却不能够正常显示图片内容&#xff0c;我们希望能够在GitHub上正常显示图片&#xff0c;如下图所示&#xff1a; 二、问题分析 现状&#xff1a;REA…

如何解决 VS Code 调试时无法查看 std 中变量的问题

在使用 VS Code 调试 C 程序时&#xff0c;我们经常遇到查看 std 容器或字符串变量时只显示一串数字而看不到实际值的情况。这是由于调试器未启用 pretty-printing 功能导致的。为了解决这个问题&#xff0c;可以在 launch.json 中进行配置。 问题描述 在调试 C 程序时&…

安装MySQL的五种方法(Linux系统和Windows系统)

一.在Linux系统中安装MySQL 第一种方法:在线YUM仓库 首先打开MySQL官网首页 www.mysql.com 找到【DOWNLOADS】选项&#xff0c;点击 下拉&#xff0c;找到 【MySQL Community(GPL) Downloads】 在社区版下载页面中&#xff0c;【 MySQL Yum Repository 】链接为在线仓库安装…

基于mybatis-plus历史背景下的多租户平台改造

前言 别误会&#xff0c;本篇【并不是】 要用mybatis-plus自身的多租户方案&#xff1a;在表中加一个tenant_id字段来区分不同的租户数据。并不是的&#xff01; 而是在假设业务系统已经使用mybatis-plus多数据源的前提下&#xff0c;如何实现业务数据库隔开的多租户系统。 这…

RabbitMQ高级篇之MQ可靠性 数据持久化

文章目录 消息丢失的原因分析内存存储的缺陷如何确保 RabbitMQ 的消息可靠性&#xff1f;数据持久化的三个方面持久化对性能的影响持久化实验验证性能对比Spring AMQP 默认持久化总结 消息丢失的原因分析 RabbitMQ 默认使用内存存储消息&#xff0c;但这种方式带来了两个主要问…

Openssl1.1.1s rpm包构建与升级

rpmbuild入门知识 openssh/ssl二进制升级 文章目录 前言一、资源准备1.下载openssh、openssl二进制包2.安装rpmbuild工具3.拷贝源码包到SOURCES目录下4.系统开启telnet&#xff0c;防止意外导致shh无法连接5.编译工具安装6.补充说明 二、制作 OpenSSL RPM 包1.编写 SPEC 文件2.…

【Unity3D】apk加密(global-metadata.dat加密)

涉及&#xff1a;apk、aab、global-metadata.dat、jks密钥文件、APKTool、zipalign 使用7z打开apk文件观察发现有如下3个针对加密的文件。 xxx.apk\assets\bin\Data\Managed\Metadata\global-metadata.dat xxx.apk\lib\armeabi-v7a\libil2cpp.so xxx.apk\lib\arm64-v8a\libil…

[免费]微信小程序(高校就业)招聘系统(Springboot后端+Vue管理端)【论文+源码+SQL脚本】

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序(高校就业)招聘系统(Springboot后端Vue管理端)&#xff0c;分享下哈。 项目视频演示 【免费】微信小程序(高校就业)招聘系统(Springboot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项目介绍…

RNN心脏病预测-Pytorch版本

本文为为&#x1f517;365天深度学习训练营内部文章 原作者&#xff1a;K同学啊 一 导入数据 import numpy as np import pandas as pd import torch from torch import nn import torch.nn.functional as F import seaborn as sns from sklearn.preprocessing import Standard…