特征交叉-CAN学习笔记代码解读

一 核心模块coaction

  1. 对于每个特征对(feature_pairs)
  2. weight, bias 来自于P_induction
  3. P_fead是MLP的input

举个例子:如果是用户ID和产品ID的co-action,且产品ID是做induction,用户ID是做feed。

  • step1 用户ID/产品ID都先形成一个向量:对于产品ID,用parameter lookup获取一个可学习的P_induction(这个维度是(wi+bi) * L depth of mlp); 用户ID则直接形成一个向量P_fead
  • step2 P_induction 这个向量逐层(MLP层),reshape成MLP网络的weight 和bias;
  • step3 weight和bias作为MLP的参数,利用P_feed 作为input,进行MLP前向运算,得到特征交互结果
  1. 代码解读
#### CAN config #####
weight_emb_w = [[16, 8], [8,4]] # micro-mlp的参数dimension
weight_emb_b = [0, 0]           # bias参数
orders = 3  # 特征的阶数,文章提到了,要做高阶特征交叉,直接是P_feed^c, c就是阶数
order_indep = False # True
WEIGHT_EMB_DIM = (sum([w[0]*w[1] for w in weight_emb_w]) + sum(weight_emb_b)) # * orders 这个是供每一个micro-mlp拆解w&b需要的dimension总和
INDEP_NUM = 1
if order_indep:
    INDEP_NUM *= orders
###### 这一部分对应图中绿色和橙色部分,主要是把P_feed&P_induction的嵌入表示得到 ##########
if self.use_coaction:
   # batch_ph batch输入的数据;his_batch_ph历史批次数据; his_batch_embedded 历史嵌入表示
   ph_dict = {
       "item": [self.mid_batch_ph, self.mid_his_batch_ph, self.mid_his_batch_embedded],
       "cate": [self.cate_batch_ph, self.cate_his_batch_ph, self.cate_his_batch_embedded]
   }
   ### p_induction ####
   self.mlp_batch_embedded = [] # induction embedding
   with tf.device(device):
       # 定义可训练的嵌入矩阵,在这里n_mid是item id的数量
       self.item_mlp_embeddings_var = tf.get_variable("item_mlp_embedding_var", [n_mid, INDEP_NUM * WEIGHT_EMB_DIM], trainable=True)
       self.cate_mlp_embeddings_var = tf.get_variable("cate_mlp_embedding_var", [n_cate, INDEP_NUM * WEIGHT_EMB_DIM], trainable=True)
       # 通过embedding_lookup在上一步初始化好的矩阵中找到对应的embedding表示
       self.mlp_batch_embedded.append(tf.nn.embedding_lookup(self.item_mlp_embeddings_var, ph_dict['item'][0]))
       self.mlp_batch_embedded.append(tf.nn.embedding_lookup(self.cate_mlp_embeddings_var, ph_dict['cate'][0]))
       #########P_feed input ########
       self.input_batch_embedded = []
       self.item_input_embeddings_var = tf.get_variable("item_input_embedding_var", [n_mid, weight_emb_w[0][0] * INDEP_NUM], trainable=True)
       self.cate_input_embeddings_var = tf.get_variable("cate_input_embedding_var", [n_cate, weight_emb_w[0][0] * INDEP_NUM], trainable=True)  
         self.input_batch_embedded.append(tf.nn.embedding_lookup(self.item_input_embeddings_var, ph_dict['item'][1]))
       self.input_batch_embedded.append(tf.nn.embedding_lookup(self.cate_input_embeddings_var, ph_dict['cate'][1]))
