【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

 

目录

一、引言

1.1 往期回顾

1.2 本期概要

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

三、总结


一、引言

在朴素的深度学习ctr预估模型中(如DNN),通常以一个行为为预估目标,比如通过ctr预估点击率。但实际推荐系统业务场景中,更多是多种目标融合的结果,比如视频推荐,会存在视频点击率、视频完整播放率、视频播放时长等多个目标,而多种目标如何更好的融合,在工业界与学术界均有较多内容产出,由于该环节对实际业务影响最为直接,特开此专栏对推荐系统深度学习多目标问题进行讲述。

1.1 往期回顾

上一篇文章主要介绍了推荐系统多目标算法中的“样本Loss加权”,该方法在训练时Loss乘以样本权重实现对多种目标的加权,通过引导Loss梯度的学习方向,让模型参数朝着你设定的权重方向去学习。

1.2 本期概要

今天进一步深化,主要介绍Shared-Bottom Multi-task Model算法,该算法中文可译为“底部共享多任务模型”,该算法设定多个任务,每个任务设定多个目标,通过“Loss计算时调整每个任务的权重”,亦或是“每个塔单元内,多目标Loss计算时调整每个目标的权重”进行多任务多目标的调整。

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

Shared-Bottom Multi-task Model(SBMM)全称为底层共享多任务模型,主要由底层共享网络、多任务塔、多目标输出构成。核心原理:通过构造多任务多目标样本数据,在Loss计算环节,将各任务Loss求和(或加权求和),对Loss求导(求梯度)后,逐步后向传播迭代。

  1. 底部网络:Shared-Bottom 网络通常位于底部,可以为一个DNN网络,或者emb+pooling+mlp的方式对input输入的稀疏(sparse)特征进行稠密(dense)化。
  2. ​​​​​​​多个任务塔:底部网络上层接N个任务塔(Tower),每个塔根据需要可以定义为简单或复杂的多层感知器(mlp)网络。每个塔可以对应特定的场景,比如一二级页面场景。
  3. 多个目标:每个任务塔(Tower)可以输出多个学习目标,每个学习目标还可以像上一篇文章一样进行样本Loss加权。每个目标可以对应一种特定的指标行为,比如点击、时长、下单等。

2.2 技术优缺点

相比于上一篇文章提到的样本Loss加权融合法,以及后续文章将会介绍的MoE、MMoE方法,有如下优缺点:

优点:

  • 可以对多级场景任务进行建模,使得ctcvr等点击后转化问题可以被深度学习
  • 浅层参数共享,互相补充学习,任务相关性越高,模型的loss可以降低到更低

缺点: 

  • 跷跷板问题:任务没有好的相关性时,这种Hard parameter sharing会损害效果

2.3 业务代码实践

我们以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

跨场景多目标建模:我们定义一个SBMM算法结构,底层是一个3层的MLP(64,32,16),MLP出来后接一级场景Tower和二级场景Tower,一级场景任务中分别定义视频一级页“是否停留”、“停留时长”、“是否点击”,二级场景任务中分别定义“点击后播放时长”,“播放后是否点赞”

伪代码:

导入 pytorch 库
定义 SharedBottomMultiTaskModel 类 继承自 nn.Module:
    定义 __init__ 方法 参数 (self, 输入维度, 隐藏层1大小, 隐藏层2大小, 隐藏层3大小, 输出任务1维度, 输出任务2维度):
        初始化共享底部的三层全连接层
        初始化任务1的三层全连接层
        初始化任务2的三层全连接层

    定义 forward 方法 参数 (self, 输入数据):
        计算输入数据通过共享底部后的输出
        从共享底部输出分别计算任务1和任务2的结果
        返回任务1和任务2的结果

生成虚拟样本数据:
    创建训练集和测试集

实例化模型对象
定义损失函数和优化器
训练循环:
    前向传播: 获取预测值
    计算每个任务的损失
    反向传播和优化

PyTorch版本:

算法逻辑

  1. 导入必要的库。
  2. 定义一个类来表示共享底部和特定任务头部的模型结构。
  3. 在初始化方法中定义共享底部和两个独立的任务头部网络层。
  4. 实现前向传播函数,处理输入数据通过共享底部后分发到不同的任务头部。
  5. 生成虚拟样本数据。
  6. 定义损失函数和优化器。
  7. 编写训练循环。
  8. 进行模型预测。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class SharedBottomMultiTaskModel(nn.Module):
    def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_task1_dim, output_task2_dim):
        super(SharedBottomMultiTaskModel, self).__init__()
        # 定义共享底部的三层全连接层
        self.shared_bottom = nn.Sequential(
            nn.Linear(input_dim, hidden1_dim),
            nn.ReLU(),
            nn.Linear(hidden1_dim, hidden2_dim),
            nn.ReLU(),
            nn.Linear(hidden2_dim, hidden3_dim),
            nn.ReLU()
        )

        # 定义任务1的三层全连接层
        self.task1_head = nn.Sequential(
            nn.Linear(hidden3_dim, hidden2_dim),
            nn.ReLU(),
            nn.Linear(hidden2_dim, output_task1_dim)
        )

        # 定义任务2的三层全连接层
        self.task2_head = nn.Sequential(
            nn.Linear(hidden3_dim, hidden2_dim),
            nn.ReLU(),
            nn.Linear(hidden2_dim, output_task2_dim)
        )

    def forward(self, x):
        # 计算输入数据通过共享底部后的输出
        shared_output = self.shared_bottom(x)

        # 从共享底部输出分别计算任务1和任务2的结果
        task1_output = self.task1_head(shared_output)
        task2_output = self.task2_head(shared_output)

        return task1_output, task2_output

# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
task1_dim = 3
task2_dim = 2
num_samples = 1000
X_train = torch.randn(num_samples, input_dim)
y_train_task1 = torch.randn(num_samples, task1_dim)  # 假设任务1的输出维度为task1_dim
y_train_task2 = torch.randn(num_samples, task2_dim)  # 假设任务2的输出维度为task2_dim

# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 实例化模型对象
model = SharedBottomMultiTaskModel(input_dim, 64, 32, 16, task1_dim, task2_dim)

# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):
        # 前向传播: 获取预测值
        outputs_task1, outputs_task2 = model(X_batch)

        # 计算每个任务的损失
        loss_task1 = criterion_task1(outputs_task1, y_task1_batch)
        loss_task2 = criterion_task2(outputs_task2, y_task2_batch)
        #print(f'loss_task1:{loss_task1},loss_task2:{loss_task2}')
        total_loss = loss_task1 + loss_task2

        # 反向传播和优化
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

# 模型预测
model.eval()
with torch.no_grad():
    test_input = torch.randn(1, input_dim)  # 构造一个测试样本
    pred_task1, pred_task2 = model(test_input)

    print(f'任务1预测结果: {pred_task1}')
    print(f'任务2预测结果: {pred_task2}')

三、总结

本文从技术原理、技术优缺点方面对推荐系统深度学习多任务多目标“Shared-Bottom Multi-task Model”算法进行讲解,该模型使用深度学习模型对多个任务场景多个目标的业务问题进行建模,使得用户在多个场景连续性行为可以被学习,在现实推荐系统业务中是比较基础的方法,后面本专栏还会陆续介绍MoE、MMoE等多任务多目标算法,期待您的关注和支持。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/951159.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

探秘MetaGPT:革新软件开发的多智能体框架(22/30)

一、MetaGPT 引发的 AI 变革浪潮 近年来,人工智能大模型领域取得了令人瞩目的进展,GPT-3、GPT-4、PaLM 等模型展现出了惊人的自然语言处理能力,仿佛为 AI 世界打开了一扇通往无限可能的大门。它们能够生成流畅的文本、回答复杂的问题、进行创…

LabVIEW软件Bug的定义与修改

在LabVIEW软件开发过程中,bug(程序错误或缺陷)指的是程序中导致不符合预期行为的任何问题。Bug可能是由于编码错误、逻辑漏洞、硬件兼容性问题、系统资源限制等因素引起的。它可能会导致程序崩溃、功能无法正常执行或输出结果不符合预期。理解…

高性能网络模式:Reactor 和 Proactor

Reactor Reactor 采用I/O多路复用监听事件,收到事件后,根据事件类型分配给某个进程/线程。 实际应用中用到的模型: 单 Reactor 单进程 单 Reactor 多线程 优点:能充分利用多核CPU性能。 缺点:存在多线程竞争共享资源…

扩散模型论文概述(三):Stability AI系列工作【学习笔记】

视频链接:扩散模型论文概述(三):Stability AI系列工作_哔哩哔哩_bilibili 本期视频讲的是Stability AI在图像生成的工作。 同样,第一张图片是神作,总结的太好了! 介绍Stable Diffusion之前&…

大数据技术-Hadoop(四)Yarn的介绍与使用

目录 一、Yarn 基本结构 1、Yarn基本结构 2、Yarn的工作机制 二、Yarn常用的命令 三、调度器 1、Capacity Scheduler(容量调度器) 1.1、特点 1.2、配置 1.2.1、yarn-site.xml 1.2.2、capacity-scheduler.xml 1.3、重启yarn、刷新队列 测试 向hi…

玩转大语言模型——ollama导入huggingface下载的模型

ollama导入huggingface模型 前言gguf模型查找相关模型下载模型 导入Ollama配置参数文件导入模型查看导入情况 safetensfors模型下载模型下载llama.cpp配置环境并转换 前言 ollama在大语言模型的应用中十分的方便,但是也存在一定的问题,比如不能使用自己…

apollo内置eureka dashboard授权登录

