GNN 训练点击-购买 预测模型

搜集到click-buy数据集,

数据集分享在网盘

通过百度网盘分享的文件:数据集_20241031_220915
链接:https://pan.baidu.com/s/1qcXAO_P1h3Vrrui5qFbYLw?pwd=6f3m 

其中 yoochoose-buys.dat

特征含义buy_df.columns = ['session_id', 'timestamp', 'item_id', 'price', 'quantity']

420374,2014-04-06T18:44:58.314Z,214537888,12462,1

yoochoose-clicks.dat

特征含义click_df.columns = ['session_id', 'timestamp', 'item_id', 'category']

1,2014-04-07T10:51:09.277Z,214536502,0

我们需要构建图结构通过训练该图神经网络来预测click-buy的关系

构建训练用的Dataset, 在下面代码模板上重构process 方法,用于处理原始数据并将其转化为图数据。

from torch_geometric.data import InMemoryDataset
import torch

class ChooseBinaryDataset(InMemoryDataset):
    def __init__(self, root, clicks_file, buys_file, transform=None, pre_transform=None):
        self.clicks_file = clicks_file
        self.buys_file = buys_file
        super(ChooseBinaryDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = self.process()

    @property
    def processed_file_names(self):
        return ['data.pt']  # 指定处理后保存的数据文件名
    
    def download(self):
        pass

    def process(self):
        # Implement your data processing logic here
        pass

代码的按照流程完成下面任务

  • 数据分组
    使用 groupby 方法按 session_id 将原始数据分成多个会话组,便于处理每个用户的行为序列。

  • 标签编码
    使用 LabelEncoder 对商品 ID 进行编码,将原始字符串转换为整数形式,便于后续处理。

  • 节点特征
    从每个会话组中提取商品 ID(item_id)并创建节点特征,确保每个商品 ID 唯一。

  • 边缘索引
    创建源节点和目标节点,用于构建图中的边缘。源节点是会话中的前一个商品,目标节点是后续商品。

  • 图数据对象
    使用 PyTorch Geometric 的 Data 类构建图数据对象,并将其添加到 data_list 中。

  • 数据保存
    使用 collate 方法将图数据对象合并,并保存到指定的路径。

    def process(self):
        # 加载数据
        clicks_df = pd.read_csv(self.clicks_file, sep=';', names=['session_id', 'timestamp', 'item_id', 'category'])
        buys_df = pd.read_csv(self.buys_file, sep=';', names=['session_id', 'timestamp', 'item_id', 'price', 'quantity'])

        # 创建购买标签
        buys_df['buy'] = 1  # 标记购买的商品
        clicks_df = clicks_df.merge(buys_df[['session_id', 'item_id', 'buy']], on=['session_id', 'item_id'], how='left')
        clicks_df['buy'] = clicks_df['buy'].fillna(0)  # 填充未购买的商品为0

        # 存储图数据对象
        data_list = []
        grouped = clicks_df.groupby('session_id')

        for session_id, group in tqdm(grouped, desc="Processing sessions"):
            # 标签编码
            sess_item_id = LabelEncoder().fit_transform(group.item_id)
            group = group.reset_index(drop=True)
            group['sess_item_id'] = sess_item_id

            # 节点特征
            node_features = group['sess_item_id'].values
            node_features = torch.LongTensor(node_features).unsqueeze(1)

            # 创建边缘索引
            target_nodes = group.sess_item_id.values[1:]
            source_nodes = group.sess_item_id.values[:-1]
            edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)

            # 创建图数据对象
            x = node_features
            y = torch.FloatTensor([group.buy.values[0]])  # 使用第一个商品的购买标签

            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        return data, slices
  • clicks_df 包含会话 ID (session_id)、时间戳 (timestamp)、商品 ID (item_id) 和类别 (category)。
  • buys_df 包含会话 ID、时间戳、商品 ID、价格 (price) 和数量 (quantity)。
  • 在购买数据框中为每个购买的商品添加标签 buy,值为 1。对点击数据框进行合并,将购买标签添加到点击数据中。对未购买的商品,buy 标签填充为 0。
  • session_id 对点击数据进行分组,以便逐会话处理。遍历每个session的点击数据。将商品 ID 转换为标签编码,生成 sess_item_id,便于后续处理。

group = group.reset_index(drop=True)

  • 重置索引并将编码后的商品 ID 添加到数据框中。

target_nodes = group.sess_item_id.values[1:]
source_nodes = group.sess_item_id.values[:-1]
edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)

