【深度学习】多目标融合算法(四):多门混合专家网络MMOE(Multi-gate Mixture-of-Experts)

目录

一、引言

二、MMoE(Multi-gate Mixture-of-Experts,多门混合专家网络)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

2.3.1 业务场景与建模

2.3.2 模型代码实现

2.3.3 模型训练与推理测试

2.3.4 打印模型结构 

三、总结


一、引言

上一篇我们讲了MoE混合专家网络,通过引入Gate门控,针对不同的Input分布,对多个专家网络赋予不同的权重,解决多场景或多目标任务task的底层信息共享及个性化问题。但MoE网络对于不同的Expert专家网络,采用同一个Gate门控网络,仅对不同的Input分布实现了个性化,对不同目标任务task的个性化刻画能力不足,今天在MoE的基础上,引入MMoE网络,为每一个task任务构建专属的Gate门控网络,这样的改进可以针对不同的task得到不同的Experts权重,从而实现对Experts专家的选择利用,不同的任务task对应的gate门控网络可以学习到不同的Experts网络组合模式,更容易捕捉到不容task间的相关性和差异性。

二、MMoE(Multi-gate Mixture-of-Experts,多门混合专家网络)

2.1 技术原理

MMoE(Multi-gate Mixture-of-Experts)全称为多门混合专家网络,主要由多个专家网络、多个任务塔、多个门控网络构成。核心原理:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim;同时,样本数据分别输入目标task对应的门控网络Gate A及Gate B,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(采用softmax将输出归一化,各个维度加起来和为1);对于每一个Task,将各自对应专家网络的输出,基于对应gate门控网络的softmax加权平均,作为各自Task的输入,所有Task的输入统一维度均为output_experts_dim。在每次反向传播迭代时,对Gate A、Gate B和num_experts个专家参数进行更新,Gate A、Gate B和专家网络的参数受任务Task A、B共同影响。

  • 专家网络:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim。
  • 门控网络:样本数据分别输入目标task对应的门控网络Gate A及Gate B,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(采用softmax将输出归一化,各个维度加起来和为1)
  • 任务网络:对于每一个Task,将各自对应专家网络的输出,基于对应gate门控网络的softmax加权平均,作为各自Task的输入,所有Task的输入统一维度均为output_experts_dim。

2.2 技术优缺点

相较于MoE网络,MMoE的本质是每个task自带Gate门控网络对多个专家的预估结果进行选择,相当于给每个task安排了一个个人助理,对专家的结果进行评审(而MoE对于所有task仅有一个公共助理,对task的专属需求了解不深)。相较于MoE网络:

优点:

  • 对每个task安排专属的gate网络,在专家网络赋值时更加个性化
  • 更容易捕捉到不容task间的相关性和差异性。

缺点: 

  • MMOE中所有的Expert是被所有task共享的,这可能无法捕捉到任务之间更复杂的关系,从而给部分任务带来一定的噪声
  • 不同的Expert之间没有交互,联合优化的效果有所折扣,虽然可以缓解负迁移问题,但跷跷板现象仍然存在。

2.3 业务代码实践

2.3.1 业务场景与建模

我们还是以小红书推荐场景为例,针对一个视频,用户可以点红心(互动),也可以点击视频进行播放(点击),针对互动和点击两个目标进行多目标建模

我们构建一个100维特征输入,4个experts专家网络,2个task目标,2个门控的MMoE网络,用于建模多目标学习问题,模型架构图如下:

​​​​​​​​​​​​​​​​​​​​​

如架构图所示,其中有几个注意的点:

  • num_experts:门控gate的输出维度和专家数相同,均为num_experts,因为gate的用途是对专家网络最后一层进行加权平均,gate维度与专家数是直接对应关系。
  • output_experts_dim:专家网络的输出维度和task网络的输入维度相同,task网络承接的是专家网络各维度的加权平均值,experts网络与task网络是直接对应关系。
  • Softmax:Gate门控网络对最后一层采用Softmax归一化,保证专家网络加权平均后值域相同

2.3.2 模型代码实现

