Deep Crossing:深度交叉网络在推荐系统中的应用

实验和完整代码

完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main

引言

在机器学习和深度学习领域,特征工程一直是一个关键步骤,尤其是对于大规模的推荐系统和广告点击率预测任务。传统的特征工程通常依赖于手动设计的组合特征,这些特征虽然有效,但在大规模数据场景下,其开发和维护成本极高。Deep Crossing 是一种新型的深度学习模型,能够自动学习特征组合,无需手动设计组合特征,从而在大规模数据上实现高效建模。

背景知识

Deep Crossing 是由微软研究院提出的一种深度神经网络模型,专门用于处理大规模稀疏特征数据。该模型的核心思想是通过嵌入层(Embedding Layer)、残差单元(Residual Units)和评分层(Scoring Layer)自动学习特征之间的复杂交互关系。Deep Crossing 的主要贡献在于它能够自动发现重要的特征组合,而无需依赖于手动设计的组合特征。

1. 模型结构

Deep Crossing 的网络结构主要包括以下几个部分:

  1. Embedding 层
    • 将稀疏的类别特征嵌入到低维的稠密向量中。每个类别特征都有一个对应的嵌入矩阵,嵌入矩阵的大小为 (类别数, 嵌入维度)
    • 例如,对于用户 ID 和项目 ID 等类别特征,可以将其嵌入到一个低维的稠密向量中,以便神经网络能够更好地处理。
  2. 残差单元(Residual Units)
    • 残差单元是 Deep Crossing 的核心部分,用于学习特征之间的复杂交互关系。每个残差单元包含两个全连接层(nn.Linear),中间通过非线性激活函数(ReLU)和批量归一化(BatchNorm)进行处理。
    • 残差单元的输出通过残差连接(Residual Connection)与输入相加,从而保留了输入的特征信息,避免了梯度消失问题。
  3. 评分层(Scoring Layer)
    • 评分层是一个全连接层,用于将经过残差单元处理后的特征向量映射到最终的预测值。输出层通常使用 Sigmoid 函数将输出值映射到 [0, 1] 范围内,表示预测的概率。

模型结构如下:

其中Feature #1 和 Features #n都是分类型数据,Feature #2是数值型数据

残差模块结构如下:

随着网络的加深,梯度在反向传播过程中可能会逐渐衰减(梯度消失)或指数级增长(梯度爆炸)。残差连接(Residual Connection) 通过 恒等映射(Identity Mapping),使梯度可以直接沿着跳跃连接传播,从而减轻梯度消失或爆炸的问题。这对于深度神经网络(DNN)而言尤为重要。

数学上,假设残差模块的输入为 x \mathbf{x} x,非线性变换为 F ( x ) F(\mathbf{x}) F(x),则输出为:

y = F ( x ) + x y=F(x)+x y=F(x)+x

这样,在反向传播时,梯度可以通过 F ( x ) F(\mathbf{x}) F(x) 传播,也可以通过恒等映射直接传播:

∂ y ∂ x = ∂ F ( x ) ∂ x + 1 \frac{\partial \mathbf{y}}{\partial \mathbf{x}} = \frac{\partial F(\mathbf{x})}{\partial \mathbf{x}}+ 1 xy=xF(x)+1

这保证了梯度不会因层数加深而过度衰减。


此外,从模型的表达能力来看,由于残差模块能够直接建模

F ( x ) = H ( x ) − x F(x) = H(x) - x F(x)=H(x)x

模型学习的是输入和输出之间的残差,而不是直接拟合输出 H ( X ) H(X) H(X),使得模型更容易优化,也能学习到更复杂的特征交互关系。

2. 模型理论框架

2.1 整体架构

Deep Crossing采用经典的Embedding+MLP范式,其数学表达为:

y ^ = σ ( W ( L ) ⋅ h ( L − 1 ) + b ( L ) ) \hat{y} = \sigma(W^{(L)} \cdot h^{(L-1)} + b^{(L)}) y^=σ(W(L)h(L1)+b(L))

