【PYG】Cora数据集分类任务计算损失,cross_entropy为什么不能直接替换成mse_loss

  • cross_entropy计算误差方式,输入向量z为[1,2,3],预测y为[1],选择数为2,计算出一大坨e的式子为3.405,再用-2+3.405计算得到1.405
  • MSE计算误差方式,输入z为[1,2,3],预测向量应该是[1,0,0],和输入向量维度相同

在这里插入图片描述
将cross_entropy直接替换成mse_loss报错RuntimeError: The size of tensor a (7) must match the size of tensor b (140) at non-singleton dimension 1

cross_entropy 换成 mse_loss 会报错的原因是,这两个损失函数的输入和输出形状要求不同。cross_entropy 是一个分类损失函数,它期望输入是未归一化的logits(形状为 [batch_size, num_classes]),而标签是整数类别(形状为 [batch_size])。mse_loss 是一个回归损失函数,它期望输入和标签的形状相同。

如果你想使用 mse_loss 来替代 cross_entropy,你需要对标签进行one-hot编码,使它们与模型的输出形状匹配。下面是如何修改代码以使用 mse_loss 的示例:

修改代码以使用 mse_loss

  1. 加载必要的库
    你需要一个工具来将标签转换为one-hot编码。这里我们使用 torch.nn.functional.one_hot

  2. 修改训练函数
    在训练函数中,将标签转换为one-hot编码,然后计算 mse_loss

核心测试代码讲解

out=model(data)模型输出形状为torch.Size([140, 7])
data.y中测试数据输出形状为torch.Size([140]),打印第一个数据为3,7个类别中的第4个类别
将3转化为7位置独热码计算MSE,对应train_labels_one_hot第一个数据[0., 0., 0., 1., 0., 0., 0.]为4
out形状为torch.Size([140, 7]),train_labels_one_hot的形状为[140, 7]

torch.Size([140, 7]) torch.Size([140])
tensor([-0.0166,  0.0191, -0.0036, -0.0053, -0.0160,  0.0071, -0.0042],
       device='cuda:0', grad_fn=<SelectBackward0>) tensor(3, device='cuda:0')
tensor([[0., 0., 0., 1., 0., 0., 0.],
		...
		[0., 1., 0., 0., 0., 0., 0.]], device='cuda:0')
train_labels_one_hot shape torch.Size([140, 7])
test out torch.Size([2708, 7])
train_labels_one_hot = F.one_hot(data.y[data.train_mask], num_classes=dataset.num_classes).float()
print(out[data.train_mask].shape, data.y[data.train_mask].shape)
print(out[data.train_mask][0], data.y[data.train_mask][0])
print(train_labels_one_hot)
print(f"train_labels_one_hot shape {train_labels_one_hot.shape}")
loss = F.mse_loss(out[data.train_mask], train_labels_one_hot)

解释

  1. 加载库:我们使用 torch.nn.functional.one_hot 将标签转换为one-hot编码。
  2. 修改训练函数
    • 将标签 train_labels 转换为one-hot编码,train_labels_one_hot = F.one_hot(train_labels, num_classes=dataset.num_classes).float()
    • 使用 mse_loss 计算均方误差损失 loss = F.mse_loss(train_out, train_labels_one_hot)
  3. 保持评估函数不变:评估函数仍然使用 argmax 提取预测类别,并计算准确性。

魔改完整代码

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora',  transform=NormalizeFeatures())
data = dataset[0]

# 定义GCN模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x
        # return F.log_softmax(x, dim=1)

# 初始化模型和优化器
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')

# 打印归一化后的特征
print(data.x[0])

print(f"data.train_mask{data.train_mask}")

# 训练模型
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    # print(f"out[data.train_mask] {data.train_mask.shape} {out[data.train_mask].shape} {out[data.train_mask]}")
    # loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    train_labels_one_hot = F.one_hot(data.y[data.train_mask], num_classes=dataset.num_classes).float()
    print(out[data.train_mask].shape, data.y[data.train_mask].shape)
    print(out[data.train_mask][0], data.y[data.train_mask][0])
    print(train_labels_one_hot)
    print(f"train_labels_one_hot shape {train_labels_one_hot.shape}")
    loss = F.mse_loss(out[data.train_mask], train_labels_one_hot)
    loss.backward()
    optimizer.step()
    return loss.item()

