特征交叉-MaskNet文章总结代码实现

MaskNet 这个模型是微博21年提出的,23年twitter(X)开源的推荐系统排序模块使用的backbone结构。
核心思想是认为DNN为主的特征交叉是addictive,交叉效率不高;所以设计了一种multiplicatvie的特征交叉
如何设计muliplicative特征交叉呢?
1)首先设计了一个instance-guide-mask,下图是instance-guide-mask的设计,其实就是两层feed-forward-layer,第一层把原始输入维度扩增,第二层再还原回去,总结而言,就是这个公式:
V m a s k = W d 2 ( R e l u ( W d 1 V e m b + β d 1 ) ) + β d 2 V_{mask} = W_{d2}(Relu(W_{d1} V_{emb} + \beta_{d1})) + \beta_{d2} Vmask=Wd2(Relu(Wd1Vemb+βd1))+βd2

𝑉 𝑒 𝑚 𝑏 ∈ R 𝑚 = 𝑓 × 𝑘 𝑉_{𝑒𝑚𝑏}∈ R^{𝑚=𝑓 ×𝑘} VembRm=f×k 输入的embedding结果, f是输入特征的数量,k是特征embedding维度。
最终输出的是一个处理后的embedding向量,后面简称为mask
instance-guide-mask2) 这个embedding得到后,怎么使用呢,这个就是MaskBlock干的事情。
主要有两种使用,一个是对embedding进行处理,图里LN-EMB里的LN指的是Layer Normalization
V 𝑚 𝑎 𝑠 𝑘 𝑒 𝑑 𝐸 𝑀 𝐵 = V 𝑚 𝑎 𝑠 𝑘 ⊙ L N _ E M B ( V 𝑒 𝑚 b ) V_{𝑚𝑎𝑠𝑘𝑒𝑑𝐸𝑀𝐵} = V_{𝑚𝑎𝑠𝑘} ⊙ LN\_EMB(V_{𝑒𝑚b}) VmaskedEMB=VmaskLN_EMB(Vemb), 把mask的结果和LN-EMB 进行element-wide product, 然后在接一个linear,LN后应用在Relu做下非线性激活,这个就是MaskBLock的全部了, 总结成一个公式:
V o u t p u t = L N _ H I D ( W i V m a s k e d E M B ) = R e L U ( L N ( W i ( V m a s k ⊙ L N _ E M B ( V 𝑒 𝑚 b ) ) ) V_{output} = LN\_HID(W_i V_{maskedEMB}) = ReLU(LN(W_i (V_{mask} ⊙ LN\_EMB(V_{𝑒𝑚b}))) Voutput=LN_HID(WiVmaskedEMB)=ReLU(LN(Wi(VmaskLN_EMB(Vemb)))
Embedding除了对Embedding进行element-wide-product,还可以对神经网络的输出再和mask做一次处理,这个就是另一种mask的应用方式:
V o u t p u t = L N _ H I D ( W i V m a s k d H I D ) = R e L U ( L N ( W i ( V m a s k ⊙ V o u t p u t p ) ) ) V_{output} = LN\_HID(W_i V_{maskdHID}) = ReLU(LN(W_i(V_{mask} ⊙ V_{output}^p))) Voutput=LN_HID(WiVmaskdHID)=ReLU(LN(Wi(VmaskVoutputp)))
在这里插入图片描述
1) 2) 结束之后,文章的核心内容也基本结束,后面3)是MaskBlock的应用
3)MaskNet
所有特征都和Instance-guide-mask进行运算,可以是串行的也可以是并行的。
串行的第一个是一个MaskBlock on feature embedding,后面接的都是MaskBlock on MaskBlock;
并行的比较简单,每一个都是一个MaskBlock on feature embedding,然后concat到一起
在这里插入图片描述

二 Implementation
1)torch 代码实现,摘录自twitter开源代码:

def _init_weights(module):
  if isinstance(module, torch.nn.Linear):
    torch.nn.init.xavier_uniform_(module.weight)
    torch.nn.init.constant_(module.bias, 0)

