传知代码-图神经网络长对话理解(论文复现)

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

概述

情感识别是人类对话理解的关键任务。随着多模态数据的概念,如语言、声音和面部表情,任务变得更加具有挑战性。作为典型解决方案,利用全局和局部上下文信息来预测对话中每个单个句子(即话语)的情感标签。具体来说,全局表示可以通过对话级别的跨模态交互建模来捕获。局部表示通常是通过发言者的时间信息或情感转变来推断的,这忽略了话语级别的重要因素。此外,大多数现有方法在统一输入中使用多模态的融合特征,而不利用模态特定的表示。针对这些问题,我们提出了一种名为“关系时序图神经网络与辅助跨模态交互(CORECT)”的新型神经网络框架,它以模态特定的方式有效捕获了对话级别的跨模态交互和话语级别的时序依赖,用于对话理解。大量实验证明了CORECT的有效性,通过在IEMOCAP和CMUMOSEI数据集上取得了多模态ERC任务的最新成果。

模型整体架构

在这里插入图片描述

特征提取

文本采用transformerde方式进行编码
在这里插入图片描述

音频,视频都采用全连接的方式进行编码
在这里插入图片描述

通过添加相应的讲话者嵌入来增强技术增强
在这里插入图片描述

关系时序图卷积网络(RT-GCN)

解读:RT-GCN旨在通过利用话语之间以及话语与其模态之间的多模态图来捕获对话中每个话语的局部上下文信息,关系时序图在一个模块中同时实现了上下文信息,与模态之间的信息的传递。对话中情感识别需要跨模态学习到信息,同时也需要学习上下文的信息,整合成一个模块的作用将两部分并行处理,降低模型的复杂程度,降低训练成本,降低训练难度。

建图方式,模态与模态之间有边相连,对话之间有边相连:

在这里插入图片描述

建图之后,用图transformer融合不同模态,以及不同语句的信息,得到处理之后特征向量:
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

两两交叉模态特征交互

跨模态的异质性经常提高了分析人类语言的难度。利用跨模态交互可能有助于揭示跨模态之间的“不对齐”特性和长期依赖关系。受到这一思想的启发(Tsai等人,2019),我们将配对的跨模态特征交互(P-CM)方法设计到我们提出的用于对话理解的框架中。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

线性分类器

最后就是根据提取出来的特征进行情感分类了:

在这里插入图片描述

代码修改

这是对话中多模态情感识别(视觉,音频,文本)在数据集IEMOCAP目前为止的SOTA。在离线系统已经取得了相当不错的表现。(离线系统的意思是,是一段已经录制好的视频,而不是事实录制如线上开会)

但是却存在一个问题,输入的数据是已经给定的一个视频,分析某一句话的情感状态的时候,论文的方法使用了过去的信息,也使用了未来的信息,这样会在工业界实时应用场景存在一定的问题。

比如在开线上会议,需要检测开会双方的情绪,不可能用未来将要说的话预测现在的情绪。因为未来的话都还没被说话者说出来,此时,就不能参考到未来的语句来预测现在语句的情感信息。但是原文的方法在数据结构图的构建的时候,连接上了未来语句和现在语句的边,用图神经网络学习了之间的关联。

因此,修改建图方式,不考虑未来的情感信息,重新训练网络,得到了还可以接受的效果,精度大概在82%左右,原文的精度在84%左右,2%精度的牺牲解决了是否能实时的问题其实是值得的。

演示效果

在这里插入图片描述
在这里插入图片描述

核心逻辑

在这里可以粘贴您的核心代码逻辑:

# start

#模型核心部分

import torch
import torch.nn as nn
import torch.nn.functional as F

from .Classifier import Classifier
from .UnimodalEncoder import UnimodalEncoder
from .CrossmodalNet import CrossmodalNet
from .GraphModel import GraphModel
from .functions import multi_concat, feature_packing
import corect

log = corect.utils.get_logger()