# 评估模型
def test():
    model.eval()
    out = model(data)
    print(f"test out {out.shape}")
    print(f"test out[0] {out[0].shape} {out[0]}")
    print(f"test out[0:1,:] {out[0:1,:].shape} {out[0:1,:]}")
    print(f"test out[0:1,:].argmax(dim=1) {out[0:1,:].argmax(dim=1)}")
    pred = out.argmax(dim=1)
    print(f"test pred {pred[data.test_mask].shape} {pred[data.test_mask]}")
    print(f"data {data.y[data.test_mask].shape} {data.y[data.test_mask]}")
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    return acc

for epoch in range(1):
    loss = train()
    acc = test()
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

原始代码

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# 加载Cora数据集,并应用NormalizeFeatures变换
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

# 计算训练、验证和测试集的大小
num_train = data.train_mask.sum().item()
num_val = data.val_mask.sum().item()
num_test = data.test_mask.sum().item()

print(f'Number of training nodes: {num_train}')
print(f'Number of validation nodes: {num_val}')
print(f'Number of test nodes: {num_test}')

# 定义GCN模型
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x  # 返回未归一化的logits

# 初始化模型和优化器
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')

# 训练模型
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)  # out 的形状是 [num_nodes, num_classes]
    train_out = out[data.train_mask]  # 选择训练集节点的输出
    train_labels = data.y[data.train_mask]  # 选择训练集节点的标签

    # 将标签转换为one-hot编码
    train_labels_one_hot = F.one_hot(train_labels, num_classes=dataset.num_classes).float()

    # 计算均方误差损失
    loss = F.mse_loss(train_out, train_labels_one_hot)
    loss.backward()
    optimizer.step()
    return loss.item()

# 评估模型
def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)  # 提取预测类别
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    return acc

for epoch in range(200):
    loss = train()
    acc = test()
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

通过这些修改,你可以将交叉熵损失函数替换为均方误差损失函数,并确保输入和标签的形状匹配,从而避免报错。

  • 简单版本的的答案

Cross Entropy vs. MSE Loss

  1. Cross Entropy Loss:

    • 输入:模型的logits,形状为 ([N, C]),其中 (N) 是批次大小,(C) 是类别数量。
    • 目标:目标类别的索引,形状为 ([N])。
  2. MSE Loss:

    • 输入:模型的预测值,形状为 ([N, C])。
    • 目标:实际值,形状为 ([N, C])(通常是 one-hot 编码)。

要将 cross_entropy 换成 mse_loss,需要确保输入和目标的形状匹配。具体来说,你需要将目标类别索引转换为 one-hot 编码。

示例代码

假设你有一个分类任务,其中模型输出的是 logits,目标是类别索引。我们将这个设置转换为使用 MSE Loss。

import torch
import torch.nn.functional as F

# 假设有一个批次的模型输出和目标标签
logits = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], requires_grad=True)  # 模型输出
target = torch.tensor([0, 2])  # 目标类别

# 使用 cross_entropy
cross_entropy_loss = F.cross_entropy(logits, target)
print("Cross-Entropy Loss:")
print(cross_entropy_loss)

# 转换目标类别为 one-hot 编码
target_one_hot = F.one_hot(target, num_classes=logits.size(1)).float()
print("One-Hot Encoded Targets:")
print(target_one_hot)

# 计算 MSE Loss
mse_loss = F.mse_loss(F.softmax(logits, dim=1), target_one_hot)
print("MSE Loss:")
print(mse_loss)

输出

Cross-Entropy Loss:
tensor(1.4076, grad_fn=<NllLossBackward>)
One-Hot Encoded Targets:
tensor([[1., 0., 0.],
        [0., 0., 1.]])
MSE Loss:
tensor(0.2181, grad_fn=<MseLossBackward>)

解释

  1. logits: 模型的原始输出,形状为 ([N, C])。
  2. target: 原始目标类别索引,形状为 ([N])。
  3. target_one_hot: 将目标类别索引转换为 one-hot 编码,形状为 ([N, C])。
  4. F.mse_loss: 使用 F.softmax(logits, dim=1) 计算模型的概率分布,然后与 target_one_hot 计算 MSE 损失。