基于pytorch,实现上述网络架构,如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class MMoEModel(nn.Module):
    def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts):
        super(MMoEModel, self).__init__()
        # 初始化函数外使用初始化变量需要赋值,否则默认使用全局变量
        # 初始化函数内使用初始化变量不需要赋值 
        self.num_experts = num_experts
        self.output_experts_dim = output_experts_dim

        # 初始化多个专家网络
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, experts_hidden1_dim),
                nn.ReLU(),
                nn.Linear(experts_hidden1_dim, experts_hidden2_dim),
                nn.ReLU(),
                nn.Linear(experts_hidden2_dim, output_experts_dim),
                nn.ReLU()
            ) for _ in range(num_experts)
        ])

        # 定义任务1的输出层
        self.task1_head = nn.Sequential(
                nn.Linear(output_experts_dim, task_hidden1_dim),
                nn.ReLU(),
                nn.Linear(task_hidden1_dim, task_hidden2_dim),
                nn.ReLU(),
                nn.Linear(task_hidden2_dim, output_task1_dim),
                nn.Sigmoid()
            ) 

        # 定义任务2的输出层
        self.task2_head = nn.Sequential(
                nn.Linear(output_experts_dim, task_hidden1_dim),
                nn.ReLU(),
                nn.Linear(task_hidden1_dim, task_hidden2_dim),
                nn.ReLU(),
                nn.Linear(task_hidden2_dim, output_task2_dim),
                nn.Sigmoid()
            ) 

        # 初始化门控网络1
        self.gating1_network = nn.Sequential(
            nn.Linear(input_dim, gate_hidden1_dim),
            nn.ReLU(),
            nn.Linear(gate_hidden1_dim, gate_hidden2_dim),
            nn.ReLU(),
            nn.Linear(gate_hidden2_dim, num_experts),
            nn.Softmax(dim=1)
        )
        # 初始化门控网络2
        self.gating2_network = nn.Sequential(
            nn.Linear(input_dim, gate_hidden1_dim),
            nn.ReLU(),
            nn.Linear(gate_hidden1_dim, gate_hidden2_dim),
            nn.ReLU(),
            nn.Linear(gate_hidden2_dim, num_experts),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        # 计算输入数据通过门控网络后的权重
        gates1 = self.gating1_network(x)
        gates2 = self.gating2_network(x)
        #print(gates)
        batch_size, _ = x.shape
        task1_inputs = torch.zeros(batch_size, self.output_experts_dim)
        task2_inputs = torch.zeros(batch_size, self.output_experts_dim)

        # 计算每个专家的输出并加权求和
        for i in range(self.num_experts):
            expert_output = self.experts[i](x)

            task1_inputs += expert_output * gates1[:, i].unsqueeze(1)
            task2_inputs += expert_output * gates2[:, i].unsqueeze(1)

        task1_outputs = self.task1_head(task1_inputs)
        task2_outputs = self.task2_head(task2_inputs)

        return task1_outputs, task2_outputs


# 实例化模型对象
num_experts = 4  # 假设有4个专家
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 1
output_task2_dim = 1

# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 100
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim)  # 假设任务1的输出维度为1
y_train_task2 = torch.rand(num_samples, output_task2_dim)  # 假设任务2的输出维度为1

# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

model = MMoEModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts)

# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练循环
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):
        # 前向传播: 获取预测值
        #print(batch_idx, X_batch )
        #print(f'Epoch [{epoch+1}/{num_epochs}-{batch_idx}], Loss: {running_loss/len(train_loader):.4f}')
        outputs_task1, outputs_task2 = model(X_batch)

        # 计算每个任务的损失
        loss_task1 = criterion_task1(outputs_task1, y_task1_batch)
        loss_task2 = criterion_task2(outputs_task2, y_task2_batch)

        total_loss = loss_task1 + loss_task2

        # 反向传播和优化
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()
    if epoch % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

print(model)
#for param_tensor in model.state_dict():
#    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 模型预测
model.eval()
with torch.no_grad():
    test_input = torch.randint(0, 2, (1, input_dim)).float()  # 构造一个测试样本
    pred_task1, pred_task2 = model(test_input)

    print(f'互动目标预测结果: {pred_task1}')
    print(f'点击目标预测结果: {pred_task2}')

相比于上一篇MoE中的代码,MMoE初始化了gating1_network和gating2_network两个门控网络,在forward前向传播网络结构定义中,两个gate分别以input为输入,通过多层MLP后得到task相对应的加权平均权重。

2.3.3 模型训练与推理测试