其中 h ( l ) h^{(l)} h(l)表示第 l l l层隐藏状态,包含以下核心组件:

1. 特征嵌入层

​ 对类别型特征 c i ∈ R d i c_i \in \mathbb{R}^{d_i} ciRdi进行降维:

e i = E i T c i , E i ∈ R d i × k e_i = E_i^T c_i, \quad E_i \in \mathbb{R}^{d_i \times k} ei=EiTci,EiRdi×k

​ 数值型特征直接标准化处理:

v j = x j − μ j σ j v_j = \frac{x_j - \mu_j}{\sigma_j} vj=σjxjμj

2. 特征堆叠层

​ 将各特征向量拼接:

h ( 0 ) = [ e 1 ; e 2 ; . . . ; e m ; v 1 ; v 2 ; . . . ; v n ] h^{(0)} = [e_1; e_2; ...; e_m; v_1; v_2; ...; v_n] h(0)=[e1;e2;...;em;v1;v2;...;vn]

3. 残差层

采用改进的残差单元(受ResNet启发):

h ( l ) = f ( W 2 ( l ) ⋅ ReLU ( W 1 ( l ) h ( l − 1 ) + b 1 ( l ) ) + b 2 ( l ) ) + h ( l − 1 ) h^{(l)} = f(W_2^{(l)} \cdot \text{ReLU}(W_1^{(l)} h^{(l-1)} + b_1^{(l)}) + b_2^{(l)}) + h^{(l-1)}\\ h(l)=f(W2(l)ReLU(W1(l)h(l1)+b1(l))+b2(l))+h(l1)
其中f为激活函数,实验表明ReLU效果最优。

4. 评分层

最终预测层实现为:

p = sigmoid ( W ( L ) h ( L − 1 ) + b ( L ) ) p = \text{sigmoid}(W^{(L)} h^{(L-1)} + b^{(L)}) p=sigmoid(W(L)h(L1)+b(L))

3. 代码实现

残差模块

#残差网络块
class ResidualUnit(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout_rate):
        super(ResidualUnit, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, input_dim),
            nn.BatchNorm1d(input_dim),
            nn.Dropout(dropout_rate)
        )
        self.relu = nn.ReLU()
    
    def forward(self, x):
        residual = self.layers(x)
        return self.relu(x + residual)
    

Deep Crossing模块

class DeepCrossing(nn.Module):
    def __init__(self, cat_sizes, num_sizes, config):
        super(DeepCrossing, self).__init__()

        #Embedding层
        self.embeddings = nn.ModuleList([
            nn.Embedding(size, config.embedding_dim ) for size in cat_sizes #生成对应 Embedding层    
        ])

        #计算总特征维度
        total_dim = len(cat_sizes) * config.embedding_dim + num_sizes

        #多层Residual units
        self.res_uint = nn.Sequential()
        for _ in range(config.num_residual_units):
            self.res_uint.append(
                ResidualUnit(total_dim, config.hidden_dim, config.dropout_rate)
            )

        #scoring层
        self.fc = nn.Linear(total_dim,1)

    def forward(self, x_cat, x_num):
        #处理类别特征,注意x_cat 每一列都是一个类别特征,采用类似Ordinal Encoder
        embeddings = []
        for i in range(len(self.embeddings)):
            embeddings.append(self.embeddings[i](x_cat[:,i]))
        
        x = torch.cat(embeddings, dim = 1) #拼接起来

        #拼接数值特征
        x = torch.cat([x,x_num], dim = 1)

        #残差单元
        x = self.res_uint(x)

        #输出层
        return torch.sigmoid(self.fc(x)).squeeze()
        
        

4. 实验

由于没有合适的数据,使用sklearn中make_classification方法生成的数据进行实验如下:
在这里插入图片描述

实验结果表明,Deep Crossing 模型在训练和测试集上都表现良好,损失逐渐减小,AUC 分数逐渐提高,且训练和测试结果接近,说明模型能够有效地学习特征之间的交互关系,并具有良好的泛化能力。这些结果验证了 Deep Crossing 模型在处理大规模稀疏数据和自动特征学习方面的优势。