class CORECT(nn.Module):
    def __init__(self, args):
        super(CORECT, self).__init__()

        self.args = args
        self.wp = args.wp
        self.wf = args.wf
        self.modalities = args.modalities
        self.n_modals = len(self.modalities)
        self.use_speaker = args.use_speaker
        g_dim = args.hidden_size
        h_dim = args.hidden_size

        ic_dim = 0
        if not args.no_gnn:
            ic_dim = h_dim * self.n_modals

            if not args.use_graph_transformer and (args.gcn_conv == "gat_gcn" or args.gcn_conv == "gcn_gat"):
                ic_dim = ic_dim * 2

            if args.use_graph_transformer:
                ic_dim *= args.graph_transformer_nheads
        
        if args.use_crossmodal and self.n_modals > 1:
            ic_dim += h_dim * self.n_modals * (self.n_modals - 1)

        if self.args.no_gnn and (not self.args.use_crossmodal or self.n_modals == 1):
            ic_dim = h_dim * self.n_modals

        
        a_dim = args.dataset_embedding_dims[args.dataset]['a']
        t_dim = args.dataset_embedding_dims[args.dataset]['t']
        v_dim = args.dataset_embedding_dims[args.dataset]['v']
        
        dataset_label_dict = {
            "iemocap": {"hap": 0, "sad": 1, "neu": 2, "ang": 3, "exc": 4, "fru": 5},
            "iemocap_4": {"hap": 0, "sad": 1, "neu": 2, "ang": 3},
            "mosei": {"Negative": 0, "Positive": 1},
        }

        dataset_speaker_dict = {
            "iemocap": 2,
            "iemocap_4": 2,
            "mosei":1,
        }
        
        
        tag_size = len(dataset_label_dict[args.dataset])
        self.n_speakers = dataset_speaker_dict[args.dataset]

        self.wp = args.wp
        self.wf = args.wf
        self.device = args.device


        self.encoder = UnimodalEncoder(a_dim, t_dim, v_dim, g_dim, args)
        self.speaker_embedding = nn.Embedding(self.n_speakers, g_dim)

        print(f"{args.dataset} speakers: {self.n_speakers}")
        if not args.no_gnn:
            self.graph_model = GraphModel(g_dim, h_dim, h_dim, self.device, args)
            print('CORECT --> Use GNN')

        if args.use_crossmodal and self.n_modals > 1:
            self.crossmodal = CrossmodalNet(g_dim, args)
            print('CORECT --> Use Crossmodal')
        elif self.n_modals == 1:
            print('CORECT --> Crossmodal not available when number of modalitiy is 1')

        self.clf = Classifier(ic_dim, h_dim, tag_size, args)

        self.rlog = {}


    def represent(self, data):

        # Encoding multimodal feature
        a = data['audio_tensor'] if 'a' in self.modalities else None
        t = data['text_tensor'] if 't' in self.modalities else None
        v = data['visual_tensor'] if 'v' in self.modalities else None

        a, t, v = self.encoder(a, t, v, data['text_len_tensor'])


        # Speaker embedding
        if self.use_speaker:
            emb = self.speaker_embedding(data['speaker_tensor'])
            a = a + emb if a != None else None
            t = t + emb if t != None else None
            v = v + emb if v != None else None

        # Graph construct
        multimodal_features = []

        if a != None:
            multimodal_features.append(a)
        if t != None:
            multimodal_features.append(t)
        if v != None:
            multimodal_features.append(v)

        out_encode = feature_packing(multimodal_features, data['text_len_tensor'])
        out_encode = multi_concat(out_encode, data['text_len_tensor'], self.n_modals)

        out = []

        if not self.args.no_gnn:
            out_graph = self.graph_model(multimodal_features, data['text_len_tensor'])
            out.append(out_graph)


        if self.args.use_crossmodal and self.n_modals > 1:
            out_cr = self.crossmodal(multimodal_features)

            out_cr = out_cr.permute(1, 0, 2)
            lengths = data['text_len_tensor']
            batch_size = lengths.size(0)
            cr_feat = []
            for j in range(batch_size):
                cur_len = lengths[j].item()
                cr_feat.append(out_cr[j,:cur_len])

            cr_feat = torch.cat(cr_feat, dim=0).to(self.device)
            out.append(cr_feat)
        
        if self.args.no_gnn and (not self.args.use_crossmodal or self.n_modals == 1):
            out = out_encode
        else:
            out = torch.cat(out, dim=-1)

        return out

    def forward(self, data):
        graph_out = self.represent(data)
        out = self.clf(graph_out, data["text_len_tensor"])

        return out
    
    def get_loss(self, data):
        graph_out = self.represent(data)
        loss = self.clf.get_loss(
                graph_out, data["label_tensor"], data["text_len_tensor"])
        
        return loss

    def get_log(self):
        return self.rlog


        

#图神经网络
import torch
import torch.nn as nn
from torch_geometric.nn import RGCNConv, TransformerConv

import corect