################这一部分是P_induction&P_feed在MLP的使用#######################
if self.use_coaction:
    # p_feed/input
    input_batch = self.input_batch_embedded
    tmp_sum, tmp_seq = [], []
    if INDEP_NUM == 2:
        # 文章说明了是feature pairs,mlp_batch&input_batch都包含了两个部分,要分别组合
        for i, mlp_batch in enumerate(self.mlp_batch_embedded):
            for j, input_batch in enumerate(self.input_batch_embedded):
                coaction_sum, coaction_seq = gen_coaction(
                    mlp_batch[:, WEIGHT_EMB_DIM * j:  WEIGHT_EMB_DIM * (j+1)], 
                    input_batch[:, :, weight_emb_w[0][0] * i: weight_emb_w[0][0] * (i+1)],  
                    EMBEDDING_DIM, 
                    mode=CALC_MODE,
                    mask=self.mask) 
                
                tmp_sum.append(coaction_sum)
                tmp_seq.append(coaction_seq)
    else:
        for i, (mlp_batch, input_batch) in enumerate(zip(self.mlp_batch_embedded, self.input_batch_embedded)):
            coaction_sum, coaction_seq = gen_coaction(
                  mlp_batch[:, :INDEP_NUM * WEIGHT_EMB_DIM], 
                  input_batch[:, :, :weight_emb_w[0][0]],  
                  EMBEDDING_DIM, 
                  mode=CALC_MODE, 
                  mask=self.mask) 
            
            tmp_sum.append(coaction_sum)
            tmp_seq.append(coaction_seq)
            
    self.coaction_sum = tf.concat(tmp_sum, axis=1) # sum pooling
    self.cross.append(self.coaction_sum)   # concat              
###### core interaction 核心运算 #########
def gen_coaction(ad, his_items, dim, mode="can", mask=None):
    """
    ad: induct
    his_items 待交互seq
    """
    weight, bias = [], []
    idx = 0
    weight_orders = []
    bias_orders = []
    # 拆解得到weight&bias参数
    for i in range(orders):
        for w, b in zip(weight_emb_w, weight_emb_b):
            weight.append(tf.reshape(ad[:, idx:idx+w[0]*w[1]], [-1, w[0], w[1]]))
            idx += w[0] * w[1]
            if b == 0:
                bias.append(None)
            else:
                bias.append(tf.reshape(ad[:, idx:idx+b], [-1, 1, b]))
                idx += b
        weight_orders.append(weight)
        bias_orders.append(bias)
        if not order_indep:
            break
 
    if mode == "can":
        out_seq = []
        hh = []
        # 高阶特征处理,explicit deal with
        for i in range(orders):
            hh.append(his_items**(i+1))
        #hh = [sum(hh)]
        for i, h in enumerate(hh):
            if order_indep:
                weight, bias = weight_orders[i], bias_orders[i]
            else:
                weight, bias = weight_orders[0], bias_orders[0]
            # 模拟MLP forward calculation
            for j, (w, b) in enumerate(zip(weight, bias)):
                h  = tf.matmul(h, w)
                if b is not None:
                    h = h + b
                if j != len(weight)-1:
                    h = tf.nn.tanh(h)
                out_seq.append(h)
        out_seq = tf.concat(out_seq, 2)
        if mask is not None:
            mask = tf.expand_dims(mask, axis=-1) 
            out_seq = out_seq * mask
            
    # 序列交互结果做sum_pooling
    out = tf.reduce_sum(out_seq, 1)
    if keep_fake_carte_seq and mode=="emb":
        return out, out_seq
    return out, None

二 文章中的应用
整体的模型结构两部分构成:

  • co-action作为核心形成的一部分,对于用户的序列特征,一一作用后做sum-pooling,对于非序列特征,作用后直接输出
  • DIEN作为核心形成的一部分

两部分concat以后加一个DNN常规操作,看起来就像是用co-action做显式的特征交叉,然后DIEN做之前的序列建模。
在这里插入图片描述
三 一些其他细节补充

  1. can 部分高阶特征处理: 直接把待交叉特征p_fead 做c阶运算后,再与p_induction进行作用
  2. 在文章场景,p_induction是target_item,也就是产品