运行上述代码,模型启动训练,Loss逐渐收敛,测试结果如下:

2.3.4 打印模型结构 ​​​​​​​

三、总结

本文详细介绍了MMoE多任务模型的算法原理、算法优势,并以小红书业务场景为例,构建网络结构并使用pytorch代码实现对应的网络结构、训练流程。相比于MoE,MMoE可以更好的学习不同Task任务的相关性和差异性。是深度学习推荐系统中多目标或多场景类问题中必须掌握的根基模型。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

【深度学习】多目标融合算法(三):混合专家网络MOE(Mixture-of-Experts) 

 【深度学习】多目标融合算法(四):多门混合专家网络MMOE(Multi-gate Mixture-of-Experts)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/969437.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

sqli-labs靶场实录(四): Challenges

sqli-labs靶场实录: Challenges Less54确定字段数获取数据库名获取表名获取列名提取密钥值 Less55Less56Less57Less58爆库构造爆表构造爆列构造密钥提取构造 Less59Less60Less61Less62爆库构造 Less63Less64Less65免责声明: Less54 本关开始上难度了 可以看到此关仅…

使用Redis实现分布式锁,基于原本单体系统进行业务改造

一、单体系统下&#xff0c;使用锁机制实现秒杀功能&#xff0c;并限制一人一单功能 1.流程图&#xff1a; 2.代码实现&#xff1a; Service public class VoucherOrderServiceImpl extends ServiceImpl<VoucherOrderMapper, VoucherOrder> implements IVoucherOrderSe…

Python + WhisperX:解锁语音识别的高效新姿势

大家好&#xff0c;我是烤鸭&#xff1a; 最近在尝试做视频的质量分析&#xff0c;打算利用asr针对声音判断是否有人声&#xff0c;以及识别出来的文本进行进一步操作。asr看了几个开源的&#xff0c;最终选择了openai的whisper&#xff0c;后来发现性能不行&#xff0c;又换了…

【Linux】Ubuntu Linux 系统——Node.js 开发环境

ℹ️大家好&#xff0c;我是练小杰&#xff0c;今天星期五了&#xff0c;同时也是2025年的情人节&#xff0c;今晚又是一个人的举个爪子&#xff01;&#xff01; &#x1f642; 本文是有关Linux 操作系统中 Node.js 开发环境基础知识&#xff0c;后续我将添加更多相关知识噢&a…

Oracle查看执行计划

方式一&#xff08;查看的真实的使用到的索引&#xff09; 1.执行解释计划 2.查看结果 可以看到使用了RANGE SCAN范围扫描的索引 方式二&#xff08;查看的是预测的可能会用到的索引&#xff09; 1.执行解释计划sql explain plan for select * from COURSE where COURSE_…

百度 AI开源!将在6月30日开源文心大模型4.5系列

【大力财经】直击互联网最前线&#xff1a;百度近期动作频频&#xff0c;先是宣布将在未来数月陆续推出文心大模型4.5系列&#xff0c;并于6月30日正式开源。 据大力财经了解&#xff0c;自DeepSeek开源之风盛行全球后&#xff0c;开源闭源路径的选择就成为AI领域的热门话题&a…

【DDD系列-2】风暴出的领域模型

为什么使用DDD​ 三个问题​ 1.为什么我们的系统越做越多&#xff0c;越来越庞大&#xff0c;还需要不断的重构&#xff1f;​ 2.为什么我们的系统业务越来越复杂&#xff0c;服务层的代码越来越多难以维护&#xff0c;不敢维护&#xff1f;​ 3.为什么一旦业务变化或者数据…

基于YALMIP和cplex工具箱的微电网最优调度算法matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 系统建模 4.2 YALMIP工具箱 4.3 CPLEX工具箱 5.完整工程文件 1.课题概述 基于YALMIP和cplex工具箱的微电网最优调度算法matlab仿真。通过YALMIP和cplex这两个工具箱&#xff0c;完成微电网的最优调…

visual studio导入cmake项目后打开无法删除和回车

通过Cmakelists.txt导入的项目做删除和回车无法响应&#xff0c;需要点击项目&#xff0c;然后选择配置项目就可以了

npm安装时无法访问github域名的解决方法