class GNN(nn.Module):
    def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, args):
        super(GNN, self).__init__()
        self.args = args

        self.num_modals = num_modals
        
        if args.gcn_conv == "rgcn":
            print("GNN --> Use RGCN")
            self.conv1 = RGCNConv(g_dim, h1_dim, num_relations)

        if args.use_graph_transformer:
            print("GNN --> Use Graph Transformer")
           
            in_dim = h1_dim
                
            self.conv2 = TransformerConv(in_dim, h2_dim, heads=args.graph_transformer_nheads, concat=True)
            self.bn = nn.BatchNorm1d(h2_dim * args.graph_transformer_nheads)
            

    def forward(self, node_features, node_type, edge_index, edge_type):

        if self.args.gcn_conv == "rgcn":
            x = self.conv1(node_features, edge_index, edge_type)
        
        if self.args.use_graph_transformer:
            x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index)))
        
        return x

使用方式&部署方式

首先建议安装conda,因为想要复现深度学习的代码,github上不同项目的环境差别太大,同时处理多个项目的时候很麻烦,在这里就不做conda安装的教程了,请自行学习。

安装pytorch:
请到pytorch官网找安装命令,尽量不要直接pip install
https://pytorch.org/get-started/previous-versions/

给大家直接对着我安装版本来下载,因为图神经网络的包版本要求很苛刻,版本对应不上很容易报错:
在这里插入图片描述
在这里插入图片描述

只要环境配置好了,找到这个文件,里面的代码粘贴到终端运行即可
在这里插入图片描述

温馨提示

1.数据集和已训练好的模型都在.md文件中有百度网盘链接,直接下载放到指定文件夹即可
2.注意,训练出来的模型是有硬件要求的,我是用cpu进行训练的,模型只能在cpu跑,如果想在gpu上跑,请进行重新训练
3.如果有朋友希望用苹果的gpu进行训练,虽然现在pytorch框架已经支持mps(mac版本的cuda可以这么理解)训练,但是很遗憾,图神经网络的包还不支持,不过不用担心,这个模型的训练量很小,我全程都是苹果笔记本完成训练的。

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/789778.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据库-ubuntu环境下安装配置mysql