总结

Deep Crossing 通过 Residual Network 深度建模特征交互,避免了手工特征工程的复杂性,并在 CTR 预估等任务中表现优异。相比于传统神经网络,残差结构的加入有效缓解了梯度消失问题,使得深度学习在推荐系统领域取得更大突破。

Reference

[1]. Y. Shan, T. R. Hoens, J. Jiao, H. Wang, D. Yu, and J. C. Mao, “Deep Crossing: Web-Scale Modeling without Manually Crafted Combinatorial Features,” in Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016, pp. 255-262.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/964513.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++多线程编程——基于策略模式、单例模式和简单工厂模式的可扩展智能析构线程

1. thread对象的析构问题 在 C 多线程标准库中,创建 thread 对象后,必须在对象析构前决定是 detach 还是 join。若在 thread 对象销毁时仍未做出决策,程序将会终止。 然而,在创建 thread 对象后、调用 join 前的代码中&#xff…

Rust中使用ORM框架diesel报错问题

1 起初环境没有问题:在Rust开发的时候起初使用的是mingw64平台加stable-x86_64-pc-windows-gnu编译链,当使用到diesel时会报错,如下: x86_64-w64-mingw32/bin/ld.exe: cannot find -lmysql具体信息很长这是主要信息是rust找不到链…

Fastdds学习分享_xtpes_发布订阅模式及rpc模式

在之前的博客中我们介绍了dds的大致功能,与组成结构。本篇博文主要介绍的是xtypes.分为理论和实际运用两部分.理论主要用于梳理hzy大佬的知识,对于某些一带而过的部分作出更为详细的阐释,并在之后通过实际案例便于理解。案例分为普通发布订阅…

Windows:AList+RaiDrive挂载阿里云盘至本地磁盘

零、前言 电脑存储的文件多了,出现存储空间不够用的情况。又没前买新的硬盘或者笔记本电脑没有额外的插槽提供给新的硬盘。遇到这种情况,我想到可以使用网盘,但单纯的网盘又要上传下载,极其麻烦。看到WPS云盘可以直接挂载本地&…

Unity游戏(Assault空对地打击)开发(3) 摄像机的控制

详细步骤 打开My Assets或者Package Manager。 选择Unity Registry。 搜索Cinemachine,找到 Cinemachine包,点击 Install按钮进行安装。 关闭窗口,新建一个FreeLook Camera,如下。 接着新建一个对象Pos,拖到Player下面…

保姆级教程Docker部署Zookeeper官方镜像

目录 1、安装Docker及可视化工具 2、创建挂载目录 3、运行Zookeeper容器 4、Compose运行Zookeeper容器 5、查看Zookeeper运行状态 6、验证Zookeeper是否正常运行 1、安装Docker及可视化工具 Docker及可视化工具的安装可参考:Ubuntu上安装 Docker及可视化管理…

力扣73矩阵置零