四 用tf2/torch重构

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class CAN_Model(nn.Module):
    def __init__(self, n_uid, n_mid, n_cate, n_carte, EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE, use_negsampling=False, use_softmax=True, use_coaction=False, use_cartes=False):
        super(CAN_Model, self).__init__()
        
        self.n_uid = n_uid
        self.n_mid = n_mid
        self.n_cate = n_cate
        self.n_carte = n_carte
        self.EMBEDDING_DIM = EMBEDDING_DIM
        self.HIDDEN_SIZE = HIDDEN_SIZE
        self.ATTENTION_SIZE = ATTENTION_SIZE
        self.use_negsampling = use_negsampling
        self.use_softmax = use_softmax
        self.use_coaction = use_coaction
        self.use_cartes = use_cartes

        self.uid_embeddings = nn.Embedding(n_uid, EMBEDDING_DIM)
        self.mid_embeddings = nn.Embedding(n_mid, EMBEDDING_DIM)
        self.cate_embeddings = nn.Embedding(n_cate, EMBEDDING_DIM)

        if use_cartes:
            self.carte_embeddings = nn.ModuleList([nn.Embedding(num, EMBEDDING_DIM) for num in n_carte])

        if self.use_coaction:
            self.item_mlp_embeddings = nn.Parameter(torch.randn(n_mid, INDEP_NUM * WEIGHT_EMB_DIM))
            self.cate_mlp_embeddings = nn.Parameter(torch.randn(n_cate, INDEP_NUM * WEIGHT_EMB_DIM))
            self.input_batch_embeddings = nn.ModuleList([nn.Embedding(n_mid, weight_emb_w[0][0] * INDEP_NUM), nn.Embedding(n_cate, weight_emb_w[0][0] * INDEP_NUM)])

        self.fc1 = nn.Linear(200, 80)
        self.fc2 = nn.Linear(80, 2 if use_softmax else 1)

    def forward(self, uid, mid, cate, mid_his, cate_his, mask, target, seq_len, lr, carte=None):
        # Embedding lookups
        uid_emb = self.uid_embeddings(uid)
        mid_emb = self.mid_embeddings(mid)
        cate_emb = self.cate_embeddings(cate)
        mid_his_emb = self.mid_embeddings(mid_his)
        cate_his_emb = self.cate_embeddings(cate_his)

        if self.use_cartes:
            carte_emb = [emb(carte[:, i, :]) for i, emb in enumerate(self.carte_embeddings)]

        # Co-action logic (if enabled)
        if self.use_coaction:
            # This is a simplified version of the co-action implementation from the original TensorFlow code
            mlp_embedded_item = self.item_mlp_embeddings[mid]
            mlp_embedded_cate = self.cate_mlp_embeddings[cate]
            input_embedded_item = self.input_batch_embeddings[0](mid_his)
            input_embedded_cate = self.input_batch_embeddings[1](cate_his)
            # Further coaction operations can be added based on your logic

        # Concatenate item and category embeddings
        item_eb = torch.cat([mid_emb, cate_emb], dim=1)
        item_his_eb = torch.cat([mid_his_emb, cate_his_emb], dim=2)
        item_his_eb_sum = item_his_eb.sum(dim=1)

        if self.use_negsampling:
            # Assuming the negative sampling implementation would need its own logic.
            pass

        # FC layers
        x = self.fc1(item_eb)
        x = F.relu(x)
        x = self.fc2(x)

        # Loss computation
        if self.use_softmax:
            y_hat = F.softmax(x, dim=-1)
            loss = F.cross_entropy(y_hat, target)
        else:
            y_hat = torch.sigmoid(x)
            loss = F.binary_cross_entropy_with_logits(x, target)

        return loss, y_hat

    def auxiliary_loss(self, h_states, click_seq, noclick_seq, mask):
        mask = mask.float()
        click_input = torch.cat([h_states, click_seq], dim=-1)
        noclick_input = torch.cat([h_states, noclick_seq], dim=-1)
        click_prop = self.auxiliary_net(click_input)[:, :, 0]
        noclick_prop = self.auxiliary_net(noclick_input)[:, :, 0]
        click_loss = -torch.log(click_prop) * mask
        noclick_loss = -torch.log(1.0 - noclick_prop) * mask
        loss = (click_loss + noclick_loss).mean()
        return loss

    def auxiliary_net(self, in_):
        x = F.relu(self.fc1(in_))
        x = F.relu(self.fc2(x))
        return x

    def train_step(self, data, optimizer):
        optimizer.zero_grad()
        loss, y_hat = self(data)
        loss.backward()
        optimizer.step()
        return loss.item()

    def evaluate(self, data):
        with torch.no_grad():
            loss, y_hat = self(data)
        return loss.item(), y_hat

