【域适应】基于深度域适应MMD损失的典型四分类任务实现

关于

MMD (maximum mean discrepancy)是用来衡量两组数据分布之间相似度的度量。一般地,如果两组数据分布相似,那么MMD 损失就相对较小,说明两组数据/特征处于相似的特征空间中。基于这个想法,对于源域和目标域数据,在使用深度学习进行特征提取中,使用MMD损失,可以让模型提取两个域的共有特征/空间,从而实现源域到目标域的迁移。

参考论文:https://arxiv.org/abs/1409.6041

工具

Python

 

方法实现

定义mmd函数
#!/usr/bin/env python
# encoding: utf-8

import torch

# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
#             = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def mmd_linear(f_of_X, f_of_Y):
    delta = f_of_X - f_of_Y
    loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
    return loss

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2)
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    return sum(kernel_val)#/len(kernel_val)


def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    loss = 0
    for i in range(batch_size):
        s1, s2 = i, (i+1)%batch_size
        t1, t2 = s1+batch_size, s2+batch_size
        loss += kernels[s1, s2] + kernels[t1, t2]
        loss -= kernels[s1, t2] + kernels[s2, t1]
    return loss / float(batch_size)

def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
                              kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY -YX)
    return loss
定义基于mmd特征对齐CNN模型
# encoding=utf-8

import torch.nn as nn
import torch.nn.functional as F


class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(1, 3)),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 3)),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(in_features=64 * 98, out_features=100),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(in_features=100, out_features=2)
        )

    def forward(self, src, tar):
        x_src = self.conv1(src)
        x_tar = self.conv1(tar)
        
        x_src = self.conv2(x_src)
        x_tar = self.conv2(x_tar)
        #print(x_src.shape)
        x_src = x_src.reshape(-1, 64 * 98)
        x_tar = x_tar.reshape(-1, 64 * 98)
        
        x_src_mmd = self.fc1(x_src)
        x_tar_mmd = self.fc1(x_tar)
        
        #x_src = self.fc1(x_src)
        #x_tar = self.fc1(x_tar)
        
        #x_src_mmd = self.fc2(x_src)
        #x_tar_mmd = self.fc2(x_tar)
        
        y_src = self.fc2(x_src_mmd)
        
        return y_src, x_src_mmd, x_tar_mmd
    

代码获取

后台私信;

其他相关域适应问题和代码开发,欢迎沟通和交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/535534.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

24/04/11总结

IO流(First edition): IO流:用于读入写出文件中的数据 流的方向(输入指拿出来,输出指写进去) 输入流:读取 输出流:写出 操作文件类型 字节流:所有类型文件 字符流:纯文本 字节流: InputStream的子类:FileInputStream:操作本地文件的字节输入流 OutputSt…

工作流引擎常见API(以camunda为例)

在Camunda中,API的继承关系主要体现在各个服务接口之间。以下是Camunda中一些常见服务接口的继承关系: ProcessEngineServices 接口: RepositoryService: 负责管理流程定义和部署。RuntimeService: 负责管理流程实例的…

2023年通用人工智能AGI等级保护白皮书

今天分享的是人工智能专题系列深度研究报告:《人工智能专题:2023年通用人工智能AGI等级保护白皮书》。 通用人工智能发展现状 本章主要介绍通用人工智能的基本情况,包括其发展历史、现状以及组成架构等内容。本文还将通过从技术角度出发来分…

PTA 应急救援站选址(floyd+打印路径)

大学城虎溪社区有很多居民小区,居民小区道路图是连通的。现要在该社区新建一个应急救援站,且该应急救援站要和某个小区建在一起。为了使应急救援最快速,经各部门商量决定:应急救援站建好后,离应急救援站最远的小区到应…

大话设计模式之命令模式

命令模式是一种行为型设计模式,它将请求或操作封装成一个对象,从而允许客户端参数化操作。这意味着客户端将一个请求封装为一个对象,这样可以将请求的参数化、队列化和记录日志,以及支持可撤销的操作。 命令模式主要由以下几个角…

kaggle 泰坦尼克号1(根据男女性存活率)

kaggle竞赛 泰坦尼克号 流程 下载kaggle数据集导入所要使用的包引入kaggle的数据集csv文件查看数据集的大小和长度去除冗余数据建立特征工程导出结果csv文件 1.下载kaggle数据集 2.导入所要使用的包 import pandas as pd import numpy as np import matplotlib.pyplot as …

leetcode328.奇偶链表

1. 题目描述 在线练习 2. 解题思路 这道题,官方给的是中等难度。其实是一道基础题,大家应该都可以写得出来。 题目中给的示例可以清楚的看到,合并前后的奇偶链的各自包含的节点的顺序是不变的,我们基本可以确定使用尾插法来合并…

Ansys Mechanical | 软件介绍:业界一流的有限元求解器

Ansys Mechanical 有限元分析软件 Ansys Mechanical 是业界一流的有限元求解器,具有结构、热学、声学、瞬态和非线性功能,可帮助改进建模。 ​ 软件概览 Ansys Mechanical 创建了一个使用有限元仿真分析软件(FEA)进行结构分析…