文章目录 什么是数据库?一、ubuntu环境下安装mysql二、配置mysql配置文件1.先登上root账号2.配置文件的修改 mysql和mysqld数据库的基础操作登录mysql创建数据库显示当前数据库使用数据库创建表插入students表数据打印students表数据select * from students; ![在这…

响应式布局下关于gird栅格布局的一些构思

1、传列数,根据列数计算元素容器宽度 好处是子元素可以写百分比宽度,不用固定某一种宽度,反正知道列数通过计算间距就能得到外层容器的宽度。 举个简单的例子: (ps:以下用例皆在html中去模拟,就不另外起r…

零基础STM32单片机编程入门(十二) HC-SR04超声波模块测距实战含源码

文章目录 一.概要二.HC-SR04主要参数1.模块引脚定义2.模块电气参数3.模块通讯时序4.模块原理图 三.STM32单片机超声波模块测距实验四.CubeMX工程源代码下载五.小结 一.概要 HC-SR04超声波模块常用于机器人避障、物体测距、液位检测、公共安防、停车场检测等场所。HC-SR04超声波…

宏碁F5-572G-59K3笔记本笔记本电脑拆机清灰教程(详解)

1. 前言 我的笔记本开机比较慢,没有固态,听说最近固态比较便宜,就想入手一个,于是拆笔记本看一下有没有可以安的装位置。(友情提示,在拆机之前记得洗手并擦干,以防静电损坏电源器件&#xff09…

视频号的视频,一键就下载了,方法全在这儿了!

居然还有人不知道:视频号里面的视频是没有地址的,只能有微信自带的浏览器中打开。 所以很多人在视频号找到想要的素材,却无法下载,表示很苦恼。 几天每天都有人群里求助:“求好心人帮我下载一下这个视频!…

昇思学习打卡-14-ResNet50迁移学习

文章目录 数据集可视化预训练模型的使用部分实现 推理 迁移学习:在一个很大的数据集上训练得到一个预训练模型,然后使用该模型来初始化网络的权重参数或作为固定特征提取器应用于特定的任务中。本章学习使用的是前面学过的ResNet50,使用迁移学…

STM32 GPIO的工作原理

STM32的GPIO管脚有下面8种可能的配置:(4输入 2 输出 2 复用输出) (1)浮空输入_IN_FLOATING 在上图上,阴影的部分处于不工作状态,尤其是下半部分的输出电路,实际上是与端口处于隔离状态。黄色的高亮部分显示…

在攻防演练中遇到的一个“有马蜂的蜜罐”

在攻防演练中遇到的一个“有马蜂的蜜罐” 有趣的结论,请一路看到文章结尾 在前几天的攻防演练中,我跟队友的气氛氛围都很好,有说有笑,恐怕也是全场话最多、笑最多的队伍了。 也是因为我们遇到了许多相当有趣的事情,其…

Mac窗口辅助管理工具:Magnet for mac激活版

magnet mac版是一款运行在苹果电脑上的一款优秀的窗口大小控制工具,拖拽窗口到屏幕边缘可以自动半屏,全屏或者四分之一屏幕,还可以设定快捷键完成分屏。这款专业的窗口管理工具当您每次将内容从一个应用移动到另一应用时,当您需要…

python学习-容器类型

列表 列表(list)是一种有序容器,可以向其中添加或删除任意元素. 列表数据类型是一种容器类型,列表中可以存放不同数据类型的值,代码示例如下: 列表中可以实现元素的增、删、改、查。 示例代码如下: 增 …

医疗器械网络安全 | 漏洞扫描、渗透测试没有发现问题,是否说明我的设备是安全的?

尽管漏洞扫描、模糊测试和渗透测试在评估系统安全性方面是非常重要和有效的工具,但即使这些测试没有发现任何问题,也不能完全保证您的医疗器械是绝对安全的。这是因为安全性的评估是一个多维度、复杂且持续的过程,涉及多个方面和因素。以下是…

AI提示词:打造爆款标题生成器

打开GPT输入以下内容: # Role 爆款标题生成器## Profile - author: 姜小尘 - version: 02 - LLM: Kimi - language: 中文 - description: 利用心理学和市场趋势,生成吸引眼球的自媒体文章标题。## Background 一个吸引人的标题是提升文章点击率和传播力…

优化爬虫体验:揭秘IP重复率过高问题解决方案

在当今信息爆炸的时代,网络中蕴藏着大量宝贵的数据,而爬虫技术成为我们提取这些数据的重要工具。然而,随着爬虫的广泛使用,IP重复率高的问题也随之而来。本篇博文将揭秘解决这一问题的关键方法——使用IP代理。 一、 IP高重复问题…

手慢无,速看︱PMO大会内部学习资料

全国PMO专业人士年度盛会 每届PMO大会,组委会都把所有演讲嘉宾的PPT印刷在了会刊里面,供大家会后回顾与深入学习。 第十三届中国PMO大会-会刊 《2024第十三届中国PMO大会-会刊》 (内含演讲PPT) 会刊:750个页码&…

TK 检查输入框是否为空

在Python的Tkinter库中,你可以使用事件绑定或者在按钮点击事件中检查输入框的值是否为空来实现这个功能。以下是一个简单的例子: import tkinter as tk from tkinter import messageboxdef check_input():entry input_box.get()if not entry:messagebo…

《梦醒蝶飞:释放Excel函数与公式的力量》10.4 IMREAL函数

第四节 10.4 IMREAL函数 10.4.1 函数简介 IMREAL函数是Excel中的一个工程函数,用于提取复数的实部。在复数运算中,实部是复数的一部分,表示没有虚部参与的部分。IMREAL函数提供了一个简单的方法来获取复数的实部,便于进一步计算…

Java 8革新:现代编程的全新标准与挑战

文章目录 一、方法引用二、接口默认方法三、接口静态方法四、集合遍历forEach()方法 一、方法引用 方法引用是Java 8中一种简化Lambda表达式的方式,通过直接引用现有方法来代替Lambda表达式。 方法引用使得代码更加简洁和易读,特别是在处理函数式接口时&…

ns3学习笔记(四):路由概述

基于官网文档的 Routing Overview 部分详细研究一下ns3中路由是怎么工作的 文档链接16.4. Routing overview — Model Library 一、概述 NS3整体的工作架构如下: 路由部分的工作架构如下: 路由部分目前大多数用到的算法都包含在Ipv4RoutingProtocol部分…

常用网络概念

📑打牌 : da pai ge的个人主页 🌤️个人专栏 : da pai ge的博客专栏 ☁️宝剑锋从磨砺出,梅花香自苦寒来 ​​ 目录 了解组织 局域网技术 …

【Ty CLI】一个开箱即用的前端脚手架

目录 资源链接基础命令模板创建命令帮助选择模板开始创建开发模板 开发背景npm 发布流程问题记录模板创建超时 更新日志 资源链接 文档:https://ty.cli.vrteam.top/ 源码:https://github.com/bosombaby/ty-cli 基础命令 1. npm 全局安装 npm i ty-cli…