上面是对click 1 -> click 2 -> click 3 -> ...-> click N 这样的链表数据做处理,构建出source-> target的图结构的边关系

node_features = group['sess_item_id'].values
node_features = torch.LongTensor(node_features).unsqueeze(1)

session 里的所有点击item 作为该session的节点特征

下面就是样本和标签的定义

x = node_features
y = torch.FloatTensor([group.buy.values[0]])

节点特征作为样本,session里购买的商品取排序第一个商品作为标签然后转为torch张量

data = Data(x=x, edge_index=edge_index, y=y)
data_list.append(data)

图结构这样构建出来以后,把它都加入到data_list

data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])

  • 将所有图数据对象合并为一个批处理(使用 collate 函数)。

torch_geometric 中,collate 方法用于将多个图数据对象(Data)合并成一个批量(batch)

collate 方法会将每个图的数据(例如节点特征 x)合并成一个大的张量。它会在节点特征的维度上进行拼接,以便在一个批次中包含来自多个图的信息。边缘索引(edge_index)的合并也会进行调整。对于每个图,边缘索引会被重定向,以保证在新的批次中每个图的节点索引是唯一的,除了节点特征和边缘索引,collate 方法还会处理图数据对象中的其他属性,比如图的标签 y、图的边缘特征 edge_attr 等,确保它们在合并后仍能对应到正确的图。

  • 保存处理后的数据和切片信息,以便后续使用。

构建图神经网络模型,该模型参考额SAGE项目里的图模型, 需要有下面层结构

  • 卷积层 (SAGEConv):定义了三个图卷积层,用于提取节点特征。
  • 池化层 (TopKPooling):用于降低图的复杂度,减少节点数量。
  • 嵌入层 (Embedding):用于将商品 ID 映射到低维空间。
  • 线性层 (Linear):用于特征转换和输出。
  • 标准化层 (BatchNorm1d) 和激活函数 (ReLU):用于加速训练和增加模型的非线性能力。
class GraphNet(torch.nn.Module):
    """Graph Neural Network for binary classification tasks."""

    def __init__(self, emb_dim):
        super(GraphNet, self).__init__()
        # 定义图卷积层
        self.conv1 = SAGEConv(emb_dim, 128)
        self.pool1 = TopKPooling(128, ratio=0.8)

        self.conv2 = SAGEConv(128, 128)
        self.pool2 = TopKPooling(128, ratio=0.8)

        self.conv3 = SAGEConv(128, 128)
        self.pool3 = TopKPooling(128, ratio=0.8)

        # 定义嵌入层
        self.item_embedding = torch.nn.Embedding(num_embeddings=emb_dim + 10, embedding_dim=emb_dim)

        # 定义线性层
        self.lin1 = torch.nn.Linear(128, 128)
        self.lin2 = torch.nn.Linear(128, 64)
        self.lin3 = torch.nn.Linear(64, 1)

        # 定义标准化层和激活函数
        self.bn1 = torch.nn.BatchNorm1d(128)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.act1 = torch.nn.ReLU()
        self.act2 = torch.nn.ReLU()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # 使用嵌入层将节点特征进行编码
        x = self.item_embedding(x).squeeze(1)  # n * 128

        # 第一个卷积层和池化
        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = gap(x, batch)  # 全局池化

        # 第二个卷积层和池化
        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = gap(x, batch)  # 全局池化

        # 第三个卷积层和池化
        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = gap(x, batch)  # 全局池化

        # 将三个尺度的全局特征相加
        x = x1 + x2 + x3

        # 通过线性层进行特征转换
        x = self.act1(self.lin1(x))
        x = self.act2(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)

        # 最后一层使用Sigmoid激活函数进行二元分类
        x = torch.sigmoid(self.lin3(x)).squeeze(1)  # batch个结果
        return x
  • 特征提取:首先通过嵌入层将节点特征编码,然后经过多个图卷积层和池化层提取特征。
  • 池化操作:通过 gap 函数获取全局特征。
  • 特征融合:将不同尺度的全局特征相加,以获得综合特征。
  • 线性变换和激活:通过多个线性层和激活函数进行最终的特征处理。
  • 输出:最后通过 Sigmoid 激活函数将输出压缩到 [0, 1] 之间,适合二分类任务。

编写训练代码

model = GraphNet(emb_dim=100).to(device)  # 使用适当的嵌入维度
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()

# 创建数据加载器
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 训练循环
for epoch in range(100):  # 设置适当的 epoch 数
    model.train()
    total_loss = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')