通过将目标类别转换为 one-hot 编码并确保输入和目标的形状匹配,可以成功地将 cross_entropy 换成 mse_loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/765367.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python + 在线 + 中文文本转为英文语音,语音转为中文文本

模型 平台&#xff1a;https://huggingface.co/ars-语言转文本: pipeline("automatic-speech-recognition", model"openai/whisper-large-v3", device0 )tts-文本转语音&#xff1a;pipeline("text-to-speech", "microsoft/speecht5_tts&q…

Java8新特性之Optional、Lambda表达式、Stream、新日期Api使用总结

标题 Optional类常用Api使用说明Optional API 使用建议 Lambda表达式Lambda 表达式的局限性与解决方案 Stream案例实操优化 Stream 操作 新的日期和时间APILocalDateTime在SpringBoot中的应用 函数式接口&#xff08;Functional&#xff09; Optional博客参考地址1 Stream博客参…

【Linux从入门到放弃】探究进程如何退出以进程等待的前因后果

&#x1f9d1;‍&#x1f4bb;作者&#xff1a; 情话0.0 &#x1f4dd;专栏&#xff1a;《Linux从入门到放弃》 &#x1f466;个人简介&#xff1a;一名双非编程菜鸟&#xff0c;在这里分享自己的编程学习笔记&#xff0c;欢迎大家的指正与点赞&#xff0c;谢谢&#xff01; 进…

【C++】STL-priority_queue

目录 1、priority_queue的使用 2、实现没有仿函数的优先级队列 3、实现有仿函数的优先级队列 3.1 仿函数 3.2 真正的优先级队列 3.3 优先级队列放自定义类型 1、priority_queue的使用 priority_queue是优先级队列&#xff0c;是一个容器适配器&#xff0c;不满足先进先出…

pdf拆分,pdf拆分在线使用,pdf拆分多个pdf

在数字化的时代&#xff0c;pdf文件已经成为我们日常办公、学习不可或缺的文档格式。然而&#xff0c;有时候我们可能需要对一个大的pdf文件进行拆分&#xff0c;以方便管理和分享。那么&#xff0c;如何将一个pdf文件拆分成多个pdf呢&#xff1f;本文将为你推荐一种好用的拆分…

监控平台zabbix介绍与部署

目录 1.为什么要做监控 2.zabbix是什么&#xff1f; 3.zabbix 监控原理 4.Zabbix 6.0 新特性 5.Zabbix 6.0 功能组件 6.部署zabbix 6.1 部署 Nginx PHP 环境并测试 6.2 部署数据库 6.3 向数据库导入 zabbix 数据 6.4 编译安装 zabbix Server 服务端 6.5 修改 zabbix…

中小企业如何防止被查盗

在当前的商业环境中&#xff0c;小企业面临诸多挑战&#xff0c;其中之一便是如何在有限的预算内满足日常运营的技术需求。由于正版软件的高昂成本&#xff0c;一些小企业可能会选择使用盗版软件来降低成本。 我们联网之后存在很多风险&#xff0c;你可以打开自己的可以联网的电…

【面试系列】机器学习工程师高频面试题及详细解答

欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;欢迎订阅相关专栏&#xff1a; ⭐️ 全网最全IT互联网公司面试宝典&#xff1a;收集整理全网各大IT互联网公司技术、项目、HR面试真题. ⭐️ AIGC时代的创新与未来&#xff1a;详细讲解AIGC的概念、核心技术、…

Java--创建对象内存分析

1.如图所示&#xff0c;左边为一个主程序&#xff0c;拥有main方法&#xff0c;右边定义了一个Pet类&#xff0c;通过debug不难看出&#xff0c;当启动该方法时&#xff0c;有以下该步骤 1.运行左边的实例化Pet类对象 2.跳转至右边定义Pet类的语句 3.跳回至左边获取Pet类基本属…

代码生成器使用指南,JeecgBoot低代码平台