# Example of using the model
n_uid = 1000
n_mid = 1000
n_cate = 500
n_carte = [10, 20]  # Example carte sizes
EMBEDDING_DIM = 128
HIDDEN_SIZE = 256
ATTENTION_SIZE = 128

model = CAN_Model(n_uid, n_mid, n_cate, n_carte, EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Example data
uid = torch.randint(0, n_uid, (32,))
mid = torch.randint(0, n_mid, (32,))
cate = torch.randint(0, n_cate, (32,))
mid_his = torch.randint(0, n_mid, (32, 5))
cate_his = torch.randint(0, n_cate, (32, 5))
mask = torch.ones(32, 5)
target = torch.randint(0, 2, (32,))
seq_len = torch.randint(1, 5, (32,))
lr = 0.001

# Training step
loss = model.train_step((uid, mid, cate, mid_his, cate_his, mask, target, seq_len, lr), optimizer)
print(f"Loss: {loss}")

Reference:

  1. 文章形成思路历程
  2. CAN: Feature Co-Action for Click-Through Rate Prediction-21年,阿里
  3. Implementation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/937373.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Java面试】深拷贝、浅拷贝和引用拷贝三者的区别

浅拷贝:浅拷贝会在堆上创建一个新的对象(区别于引用拷贝的一点),不过,如果原对象内部的属性是引用类型的话,浅拷贝会直接复制内部对象的引用地址,也就是说拷贝对象和原对象共用同一个内部对象。…

EasyGBS点对点穿透P2P远程访问技术在安防视频监控中的应用

随着信息技术的快速发展,安防视频监控系统在公共安全领域的应用变得越来越广泛。传统的视频监控系统多依赖于中心服务器进行视频流的集中处理和分发,这不仅增加了网络带宽的负担,还可能成为系统性能瓶颈。为了解决这些问题,P2P&am…

Vue入门到精通:核心语法—模板语法

Vue入门到精通:核心语法—模板语法 Vue.js因其简单、易用和高效的特点,自推出以来一直受到广泛关注。Vue.js的核心概念和技术包括模板语法、计算属性、事件监听、动态样式绑定、条件渲染指令(如v-if)、列表渲染指令(如…

C++中如何实现接口继承与实现继承,以及它们的区别?

概念 在 C 中,接口继承和实现继承是两种不同的继承方式,它们在设计模式、代码复用和多态性方面有着不同的应用。下面将分别解释这两者的概念、实现方式及其区别。 接口继承 接口继承指的是只继承类的接口(即公共的成员函数声明&#xff09…

WPF+MVVM案例实战与特效(三十八)- 封装一个自定义的数字滚动显示控件

文章目录 1、运行效果2、案例实现1、功能设计2、页面布局3、控件使用4、运行效果3、拓展:多数字自定义控件1、控件应用4、总结1、运行效果 在Windows Presentation Foundation (WPF)应用程序中,自定义控件允许开发者创建具有特定功能和外观的独特UI元素。本博客将介绍一个名…

2024年12月HarmonyOS应用开发者高级认证全新题库

注意事项:切记在考试之外的设备上打开题库进行搜索,防止切屏三次考试自动结束,题目是乱序,每次考试,选项的顺序都不同,作者已于2024年12月15日又更新了一波题库,题库正确率99%! 新版…

【Java学习笔记】JUnit

一、为什么需要 JUnit 二、基本介绍 三、实现方法 第一次添加: 在需要测试的方法处输入 Test注解,快捷键AltInsert选择添加版本(常用JUnit5.4) 出现绿色箭头可进行测试和编译

MySQL误删除 binlog 还原 恢复已删除数据 实战 超详细

硬盘有价,数据无价,数据库执行,谨慎操作! binlog日志还原不适用于直接删表删库的误操作! 目录 实战恢复 1、导出相关时间binlog数据 2、找到对应语句以及pos区间 3、导出改动区间的sql 4、将binlog导出的sql转换…

百度地图JavaScript API核心功能指引

百度地图JavaScript API是一套由JavaScript语言编写的应用程序接口,它能够帮助您在网站中构建功能丰富、交互性强的地图应用,包含了构建地图基本功能的各种接口,提供了诸如本地搜索、路线规划等数据服务。百度地图JavaScript API支持HTTP和HT…

《拉依达的嵌入式\驱动面试宝典》—C/CPP基础篇(五)

《拉依达的嵌入式\驱动面试宝典》—C/CPP基础篇(五) 你好,我是拉依达。 感谢所有阅读关注我的同学支持,目前博客累计阅读 27w,关注1.5w人。其中博客《最全Linux驱动开发全流程详细解析(持续更新)-CSDN博客》已经是 Linux驱动 相关内容搜索的推荐首位,感谢大家支持。 《拉…

C语言简单日志宏

最近调试C代码,发现要写很多打印的内容不是很方便,于是简单写一下C语言的日志来方便自己调试: 1. 简单打印带标识的日志信息 #include "stdio.h" #define PRINT(...) \do \{ …

【算法】—— 前缀和

一、区间求和问题 给定一个长度为n的序列a,有m次查询,每次查询输出一个连续区间的和。 使用暴力做法求解是将每次查询都遍历该区间求和 //暴力做法import java.util.Scanner;public class Test {public static void main(String[] args){Scanner scan…

详解下c语言下的多维数组和指针数组

在实际c语言编程中,三维及以上数组我们使用的很少,二维数组我们使用得较多。说到数组,又不得关联到指针,因为他们两者的联系太紧密了。今天我们就详细介绍下c语言下的多维数组(主要是介绍二维数组)和指针。 一、二维数组 1.1&am…

【实验】【H3CNE邓方鸣】交换机端口安全实验+2024.12.11

实验来源:邓方鸣交换机端口安全实验 软件下载: 华三虚拟实验室: 华三虚拟实验室下载 wireshark:wireshark SecureCRT v8.7 版本: CRT下载分享与破解 文章目录 dot1x 开启802.1X身份验证 开启802.1X身份验证,需要在系统视图和接口视…

leetcode-73.矩阵置零-day5

class Solution {public void setZeroes(int[][] mat) {int m mat.length, n mat[0].length;// 1. 扫描「首行」和「首列」记录「首行」和「首列」是否该被置零boolean r0 false, c0 false;for (int i 0; i < m; i) {if (mat[i][0] 0) {r0 true;break;}}for (int j …

C++ webrtc开发(非原生开发,linux上使用libdatachannel库)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、libdatachannel库的下载和build二、开始使用 1.2.引入库3.开始使用 总结 前言 使用c开发webrtc在互联网上留下的资料甚少&#xff0c;经过我一段时间的探…

windows11 专业版 docker desktop 安装指南

家庭中文版需升级专业版&#xff0c;家庭版没有hyper-v。 开始运行optionalfeatures.exe打开windows功能 安装wsl2 步骤 1 - 启用适用于 Linux 的 Windows 子系统步骤 2 - 检查运行 WSL 2 的要求步骤 3 - 启用虚拟机功能步骤 4 - 下载 Linux 内核更新包 步骤 1 - 启用适用于 L…

解锁前端开发速度的秘密武器【Vite】

在前端开发的江湖中&#xff0c;有人偏爱 Webpack 的强大与稳定&#xff0c;有人钟情于 Rollup 的轻量与高效。而 Vite&#xff0c;这个后来居上的工具&#xff0c;却以“极致的快”和“极简的易”赢得了开发者的芳心。众所周知万事都有缘由&#xff0c;接下来我们就来深度剖析…

AI发展与LabVIEW程序员就业

人工智能&#xff08;AI&#xff09;技术的快速发展确实对许多行业带来了变革&#xff0c;包括自动化、数据分析、软件开发等领域。对于LabVIEW程序员来说&#xff0c;AI的崛起确实引发了一个值得关注的问题&#xff1a;AI会不会取代他们的工作&#xff0c;导致大量失业&#x…

决策曲线分析(DCA)中平均净收益用于评价模型算法(R自定义函数)

决策曲线分析&#xff08;DCA&#xff09;中平均净收益用于评价模型算法 DCA分析虽然不强调用来评价模型算法或者变量组合的优劣&#xff0c;但是实际应用过程中感觉DCA曲线的走势和模型的效能具有良好的一致性&#xff0c;其实这种一致性也可以找到内在的联系&#xff0c;比如…