要确保访问Eureka Server时要求输入账户和密码,需要确保以下几点: 确保 eurekaSecurityEnabled 配置为 true:这个配置项控制是否启用Eureka的安全认证。如果它被设置为 false,即使配置了用户名和密码,也不会启用安全认…

【Dify】Dify自定义模型设置 | 对接DMXAPI使用打折 Openai GPT 或 Claude3.5系列模型方法详解

一、Dify & DMXAPI 1、Dify DIFY(Do It For You)是一种自动化工具或服务,旨在帮助用户简化操作,减少繁琐的手动操作,提升工作效率。通过DIFY,用户能够快速完成任务、获取所需数据,并且可以…

【深度学习】布匹寻边:抓边误差小于3px【附完整链接】

布匹寻边 项目简介 布匹寻边是指布料裁剪过程中,通过AI寻边技术自动识别布匹的边缘,将检测到的边缘信息输出,确保裁剪的准确性,减少浪费,并提高生产效率。 项目需求 将打满针眼的布匹边缘裁剪掉,且误差小…

http range 下载大文件分片

摘自:https://www.jianshu.com/p/32c16103715a 上传分片下载也能分 HTTP 协议范围请求允许服务器只发送 HTTP 消息的一部分到客户端。范围请求在传送大的媒体文件,或者与文件下载的断点续传功能搭配使用时非常有用。 检测服务器端是否支持范围请求 假…

解决WordPress出现Fatal error: Uncaught TypeError: ftp_nlist()致命问题

错误背景 WordPress版本:wordpress-6.6.2-zh_CN WooCommerce版本:woocommerce.9.5.1 WordPress在安装了WooCommerce插件后,安装的过程中没有问题,在安装完成后提示: 此站点遇到了致命错误,请查看您站点管理…

用户使用LLM模型都在干什么?

Anthropic 对用户与 Claude 3.5 Sonnet 的大量匿名对话展开分析,主要发现及相关情况如下: 使用用途分布 软件开发主导:在各类使用场景中,软件开发占比最高,其中编码占 Claude 对话的 15% - 25%,网页和移动应…

【巨实用】Git客户端基本操作

本文主要分享Git的一些基本常规操作,手把手教你如何配置~ ● 一个文件夹中初始化Git git init ● 为了方便以后提交代码需要对git进行配置(第一次使用或者需求变更的时候),告诉git未来是谁在提交代码 git config --global user.na…

腾讯云AI代码助手编程挑战赛:自动生成漂亮的网页

在当今数字化时代,网页设计和开发已经成为一项至关重要的技能。在当今时代,借助AI的力量,这部分工作变得简单。本文借助腾讯云AI代码助手——“自动生成需要的网页”。本文将详细介绍如何利用AI代码助手生成网页素材,帮助你轻松打…

多台PC共用同一套鼠标键盘

当环境中有多个桌面 pc 需要操作的时候,在 多台 pc 之间切换会造成很多的不方便 可以通过远程进行连接,但是有一个更好的方案是让多台机器之间共用同一套键盘鼠标 常用的解决方案 synergy 和 sharemouse,通过移动光标在不同的 pc 间切换 s…

UOS系统mysql服务安装

UOS系统mysql服务安装 背景 1、安装环境:kvm虚拟机2、运行环境:uos server-1060e3、架构:x864、安装mysql版本:mysql-5.71、安装准备 # Mysql官网 https://downloads.mysql.com/archives/community/ # 下载安装包 wget -i -c …

Binlog实现MySQL主从同步

主从复制原理 ● Master 数据库只要发生变化,立马记录到Binary log 日志文件中 ● Slave数据库启动一个I/O thread连接Master数据库,请求Master变化的二进制日志 ● Slave I/O获取到的二进制日志,保存到自己的Relay log 日志文件中。 ● Sla…

matlab离线安装硬件支持包

MATLAB 硬件支持包离线安装 本文章提供matlab硬件支持包离线安装教程,因为我的matlab安装的某种原因(破解),不支持硬件支持包的安装,相信也有很多相同情况的朋友,所以记录一下我是如何离线安装的&#xff…

C#进阶-在Ubuntu上部署ASP.NET Core Web API应用

随着云计算和容器化技术的普及,Linux 服务器已成为部署 Web 应用程序的主流平台之一。ASP.NET Core 作为一个跨平台、高性能的框架,非常适合在 Linux 环境中运行。本篇博客将详细介绍如何在 Linux 服务器上部署 ASP.NET Core Web API 应用,包…

从光子到图像——相机如何捕获世界?

引言 你是否想过为何我们按一下相机快门就可以将眼前广袤多彩的世界显示于一个小小的相机屏幕上?本期推文中将带着大家重现从光子转换为电子、电子转换为图像中数字驱动值的整个流程。 ▲人们通过相机捕获眼前的场景 从光子到电子的转换 光线首先通过光学镜头进入相…