JeecgBoot 提供强大的代码生成器&#xff0c;让前后端代码一键生成&#xff0c;实现低代码开发。支持单表、树列表、一对多、一对一等数据模型&#xff0c;增删改查功能一键生成&#xff0c;菜单配置直接使用。 同时提供强大模板机制&#xff0c;支持自定义模板&#xff0c;目…

【bug报错已解决】ERROR: Could not find a version that satisfies the requirement

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 文章目录 引言一、问题描述1.1 报错示例1.2 报错分析 二、解决方法2.1 方法一2.2 方法二 三、总结 引言 有没有遇到过那种让人…

JSON JOLT常用示例整理

JSON JOLT常用示例整理 1、什么是jolt Jolt是用Java编写的JSON到JSON转换库&#xff0c;其中指示如何转换的"specification"本身就是一个JSON文档。以下文档中&#xff0c;我统一以 Spec 代替如何转换的"specification"json文档。以LHS(left hand side)代…

kaggle量化赛金牌方案(第七名解决方案)

获奖文章(第七名解决方案) 致谢 我要感谢 Optiver 和 Kaggle 组织了这次比赛。这个挑战提出了一个在金融市场时间序列预测领域中具有重大和复杂性的问题。 方法论 我的方法结合了 LightGBM 和神经网络模型,对神经网络进行了最少的特征工程。目标是结合这些模型以降低最终…

arco disign vue 日期组件的样式穿透

问题描述: 对日期组件进行样式穿透. 原因分析: 如图,日期组件被展开时它默认将dom元素挂载到body下, 我们的页面在idroot的div 里层, 里层想要穿透外层是万万行不通的. 解决问题: 其实官网提供了参数,但是并没有提供例子, 只能自己摸索着过河. 对于日期组件穿透样式,我们能…

收集了很久的全网好用的磁力搜索站列表分享

之前找资源的时候&#xff0c;收集了一波国内外大部分主流的磁力链接搜索站点。每一个站可能都有对应的优缺点&#xff0c;多试试&#xff0c;就能知道自己要哪个了。 全网好用的磁力链接 大部分的时候&#xff0c;我们用国内的就可以了&#xff0c;速度块&#xff0c;而且不…

Snappy使用

Snappy使用 Snappy是谷歌开源的压缩和解压的开发包&#xff0c;目标在于实现高速的压缩而不是最大的压缩 项目地址&#xff1a;GitHub - google/snappy&#xff1a;快速压缩器/解压缩器 Cmake版本升级 该项目需要比较新的cmake&#xff0c;CMake 3.16.3 or higher is requi…

51单片机第23步_定时器1工作在模式0(13位定时器)

重点学习51单片机定时器1工作在模式0的应用。 在51单片机中&#xff0c;定时器1工作在模式0&#xff0c;它和定时器0一样&#xff0c;TL1占低5位&#xff0c;TH1占高8位&#xff0c;合计13位&#xff0c;也是向上计数。 1、定时器1工作在模式0 1)、定时器1工作在模式0的框图…

8619 公约公倍

这个问题可以通过计算最大公约数 (GCD) 和最小公倍数 (LCM) 来解决。我们需要找到一个整数&#xff0c;它是 a, b, c 的 GCD 的倍数&#xff0c;同时也是 d, e, f 的 LCM 的约数。 以下是解决这个问题的步骤&#xff1a; 1. 计算 a, b, c 的最大公约数。 2. 计算 d, e, f 的最…

流处理系统对比:RisingWave vs ksqlDB

本文将从架构、部署与可扩展性、Source 和 Sink、生态系统与开发者工具几个方面比较 ksqlDB 和 RisingWave 这两款领先的流处理系统。 1. 架构 ksqlDB 是由 Confluent 开发和维护的流处理 SQL 引擎&#xff0c;专为 Apache Kafka 设计。它基于 Kafka Streams 构建&#xff0c;…

鸿蒙:路由Router原理

页面路由&#xff1a;在应用程序中实现不同页面之间的跳转和数据传递 典型应用&#xff1a;商品信息返回、订单等多页面跳转 页面栈最大容量为32个页面&#xff0c;当页面需要销毁可以使用router.clear()方法清空页面栈 router有两种页面跳转模式&#xff1a; router.pushUrl…