个人博客地址&#xff1a;npm安装时无法访问github域名的解决方法 | 一张假钞的真实世界 今天在用npm install的时候出现了github项目访问不了的异常&#xff1a; npm ERR! Error while executing: npm ERR! /bin/git ls-remote -h -t https://github.com/nhn/raphael.git np…

解锁豆瓣高清海报(三)从深度爬虫到URL构造,实现极速下载

脚本地址: 项目地址: Gazer PosterBandit_v2.py 前瞻 之前的 PosterBandit.py 是按照深度爬虫的思路一步步进入海报界面来爬取, 是个值得学习的思路, 但缺点是它爬取慢, 仍然容易碰到豆瓣的 418 错误, 本文也会指出彻底解决旧版 418 错误的方法并提高爬取速度. 现在我将介绍…

一维差分算法篇:高效处理区间加减

那么在正式介绍我们的一维差分的原理前&#xff0c;我们先来看一下一维差分所应用的一个场景&#xff0c;那么假设我们现在有一个区间为[L,R]的一个数组&#xff0c;那么我要在这个数组中的某个子区间比如[i,m] (L<i<m<R)进行一个加k值或者减去k值的一个操作&#xff…

信息收集-Web应用JS架构URL提取数据匹配Fuzz接口WebPack分析自动化

知识点&#xff1a; 1、信息收集-Web应用-JS提取分析-人工&插件&项目 2、信息收集-Web应用-JS提取分析-URL&配置&逻辑 FUZZ测试 ffuf https://github.com/ffuf/ffuf 匹配插件 Hae https://github.com/gh0stkey/HaE JS提取 JSFinder https://github.com/Threez…

Python基础语法精要

文章目录 一、Python的起源二、Python的用途三、Python的优缺点优点缺点 四、基础语法&#xff08;1&#xff09;常量和表达式&#xff08;2&#xff09;变量变量的语法&#xff08;i&#xff09;定义变量&#xff08;ii&#xff09;变量命名的规则 &#xff08;3&#xff09;变…

测试方案整理

搜索引擎放在那里&#xff1f;研发 查看问题样本或者在提取再批量入录等情况&#xff0c;一旦我没有勾选或者全选中已经批量入录的样本&#xff0c;那么在直接点击批量提取或查看问题样本的后&#xff0c;会自动默认为选择全选样本还是按照输入错误处理&#xff1f; 批量查看返…

开启对话式智能分析新纪元——Wyn商业智能 BI 携手Deepseek 驱动数据分析变革

2月18号&#xff0c;Wyn 商业智能 V8.0Update1 版本将重磅推出对话式智能分析&#xff0c;集成Deepseek R1大模型&#xff0c;通过AI技术的深度融合&#xff0c;致力于打造"会思考的BI系统"&#xff0c;让数据价值触手可及&#xff0c;助力企业实现从数据洞察到决策执…

政策赋能科技服务,CES Asia 2025将展北京科技新貌

近日&#xff0c;《北京市支持科技服务业高质量发展若干措施》正式印发&#xff0c;为首都科技服务业的腾飞注入了强大动力。 该《若干措施》提出了三方面14条政策措施。在壮大科技服务业市场主体方面&#xff0c;不仅支持科技服务业企业向平台化和综合性服务机构发展&#xf…

2024春秋杯网络安全联赛冬季赛wp

web flask 根据题目描述&#xff0c;很容易想到ssti注入 测试一下 确实存在 直接打payload {{lipsum.globals[‘os’].popen(‘cat f*’).read()}} file_copy 看到题目名字为file_copy&#xff0c; 当输入路径时会返回目标文件的大小&#xff0c; 通过返回包&#xff0c…

Mac os部署本地deepseek+open UI界面

一.部署本地deepseek 使用ollama部署&#xff0c;方便快捷。 ollama介绍&#xff1a;Ollama 是一个高效、便捷的人工智能模型服务平台&#xff0c;提供多样化的预训练模型&#xff0c;涵盖自然语言处理、计算机视觉、语音识别等领域&#xff0c;并支持模型定制和微调。其简洁…

分布式技术

一、为什么需要分布式技术&#xff1f; 1. 科学技术的发展推动下 应用和系统架构的变迁&#xff1a;单机单一架构迈向多机分布式架构 2. 数据大爆炸&#xff0c;海量数据处理场景面临问题 二、分布式系统概述 三、分布式、集群 四、负载均衡、故障转移、伸缩性 负载均衡&a…