torch.save(model.state_dict(), 'graphnet_model.pth')
print("Model saved to graphnet_model.pth")

开始训练

图结构比较大,训练时间较长

看到图神经网络的loss可以训练的特别小

用相同构造图数据的方法构造下测试数据,加载训练的模型尝试进行预测

# 加载模型
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from model import GraphNet

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

loaded_model = GraphNet(emb_dim=100).to(device)
loaded_model.load_state_dict(torch.load('graphnet_model.pth'))
loaded_model.eval()  # 设置为评估模式

# 准备预测数据(假设你有新的点击数据)
# 这里你可以使用与训练时相同的处理方式来创建新的 Data 对象
# 例如,假设你有新的点击数据文件
new_clicks_df = pd.read_csv('test-yoochoose-clicks.dat', sep=';', names=['session_id', 'timestamp', 'item_id', 'category'])

# 创建新的图数据
new_data_list = []
for session_id, group in tqdm(new_clicks_df.groupby('session_id'), desc="Processing new sessions"):
    sess_item_id = LabelEncoder().fit_transform(group.item_id)
    group = group.reset_index(drop=True)
    group['sess_item_id'] = sess_item_id

    # 节点特征
    node_features = group['sess_item_id'].values
    node_features = torch.LongTensor(node_features).unsqueeze(1)

    # 创建边缘索引
    if len(group) > 1:  # 确保有足够的节点以构建边
        target_nodes = group.sess_item_id.values[1:]
        source_nodes = group.sess_item_id.values[:-1]
        edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)

        # 创建图数据对象
        x = node_features
        new_data = Data(x=x, edge_index=edge_index)

        new_data_list.append(new_data)

# 将新的图数据转换为 DataLoader
new_loader = DataLoader(new_data_list, batch_size=32, shuffle=False)

# 开始预测
predictions = []
with torch.no_grad():  # 禁用梯度计算以节省内存
    for data in new_loader:
        data = data.to(device)
        output = loaded_model(data)
        predictions.append(output.cpu().numpy())

# 将预测结果合并
predictions = np.concatenate(predictions)

# 打印预测结果
print("Predictions:", predictions)

代码提交在我的github上

https://github.com/chenrui2200/click_buy_graph_predict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/907567.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SpringMvc day1102

ok了家人们今天我们学习SpringMvc,之后学习SpringBoot,let‘s go 六.拦截器 6.1 拦截器概述 Spring MVC 的处理器拦截器类似于 Servlet 开发中的过滤器 Filter ,用于对处理器 ( 自己编写的 Controller) 进行预处理和后 处理。用户可以自…

项目管理(风险:范围、成本、时间、质量)

项目管理主要是围绕着范围、成本、时间、质量,每个部分都存在不同的风险。 存在潜在的风险方面,也有可能是法律风险、合作方带来的风险等。 减少风险: 提前与技术沟通方案。产品内部先讨论。每日例会:同步信息,减少…

6.0、静态路由

路由器最主要的功能就是转发数据包。路由器转发数据包时需要查找路由表(你可以理解为地图),管理员可以直接手动配置路由表,这就是静态路由。 1.什么是路由? 在网络世界中,路由是指数据包在网络中的传输路…

网络层3——IP数据报转发的过程

目录 一、基于终点的转发 1、理解 2、IP数据报转发过程 二、最长前缀匹配 1、理解 2、主机路由 3、默认路由 三、二叉线索查找 一、基于终点的转发 1、理解 理解什么叫终点转发 IP数据报的传递,交给路由器后 可不可以做到直接发送给目的主机呢?…

VMware虚拟机Debian扩展磁盘

一、 版本 VMware:Workstation 17 Pro虚拟机:Debian11 二、 VMware虚拟机扩展 虚拟机关机状态快照或者备份:以免扩容失败导致文件丢失虚拟机——设置——硬盘——磁盘使用工具——扩展——扩展磁盘容量——设置为想要的大小 三、 虚拟机…

新能源汽车的未来:车载电源与V2G技术的前景

近年来,新能源汽车在全球市场上发展迅速,尤其是在中国,新能源汽车的月销量已经超过了燃油车。随着新能源技术的不断发展,新能源汽车不仅仅是作为出行工具,而逐渐成为“移动能源站”。本文将探讨电动汽车的车载外放电功…

JavaScript知识点梳理及案例实践