猝不及防 CCF-B ICPP 2024投稿延期至4月22日提交摘要 机会来了别错过

会议之眼 快讯 第53届ICPP(International Conference on Parallel Processing)即国际并行处理会议将于 2024年 8月12日-15日在瑞典哥特兰岛举行!ICPP是世界上最古老的连续举办的并行计算计算机科学会议之一。它是学术界、工业界和政府的研究…

欢迎加入PenPad Season 2 ,获得勋章以及海量 Scroll 生态权益

PenPad 是 Scroll 生态中的首个 LaunchPad 平台,该平台继承了 Scroll 生态的技术优势,具备包括隐私在内的系列特点,同时且也被认为是 Scroll 生态最重要的价值入口之一。Penpad 与 Scroll 官方始终保持着合作,同时该项目自启动以来…

你一定不能错过的多模态大模型!阿里千问开源Qwen-VL!具备图文解读等能力

1. Qwen-VL简介 1.1. 介绍 Qwen-VL的多语言视觉语言模型系列,基于Qwen-7B语言模型。该模型通过视觉编码器和位置感知的视觉语言适配器,赋予语言模型视觉理解能力。 Qwen-VL采用了三阶段的训练流程,并在多个视觉语言理解基准测试中取得了领先的成绩。该模型支持多语言、多图…

办公室电脑监控软件哪个最好用

办公室电脑监控软件哪个最好用 办公室监控软件主要用于帮助企业管理员监控员工在工作时间内的电脑使用情况,以提高工作效率、保障数据安全、遵守合规要求和维护良好的工作秩序。以下是一些推荐的办公室监控软件。 1、安企神 (1) 强大的监控功能:域智盾…

【文献分享】机器学习 + 分子动力学(LAMMPS 输入文件)+ 第一性原理 + 热学性质 + 动力学性质

分享一篇关于机器学习 分子动力学 第一性原理 热学性质 动力学性质的文章。 感谢论文的原作者! 关键词: 1. Machine learning, 2. Deep potential, 3. Molecular dynamics 4. Molten salts 5. Thermophysical properties 6. Phase diagram 主…

并查集加训

1.模板 #include<iostream> using namespace std; const int N 1e4 10; int p[N]; int n, m;int fd(int x){if(x ! p[x]){p[x] fd(p[x]);}return p[x]; }int main(){scanf("%d%d", &n, &m);for(int i 1; i < n; i){p[i] i;}int z, x, y;while(…

nvm更新node版本

1、nvm安装和管理多个 Node.js 版本&#xff1a;NVM 允许用户在计算机上同时安装多个不同版本的 Node.js。这使得开发人员可以轻松地在不同的项目中使用不同的 Node.js 版本&#xff0c;而无需手动安装或卸载。 2、nvm切换 Node.js 版本&#xff1a;通过 NVM&#xff0c;用户可…

一辆新能源汽车需要多少颗传感器?

随着科技的发展和环保意识的日益提高&#xff0c;新能源汽车&#xff08;包括纯电动汽车、混合动力汽车等&#xff09;在全球范围内越来越受到欢迎。这些汽车不仅减少了碳排放&#xff0c;还推动了汽车产业的创新。然而&#xff0c;这些高科技汽车的背后&#xff0c;隐藏着许多…

git lfs 大文件管理

简介 git-lfs 是 Git Large File Storage 的缩写&#xff0c;是 Git 的一个扩展&#xff0c;用于处理大文件的版本控制。 它允许你有效地管理和存储大型二进制文件&#xff0c;而不会使 Git 仓库变得过大和不稳定。以下是一些与 git-lfs 相关的常见命令和解释&#xff1a; 常…

zabbix“专家坐诊”第236期问答

问题一 Q&#xff1a;我的trap里已经可以收到信息了&#xff0c;后续要怎么创建监控项呀&#xff1f; A&#xff1a;参考&#xff1a; 问题二 Q&#xff1a;snmp和snmp trap咋搞&#xff1f; A&#xff1a;你指的是如何开启这些协议还是如何做监控项&#xff1f; Q&#xff1…

KVM部署

1、检查虚拟化支持 首先&#xff0c;确认你的系统处理器支持硬件虚拟化&#xff0c;在Linux终端中&#xff0c;使用以下命令&#xff1a; egrep -c (vmx|svm) /proc/cpuinfo2、安装KVM及其工具 yum update yum install qemu-kvm libvirt libvirt-python libguestfs-tools vi…

[大模型]基于 ChatGLM3 和 LangChain 搭建知识库助手

基于 ChatGLM3 和 LangChain 搭建知识库助手 环境配置 在已完成 ChatGLM3 的部署基础上&#xff0c;还需要安装以下依赖包&#xff1a; pip install langchain0.0.292 pip install gradio4.4.0 pip install chromadb0.4.15 pip install sentence-transformers2.2.2 pip inst…