class MaskBlock(torch.nn.Module):
  def __init__(self, 
              mask_block_config: config.MaskBlockConfig, 
              input_dim: int, 
              mask_input_dim: int) -> None:
    super(MaskBlock, self).__init__()
    self.mask_block_config = mask_block_config
    output_size = mask_block_config.output_size

    if mask_block_config.input_layer_norm: 
       # twitter实现的这里layer normalization做了可配置的
      self._input_layer_norm = torch.nn.LayerNorm(input_dim)
    else:
      self._input_layer_norm = None
    # instace-guide-mask第一层aggregation的神经元数量配置
    # 如果指定了压缩量,就是input * 压缩量;如果没有,那么也可以手动指定大小
    if mask_block_config.reduction_factor:
      aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
    elif mask_block_config.aggregation_size is not None:
      aggregation_size = mask_block_config.aggregation_size
    else:
      raise ValueError("Need one of reduction factor or aggregation size.")

    # instance-guide-mask is here
    # 两层Linear
    self._mask_layer = torch.nn.Sequential(
      torch.nn.Linear(mask_input_dim, aggregation_size),
      torch.nn.ReLU(),
      torch.nn.Linear(aggregation_size, input_dim),
    )
    # 参数初始化
    self._mask_layer.apply(_init_weights)
    self._hidden_layer = torch.nn.Linear(input_dim, output_size)
    self._hidden_layer.apply(_init_weights)
    self._layer_norm = torch.nn.LayerNorm(output_size)

  def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
    # LN
    if self._input_layer_norm:
      net = self._input_layer_norm(net)
    # self._mask_layer(mask_input)-- V_mask
    # net * V_mask
    hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
    return self._layer_norm(hidden_layer_output)


class MaskNet(torch.nn.Module):
  def __init__(self, 
               mask_net_config: config.MaskNetConfig, 
               in_features: int):
    super().__init__()
    self.mask_net_config = mask_net_config
    mask_blocks = []
    if mask_net_config.use_parallel:
      total_output_mask_blocks = 0
      # 从local_prod参数看,用了4个block
      for mask_block_config in mask_net_config.mask_blocks:
        mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
        total_output_mask_blocks += mask_block_config.output_size
      self._mask_blocks = torch.nn.ModuleList(mask_blocks)
    else:
      input_size = in_features
      for mask_block_config in mask_net_config.mask_blocks:
        mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
        input_size = mask_block_config.output_size

      self._mask_blocks = torch.nn.ModuleList(mask_blocks)
      total_output_mask_blocks = mask_block_config.output_size

    if mask_net_config.mlp:
      self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
      self.out_features = mask_net_config.mlp.layer_sizes[-1]
    else:
      self.out_features = total_output_mask_blocks
    self.shared_size = total_output_mask_blocks

  def forward(self, inputs: torch.Tensor):
    if self.mask_net_config.use_parallel:
      # 并行化的网络结构实现
      mask_outputs = []
      # 对于多个Block,每一个block输入都是一样,只是其中学习到的参数有所不同
      for mask_layer in self._mask_blocks:
        # mask_input,net 都是inputs
        mask_outputs.append(mask_layer(mask_input=inputs, net=inputs)) 
      # Share the outputs of the MaskBlocks.
      all_mask_outputs = torch.cat(mask_outputs, dim=1)
      # 最终输出处理
      output = (
        all_mask_outputs
        if self.mask_net_config.mlp is None
        else self._dense_layers(all_mask_outputs)["output"])
      return {"output": output, "shared_layer": all_mask_outputs}
    else:
      # 串行
      net = inputs
      for mask_layer in self._mask_blocks:
        # mask_input 是inputs,net输入是上一层的输出
        net = mask_layer(net=net, mask_input=inputs)
      # Share the output of the stacked MaskBlocks.
      output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
      return {"output": output, "shared_layer": net}

2)tensorflow实现
摘录自EasyRec(阿里开源推荐工具)

# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.layers import Layer
from easy_rec.python.layers.keras.blocks import MLP
from easy_rec.python.layers.keras.layer_norm import LayerNormalization
from easy_rec.python.layers.utils import Parameter