1. Date对象 创建Date对象 //方法1:不指定参数 var nowd1new Date(); console.log(nowd1.toLocaleString( )); //方法2:参数为日期字符串 var d2new Date("2004/3/20 11:12"); console.log(d2.toLocaleString( )); var d3new Date("04/…

[vulnhub]DC:7

https://www.vulnhub.com/entry/dc-7,356/ 端口扫描主机发现 探测存活主机,178是靶机 nmap -sP 192.168.75.0/24 Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-11-03 13:30 CST Nmap scan report for 192.168.75.1 Host is up (0.00037s l…

探索React源码:React Diff

本篇文章我们来了解一下Diff算法的实现过程。 相关概念 React中的各种节点 假设当前存在一个DOM节点,触发了一次更新,那么在协调的过程中,会有四种节点和该节点相关联: 该DOM节点本身。 workInProgress fiber,更新过程…

开源库 FloatingActionButton

开源库FloatingActionButton Github:https://github.com/Clans/FloatingActionButton 这个库是在前面这个库android-floating-action-button的基础上修改的,增加了一些更强大和实用的特性。 特性: Android 5.0 以上点击会有水波纹效果 可以选择自定义…

微服务设计模式 - 重试模式(Retry Pattern)

微服务设计模式 - 重试模式(Retry Pattern) 定义 重试模式(Retry Pattern)是一种微服务中的设计模式,用于在临时性失败(如网络故障或暂时不可用的服务)发生时,自动重新尝试请求&…

HTML 基础标签——链接标签 <a> 和 <iframe>

文章目录 1. `<a>` 标签属性详细说明示例2. `<iframe>` 标签属性详细说明示例注意事项总结链接标签在HTML中是实现网页导航的重要工具,允许用户从一个页面跳转到另一个页面或嵌入外部内容。主要的链接标签包括 <a> 标签和<iframe> 标签。本文将深入探…

Netty 组件介绍 - Future Promise

在异步处理时&#xff0c;经常用到这两个接口 netty 中的 Future 继承 jdk 中的 FutuFuture&#xff0c;而Promise 又对 netty Future 进行了扩展。 idk Future 只能同步等待任务结束&#xff08;或成功或失败)才能得到结果netty Future 可以同步等待任务结束得到结也可以异…

Excel:vba实现批量插入图片批注

实现的效果&#xff1a;实现的代码如下&#xff1a; Sub InsertImageNamesAndPictures()Dim PicPath As StringDim PicName As StringDim PicFullPath As StringDim RowNum As IntegerDim Name As StringDim Comment As CommentDim folder As FileDialog 定义文件选择对话框 清…

C++(类和对象-友元)

友元的作用 作用&#xff1a; 在C中&#xff0c;友元&#xff08;friend&#xff09;是一种特殊的类成员&#xff0c;它可以让一个函数或者类访问其他类的私有&#xff08;private&#xff09;和保护&#xff08;protected&#xff09;成员。 注意&#xff1a; 友元的使用应该谨…

ssm044基于java和mysql的多角色学生管理系统+jsp(论文+源码)_kaic

毕 业 设 计&#xff08;论 文&#xff09; 题目&#xff1a;学生管理系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本学生管理系统就是在这…

猜字谜 华为OD

源码 Java import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test;import java.util.ArrayList; import java.util.List; import java.util.TreeSet;public class GuessWord {public Input input;BeforeEachpublic void init() {input new Input("…

C++队列

好久没有发博客了&#xff0c;欢迎阅读&#xff1a;C队列。 1.队列的介绍 队列&#xff08;queue&#xff09;是一种先进先出的、操作受限的线性表。 数据结构是&#xff1a;先进入队列的先出去&#xff0c;后进入队列的后出去。 必须从队尾插入新元素&#xff0c;队列中的…

【ArcGISPro】制作简单的ArcGISPro-AI助手

【python】AI Navigator的使用及搭建本机大模型_anaconda ai navigator-CSDN博客 【Python】AI Navigator对话流式输出_ai大模型流式输出效果(打字效果) python-CSDN博客 【C#】调用本机AI大模型流式返回_怎么实现调用本地大模型时实现流式输出-CSDN博客 【ArcGISPro】宣布推…

小白从零开始配置pytorch环境

一、下载ANACONDA 官方网址Anaconda Installers and Packages 笔者选择的是Anaconda3-5.3.0-Windows-x86_64.exe版本。全程安装可以手机开热点&#xff0c;会快一点。 二、查看电脑是否有显卡 1、打开任务管理器 2、查看电脑CUBA版本&#xff0c;如上篇文章所提到查看CUDA-V…