给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 输入:matrix [[1,1,1],[1,0,1],[1,1,1]] 输出:[[1,0,1],[0,0,0],[1,0,1]] 输入:matrix [[0,1,2,0],[3,4,5,2],[…

【ArcGIS_Python】使用arcpy脚本将shape数据转换为三维白膜数据

说明: 该专栏之前的文章中python脚本使用的是ArcMap10.6自带的arcpy(好几年前的文章),从本篇开始使用的是ArcGIS Pro 3.3版本自带的arcpy,需要注意不同版本对应的arcpy函数是存在差异的 数据准备:准备一个…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.19 线性代数核武器:BLAS/LAPACK深度集成

2.19 线性代数核武器:BLAS/LAPACK深度集成 目录 #mermaid-svg-yVixkwXWUEZuu02L {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-yVixkwXWUEZuu02L .error-icon{fill:#552222;}#mermaid-svg-yVixkwXWUEZ…

记8(高级API实现手写数字识别

目录 1、Keras:2、Sequential模型:2.1、建立Sequential模型:modeltf.keras.Sequential()2.2、添加层:model.add(tf.keras.layers.层)2.3、查看摘要:model.summary()2.4、配置训练方法:model.compile(loss,o…

RK3568中使用QT opencv(显示基础图像)

文章目录 一、查看对应的开发环境是否有opencv的库二、QT使用opencv一、查看对应的开发环境是否有opencv的库 在开发板中的/usr/lib目录下查看是否有opencv的库: 这里使用的是正点原子的ubuntu虚拟机,在他的虚拟机里面已经安装好了opencv的库。 二、QT使用opencv 在QT pr…

el-table表格点击单元格实现编辑

使用 el-table 和 el-table-column 创建表格。在单元格的默认插槽中,使用 div 显示文本内容,单击时触发编辑功能。使用 el-input 组件在单元格中显示编辑框。data() 方法中定义了 tableData,tabClickIndex: null,tabClickLabel: ,用于判断是否…

idea隐藏无关文件

idea隐藏无关文件 如果你想隐藏某些特定类型的文件(例如 .log 文件或 .tmp 文件),可以通过以下步骤设置: 打开设置 在菜单栏中选择 File > Settings(Windows/Linux)或 IntelliJ IDEA > Preference…

备考蓝桥杯嵌入式4:使用LCD显示我们捕捉的PWM波

上一篇博客我们提到了定时器产生PWM波,现在,我们尝试的想要捕获我们的PWM波,测量它的频率,我们应该怎么做呢?答案还是回到我们的定时器上。 我们知道,定时器是一个高级的秒表(参考笔者的比喻&a…

ChatGPT-4o和ChatGPT-4o mini的差异点

在人工智能领域,OpenAI再次引领创新潮流,近日正式发布了其最新模型——ChatGPT-4o及其经济实惠的小型版本ChatGPT-4o Mini。这两款模型虽同属于ChatGPT系列,但在性能、应用场景及成本上展现出显著的差异。本文将通过图文并茂的方式&#xff0…

Codeforces Round 1002 (Div. 2)(部分题解)

补题链接 A. Milya and Two Arrays 思路&#xff1a;题意还是比较好理解&#xff0c;分析的话我加了一点猜的成分&#xff0c;对a&#xff0c;b数组的种类和相加小于4就不行&#xff0c;蒋老师的乘完后小于等于2也合理。 AC代码&#xff1a; #include <bits/stdc.h> u…

机器学习中的关键概念:通过SKlearn的MNIST实验深入理解

欢迎来到我的主页&#xff1a;【Echo-Nie】 本篇文章收录于专栏【机器学习】 1 sklearn相关介绍 Scikit-learn 是一个广泛使用的开源机器学习库&#xff0c;提供了简单而高效的数据挖掘和数据分析工具。它建立在 NumPy、SciPy 和 matplotlib 等科学计算库之上&#xff0c;支持…

【Linux系统】信号:信号保存 / 信号处理、内核态 / 用户态、操作系统运行原理(中断)

理解Linux系统内进程信号的整个流程可分为&#xff1a; 信号产生 信号保存 信号处理 上篇文章重点讲解了 信号的产生&#xff0c;本文会讲解信号的保存和信号处理相关的概念和操作&#xff1a; 两种信号默认处理 1、信号处理之忽略 ::signal(2, SIG_IGN); // ignore: 忽略#…

OpenAI新商标申请曝光:AI硬件、机器人、量子计算全线布局?

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

python学opencv|读取图像(五十六)使用cv2.GaussianBlur()函数实现图像像素高斯滤波处理

【1】引言 前序学习了均值滤波和中值滤波&#xff0c;对图像的滤波处理有了基础认知&#xff0c;相关文章链接为&#xff1a; python学opencv|读取图像&#xff08;五十四&#xff09;使用cv2.blur()函数实现图像像素均值处理-CSDN博客 python学opencv|读取图像&#xff08;…