class MaskBlock(Layer):
  """MaskBlock use in MaskNet.
  Args:
    projection_dim: project dimension to reduce the computational cost.
    Default is `None` such that a full (`input_dim` by `aggregation_size`) matrix
    W is used. If enabled, a low-rank matrix W = U*V will be used, where U
    is of size `input_dim` by `projection_dim` and V is of size
    `projection_dim` by `aggregation_size`. `projection_dim` need to be smaller
    than `aggregation_size`/2 to improve the model efficiency. In practice, we've
    observed that `projection_dim` = d/4 consistently preserved the
    accuracy of a full-rank version.
  """
  def __init__(self, params, name='mask_block', reuse=None, **kwargs):
    super(MaskBlock, self).__init__(name=name, **kwargs)
    self.config = params.get_pb_config()
    self.l2_reg = params.l2_regularizer
    self._projection_dim = params.get_or_default('projection_dim', None)
    self.reuse = reuse
    self.final_relu = Activation('relu', name='relu')

  def build(self, input_shape):
    if type(input_shape) in (tuple, list):
      assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs'
      input_dim = int(input_shape[0][-1])
      mask_input_dim = int(input_shape[1][-1])
    else:
      input_dim, mask_input_dim = input_shape[-1], input_shape[-1]
    # 这里实现和pytorch一样
    if self.config.HasField('reduction_factor'):
      aggregation_size = int(mask_input_dim * self.config.reduction_factor)
    elif self.config.HasField('aggregation_size') is not None:
      aggregation_size = self.config.aggregation_size
    else:
      raise ValueError('Need one of reduction factor or aggregation size for MaskBlock.')
    # instance-guide-mask第一层      
    self.aggr_layer = Dense(
        aggregation_size,
        activation='relu',
        kernel_initializer='he_uniform',
        kernel_regularizer=self.l2_reg,
        name='aggregation')
    # instance-guide-mask第二层
    self.weight_layer = Dense(input_dim, name='weights')
    # 对比pytorch实现,增加了projection_dim, 低秩矩阵(详见DCN)
    if self._projection_dim is not None:
      logging.info('%s project dim is %d', self.name, self._projection_dim)
      self.project_layer = Dense(
          self._projection_dim,
          kernel_regularizer=self.l2_reg,
          use_bias=False,
          name='project')
    if self.config.input_layer_norm:
      # 推荐在调用MaskBlock之前做好 layer norm,否则每一次调用都需要对input做ln
      if tf.__version__ >= '2.0':
        self.input_layer_norm = tf.keras.layers.LayerNormalization(
            name='input_ln')
      else:
        self.input_layer_norm = LayerNormalization(name='input_ln')

    if self.config.HasField('output_size'):
      self.output_layer = Dense(
          self.config.output_size, use_bias=False, name='output')
          
    # tensorflow遗留问题,兼容1/2
    if tf.__version__ >= '2.0':
      self.output_layer_norm = tf.keras.layers.LayerNormalization(
          name='output_ln')
    else:
      self.output_layer_norm = LayerNormalization(name='output_ln')
    super(MaskBlock, self).build(input_shape)

  def call(self, inputs, training=None, **kwargs):
    if type(inputs) in (tuple, list):
      net, mask_input = inputs[:2]
    else:
      net, mask_input = inputs, inputs
    # LN
    if self.config.input_layer_norm:
      net = self.input_layer_norm(net)
    # tensorflow实现aggregate层和projection层是分开的,上面pytorch是用一个sequence
    if self._projection_dim is None:
      aggr = self.aggr_layer(mask_input)
    else:
      u = self.project_layer(mask_input)
      aggr = self.aggr_layer(u)
   # 得到mask结果
    weights = self.weight_layer(aggr)
    # elemnet-wide product
    masked_net = net * weights

    if not self.config.HasField('output_size'):
      return masked_net
    # 最终处理,一个Liner+layer norm层
    hidden = self.output_layer(masked_net)
    ln_hidden = self.output_layer_norm(hidden)
    return self.final_relu(ln_hidden)

class MaskNet(Layer):
  def __init__(self, params, name='mask_net', reuse=None, **kwargs):
    super(MaskNet, self).__init__(name=name, **kwargs)
    self.reuse = reuse
    self.params = params
    self.config = params.get_pb_config()
    if self.config.HasField('mlp'):
      p = Parameter.make_from_pb(self.config.mlp)
      p.l2_regularizer = params.l2_regularizer
      self.mlp = MLP(p, name='mlp', reuse=reuse)
    else:
      self.mlp = None

    self.mask_layers = []
    for i, block_conf in enumerate(self.config.mask_blocks):
      params = Parameter.make_from_pb(block_conf)
      params.l2_regularizer = self.params.l2_regularizer
      mask_layer = MaskBlock(params, name='block_%d' % i, reuse=self.reuse)
      self.mask_layers.append(mask_layer)

    if self.config.input_layer_norm:
      if tf.__version__ >= '2.0':
        self.input_layer_norm = tf.keras.layers.LayerNormalization(
            name='input_ln')
      else:
        self.input_layer_norm = LayerNormalization(name='input_ln')

  def call(self, inputs, training=None, **kwargs):
    # 与pytorch版本对比,对输入也进行了一次layer norm
    if self.config.input_layer_norm:
      inputs = self.input_layer_norm(inputs)
    # 下面的并行/串行实现逻辑无差
    if self.config.use_parallel:
      mask_outputs = [
          mask_layer((inputs, inputs)) for mask_layer in self.mask_layers
      ]
      all_mask_outputs = tf.concat(mask_outputs, axis=1)
      if self.mlp is not None:
        output = self.mlp(all_mask_outputs, training=training)
      else:
        output = all_mask_outputs
      return output
    else:
      net = inputs
      for i, _ in enumerate(self.config.mask_blocks):
        mask_layer = self.mask_layers[i]
        net = mask_layer((net, inputs))

      if self.mlp is not None:
        output = self.mlp(net, training=training)
      else:
        output = net
      return output

Reference:
MaskNet: Introducing Feature-Wise Multiplication to CTR Ranking Models by Instance-Guided Mask
tesorflow实现
twitter-alg-ml

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/921196.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GRU (门控循环单元 - 基于RNN - 简化LSTM又快又好 - 体现注意力的思想) + 代码实现 —— 笔记3.5《动手学深度学习》

目录 0. 前言 1. 门控隐状态 1.1 重置门和更新门 1.2 候选隐状态 1.3 隐状态 2. 从零开始实现 2.1 初始化模型参数 2.2 定义模型 2.3 训练与预测 3 简洁实现 4. 小结 0. 前言 课程全部代码(pytorch版)已上传到附件看懂上一篇RNN的所有细节&am…

Java 基于SpringBoot+vue框架的老年医疗保健网站

大家好,我是Java徐师兄,今天为大家带来的是Java Java 基于SpringBootvue框架的老年医疗保健网站。该系统采用 Java 语言开发,SpringBoot 框架,MySql 作为数据库,系统功能完善 ,实用性强 ,可供大…

JavaWeb-表单-07

表单标签 介绍 code: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>HTML-表单</title> &…

计算机网络socket编程(4)_TCP socket API 详解

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 计算机网络socket编程(4)_TCP socket API 详解 收录于专栏【计算机网络】 本专栏旨在分享学习计算机网络的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&…

向量数据库FAISS之五:原理(LSH、PQ、HNSW、IVF)

1.Locality Sensitive Hashing (LSH) 使用 Shingling MinHashing 进行查找 左侧是字典&#xff0c;右侧是 LSH。目的是把足够相似的索引放在同一个桶内。 LSH 有很多的版本&#xff0c;很灵活&#xff0c;这里先介绍第一个版本&#xff0c;也是原始版本 Shingling one-hot …

Django启用国际化支持(2)—实现界面内切换语言:activate()

文章目录 ⭐注意⭐1. 配置项目全局设置&#xff1a;启用国际化2. 编写视图函数3. 配置路由4. 界面演示5、扩展自动识别并切换到当前语言设置语言并保存到Session设置语言并保存到 Cookie ⭐注意⭐ 以下操作依赖于 Django 项目的国际化支持。如果你不清楚如何启用国际化功能&am…

【初阶数据结构与算法】线性表之栈的定义与实现(含源码和有效的括号练习)

文章目录 一、栈的概念与结构1.栈的概念与操作2.栈的底层结构选型 二、栈的实现1.栈结构的定义2. 栈的初始化和销毁栈的初始化栈的销毁 3.栈的扩容与入栈栈的扩容入栈 4.判断栈是否为空和出栈判断栈是否为空出栈 5.取栈顶元素和获取栈中有效元素个数取栈顶元素获取栈中有效元素…

基于Spring Boot+Unipp的博物馆预约小程序(协同过滤算法、二维码识别)【原创】

&#x1f388;系统亮点&#xff1a;协同过滤算法、二维码识别&#xff1b; 一.系统开发工具与环境搭建 1.系统设计开发工具 后端使用Java编程语言的Spring boot框架 项目架构&#xff1a;B/S架构 运行环境&#xff1a;win10/win11、jdk17 前端&#xff1a; 技术&#xff1a;框…

Python 快速入门(上篇)❖ Python基础知识

Python 基础知识 Python安装**运行第一个程序:基本数据类型算术运算符变量赋值操作符转义符获取用户输入综合案例:简单计算器实现Python安装** Linux安装: yum install python36 -y或者编译安装指定版本:https://www.python.org/downloads/source/ wget https://www.pyt…

【MyBatisPlus·最新教程】包含多个改造案例,常用注解、条件构造器、代码生成、静态工具、类型处理器、分页插件、自动填充字段

文章目录 一、MyBatis-Plus简介二、快速入门1、环境准备2、将mybatis项目改造成mybatis-plus项目&#xff08;1&#xff09;引入MybatisPlus依赖&#xff0c;代替MyBatis依赖&#xff08;2&#xff09;配置Mapper包扫描路径&#xff08;3&#xff09;定义Mapper接口并继承BaseM…

【人工智能】PyTorch、TensorFlow 和 Keras 全面解析与对比:深度学习框架的终极指南

文章目录 PyTorch 全面解析2.1 PyTorch 的发展历程2.2 PyTorch 的核心特点2.3 PyTorch 的应用场景 TensorFlow 全面解析3.1 TensorFlow 的发展历程3.2 TensorFlow 的核心特点3.3 TensorFlow 的应用场景 Keras 全面解析4.1 Keras 的发展历程4.2 Keras 的核心特点4.3 Keras 的应用…

Chrome 浏览器 131 版本新特性

Chrome 浏览器 131 版本新特性 一、Chrome 浏览器 131 版本更新 1. 在 iOS 上使用 Google Lens 搜索 自 Chrome 126 版本以来&#xff0c;用户可以通过 Google Lens 搜索屏幕上看到的任何图片或文字。 要使用此功能&#xff0c;请访问网站&#xff0c;并点击聚焦时出现在地…

2.不同语音ai任务dataset类写法

主流语音任务 语音数据读取基本原则 直接保存语音会将该对象保存在内存中&#xff08;Dataset类调用__getitem__方法&#xff09; 所以一般保存这些数据的存储路径文档&#xff08;表单&#xff09;而不是数据的直接copy&#xff08;不然占用内存太大了&#xff09; 通常用nump…

K8S + Jenkins 做CICD

前言 这里会做整体CICD的思路和流程的介绍&#xff0c;会给出核心的Jenkins pipeline脚本&#xff0c;最后会演示一下 实验/实操 结果 由于整体内容较多&#xff0c;所以不打算在这里做每一步的详细演示 - 本文仅作自己的实操记录和日后回顾用 要看保姆式教学的可以划走了&…

力扣 LeetCode 701. 二叉搜索树中的插入操作(Day10:二叉树)

解题思路&#xff1a; 全部插入到叶子节点即可 class Solution {public TreeNode insertIntoBST(TreeNode root, int val) {if (root null) {TreeNode node new TreeNode(val);return node;}if (root.val < val) {root.right insertIntoBST(root.right, val);}if (root…

2024年11月22日Github流行趋势

项目名称&#xff1a;twenty 项目维护者&#xff1a;charlesBochet, lucasbordeau, Weiko, FelixMalfait, bosiraphael 项目介绍&#xff1a;正在构建一个由社区驱动的现代Salesforce替代方案。 项目star数&#xff1a;22,938 项目fork数&#xff1a;2,413 项目名称&#xff1…

Qt之QMainWidget相关

QMainWindow 继承于QWidget的子类 自带一个菜单栏,一个工具栏,可以设置状态栏与铆钉部件 菜单栏:QMenuBar 注意:一个窗口最多一个菜单栏 API: 创建 QMenuBar(parent) 获取QMainWindow自带的菜单栏 QMenuBar* menuBar() 添加菜单:QMenu addMenu(QMenu *menu); 菜单添加活动:QAct…

【深度学习之一】2024最新pytorch+cuda+cudnn下载安装搭建开发环境

兵马未动&#xff0c;粮草先行。作为深度学习的初学者&#xff0c;快速搭建一个属于自己的开发环境就是头等大事&#xff0c;可以让我们节省许多的时间。这一期我们主要讲一讲2024年最新pytorchcudacudnn下载安装搭建开发环境&#xff0c;以及安装过程中可能遇到的一些问题以及…

SQL 复杂查询

目录 复杂查询 一、目的和要求 二、实验内容 &#xff08;1&#xff09;查询出所有水果产品的类别及详情。 查询出编号为“00000001”的消费者用户的姓名及其所下订单。&#xff08;分别采用子查询和连接方式实现&#xff09; 查询出每个订单的消费者姓名及联系方式。 在…

如何在 UniApp 中实现 iOS 版本更新检测

随着移动应用的不断发展&#xff0c;保持应用程序的更新是必不可少的&#xff0c;这样用户才能获得更好的体验。本文将帮助你在 UniApp 中实现 iOS 版的版本更新检测和提示&#xff0c;适合刚入行的小白。我们将分步骤进行说明&#xff0c;每一步所需的代码及其解释都会一一列出…