【域适应】基于域分离网络的MNIST数据10分类典型方法实现

关于

大规模数据收集和注释的成本通常使得将机器学习算法应用于新任务或数据集变得异常昂贵。规避这一成本的一种方法是在合成数据上训练模型,其中自动提供注释。尽管它们很有吸引力,但此类模型通常无法从合成图像推广到真实图像,因此需要域适应算法来操纵这些模型,然后才能成功应用。现有的方法要么侧重于将表示从一个域映射到另一个域,要么侧重于学习提取对于提取它们的域而言不变的特征。然而,通过只关注在两个域之间创建映射或共享表示,他们忽略了每个域的单独特征。域分离网络可以实现对每个域的独特之处进行特征建模,,同时进行模型域不变特征的提取

参考文章: https://arxiv.org/abs/1608.06019

工具

 

方法实现

数据集定义
import torch.utils.data as data
from PIL import Image
import os


class GetLoader(data.Dataset):
    def __init__(self, data_root, data_list, transform=None):
        self.root = data_root
        self.transform = transform

        f = open(data_list, 'r')
        data_list = f.readlines()
        f.close()

        self.n_data = len(data_list)

        self.img_paths = []
        self.img_labels = []

        for data in data_list:
            self.img_paths.append(data[:-3])
            self.img_labels.append(data[-2])

    def __getitem__(self, item):
        img_paths, labels = self.img_paths[item], self.img_labels[item]
        imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')

        if self.transform is not None:
            imgs = self.transform(imgs)
            labels = int(labels)

        return imgs, labels

    def __len__(self):
        return self.n_data
模型搭建
import torch.nn as nn
from functions import ReverseLayerF


class DSN(nn.Module):
    def __init__(self, code_size=100, n_class=10):
        super(DSN, self).__init__()
        self.code_size = code_size

        ##########################################
        # private source encoder
        ##########################################

        self.source_encoder_conv = nn.Sequential()
        self.source_encoder_conv.add_module('conv_pse1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,
                                                                padding=2))
        self.source_encoder_conv.add_module('ac_pse1', nn.ReLU(True))
        self.source_encoder_conv.add_module('pool_pse1', nn.MaxPool2d(kernel_size=2, stride=2))

        self.source_encoder_conv.add_module('conv_pse2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5,
                                                                padding=2))
        self.source_encoder_conv.add_module('ac_pse2', nn.ReLU(True))
        self.source_encoder_conv.add_module('pool_pse2', nn.MaxPool2d(kernel_size=2, stride=2))

        self.source_encoder_fc = nn.Sequential()
        self.source_encoder_fc.add_module('fc_pse3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size))
        self.source_encoder_fc.add_module('ac_pse3', nn.ReLU(True))

        #########################################
        # private target encoder
        #########################################

        self.target_encoder_conv = nn.Sequential()
        self.target_encoder_conv.add_module('conv_pte1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,
                                                                padding=2))
        self.target_encoder_conv.add_module('ac_pte1', nn.ReLU(True))
        self.target_encoder_conv.add_module('pool_pte1', nn.MaxPool2d(kernel_size=2, stride=2))

        self.target_encoder_conv.add_module('conv_pte2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5,
                                                                padding=2))
        self.target_encoder_conv.add_module('ac_pte2', nn.ReLU(True))
        self.target_encoder_conv.add_module('pool_pte2', nn.MaxPool2d(kernel_size=2, stride=2))

        self.target_encoder_fc = nn.Sequential()
        self.target_encoder_fc.add_module('fc_pte3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size))
        self.target_encoder_fc.add_module('ac_pte3', nn.ReLU(True))

        ################################
        # shared encoder (dann_mnist)
        ################################

        self.shared_encoder_conv = nn.Sequential()
        self.shared_encoder_conv.add_module('conv_se1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,
                                                                  padding=2))
        self.shared_encoder_conv.add_module('ac_se1', nn.ReLU(True))
        self.shared_encoder_conv.add_module('pool_se1', nn.MaxPool2d(kernel_size=2, stride=2))

        self.shared_encoder_conv.add_module('conv_se2', nn.Conv2d(in_channels=32, out_channels=48, kernel_size=5,
                                                                  padding=2))
        self.shared_encoder_conv.add_module('ac_se2', nn.ReLU(True))
        self.shared_encoder_conv.add_module('pool_se2', nn.MaxPool2d(kernel_size=2, stride=2))

        self.shared_encoder_fc = nn.Sequential()
        self.shared_encoder_fc.add_module('fc_se3', nn.Linear(in_features=7 * 7 * 48, out_features=code_size))
        self.shared_encoder_fc.add_module('ac_se3', nn.ReLU(True))

        # classify 10 numbers
        self.shared_encoder_pred_class = nn.Sequential()
        self.shared_encoder_pred_class.add_module('fc_se4', nn.Linear(in_features=code_size, out_features=100))
        self.shared_encoder_pred_class.add_module('relu_se4', nn.ReLU(True))
        self.shared_encoder_pred_class.add_module('fc_se5', nn.Linear(in_features=100, out_features=n_class))

        self.shared_encoder_pred_domain = nn.Sequential()
        self.shared_encoder_pred_domain.add_module('fc_se6', nn.Linear(in_features=100, out_features=100))
        self.shared_encoder_pred_domain.add_module('relu_se6', nn.ReLU(True))

        # classify two domain
        self.shared_encoder_pred_domain.add_module('fc_se7', nn.Linear(in_features=100, out_features=2))

        ######################################
        # shared decoder (small decoder)
        ######################################

        self.shared_decoder_fc = nn.Sequential()
        self.shared_decoder_fc.add_module('fc_sd1', nn.Linear(in_features=code_size, out_features=588))
        self.shared_decoder_fc.add_module('relu_sd1', nn.ReLU(True))

        self.shared_decoder_conv = nn.Sequential()
        self.shared_decoder_conv.add_module('conv_sd2', nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5,
                                                                  padding=2))
        self.shared_decoder_conv.add_module('relu_sd2', nn.ReLU())

        self.shared_decoder_conv.add_module('conv_sd3', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5,
                                                                  padding=2))
        self.shared_decoder_conv.add_module('relu_sd3', nn.ReLU())

        self.shared_decoder_conv.add_module('us_sd4', nn.Upsample(scale_factor=2))

        self.shared_decoder_conv.add_module('conv_sd5', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3,
                                                                  padding=1))
        self.shared_decoder_conv.add_module('relu_sd5', nn.ReLU(True))

        self.shared_decoder_conv.add_module('conv_sd6', nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3,
                                                                  padding=1))

    def forward(self, input_data, mode, rec_scheme, p=0.0):

        result = []

        if mode == 'source':

            # source private encoder
            private_feat = self.source_encoder_conv(input_data)
            private_feat = private_feat.view(-1, 64 * 7 * 7)
            private_code = self.source_encoder_fc(private_feat)

        elif mode == 'target':

            # target private encoder
            private_feat = self.target_encoder_conv(input_data)
            private_feat = private_feat.view(-1, 64 * 7 * 7)
            private_code = self.target_encoder_fc(private_feat)

        result.append(private_code)

        # shared encoder
        shared_feat = self.shared_encoder_conv(input_data)
        shared_feat = shared_feat.view(-1, 48 * 7 * 7)
        shared_code = self.shared_encoder_fc(shared_feat)
        result.append(shared_code)

        reversed_shared_code = ReverseLayerF.apply(shared_code, p)
        domain_label = self.shared_encoder_pred_domain(reversed_shared_code)
        result.append(domain_label)

        if mode == 'source':
            class_label = self.shared_encoder_pred_class(shared_code)
            result.append(class_label)

        # shared decoder

        if rec_scheme == 'share':
            union_code = shared_code
        elif rec_scheme == 'all':
            union_code = private_code + shared_code
        elif rec_scheme == 'private':
            union_code = private_code

        rec_vec = self.shared_decoder_fc(union_code)
        rec_vec = rec_vec.view(-1, 3, 14, 14)

        rec_code = self.shared_decoder_conv(rec_vec)
        result.append(rec_code)

        return result




 模型训练
import random
import os
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
from model_compat import DSN
from data_loader import GetLoader
from functions import SIMSE, DiffLoss, MSE
from test import test

######################
# params             #
######################

source_image_root = os.path.join('.', 'dataset', 'mnist')
target_image_root = os.path.join('.', 'dataset', 'mnist_m')
model_root = 'model'
cuda = True
cudnn.benchmark = True
lr = 1e-2
batch_size = 32
image_size = 28
n_epoch = 100
step_decay_weight = 0.95
lr_decay_step = 20000
active_domain_loss_step = 10000
weight_decay = 1e-6
alpha_weight = 0.01
beta_weight = 0.075
gamma_weight = 0.25
momentum = 0.9

manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)

#######################
# load data           #
#######################

img_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

dataset_source = datasets.MNIST(
    root=source_image_root,
    train=True,
    transform=img_transform
)

dataloader_source = torch.utils.data.DataLoader(
    dataset=dataset_source,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8
)

train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt')

dataset_target = GetLoader(
    data_root=os.path.join(target_image_root, 'mnist_m_train'),
    data_list=train_list,
    transform=img_transform
)

dataloader_target = torch.utils.data.DataLoader(
    dataset=dataset_target,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8
)

#####################
#  load model       #
#####################

my_net = DSN()

#####################
# setup optimizer   #
#####################


def exp_lr_scheduler(optimizer, step, init_lr=lr, lr_decay_step=lr_decay_step, step_decay_weight=step_decay_weight):

    # Decay learning rate by a factor of step_decay_weight every lr_decay_step
    current_lr = init_lr * (step_decay_weight ** (step / lr_decay_step))

    if step % lr_decay_step == 0:
        print 'learning rate is set to %f' % current_lr

    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    return optimizer


optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

loss_classification = torch.nn.CrossEntropyLoss()
loss_recon1 = MSE()
loss_recon2 = SIMSE()
loss_diff = DiffLoss()
loss_similarity = torch.nn.CrossEntropyLoss()

if cuda:
    my_net = my_net.cuda()
    loss_classification = loss_classification.cuda()
    loss_recon1 = loss_recon1.cuda()
    loss_recon2 = loss_recon2.cuda()
    loss_diff = loss_diff.cuda()
    loss_similarity = loss_similarity.cuda()

for p in my_net.parameters():
    p.requires_grad = True

#############################
# training network          #
#############################
 MNIST数据重建/共有部分特征/私有数据特征可视化

 MNIST_m数据重建/共有部分特征/私有数据特征可视化

代码获取

相关问题和项目开发,欢迎私信交流和沟通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/534343.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Springboot项目的测试类书写(速通)

目录 前言1. 单元测试的测试类2. 框架测试的测试类 前言 在实际开发中,如果只是做一个简单的单元测试(不涉及端到端、数据库交互、API调用、消息队列处理等),我为了方便一般都是找块儿地方写一个main方法来跑一下就行了&#xff…

CSS-文字环绕浮动、行内块分页、三角强化妙用、伪元素选择器

文字环绕浮动 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>文字环绕浮动效果</title><s…

[leetcode] all-nodes-distance-k-in-binary-tree 二叉树中所有距离为 K 的结点

. - 力扣&#xff08;LeetCode&#xff09; 给定一个二叉树&#xff08;具有根结点 root&#xff09;&#xff0c; 一个目标结点 target &#xff0c;和一个整数值 k 。 返回到目标结点 target 距离为 k 的所有结点的值的列表。 答案可以以 任何顺序 返回。 示例 1&#xff1a…

玩转公众号|掌握公众号运营技巧,让账号脱颖而出

随着互联网的普及&#xff0c;微信公众号已经成为了企业进行品牌宣传、产品推广和客户服务的重要渠道。而且&#xff0c;企业微信公众号是可以进行二次开发的&#xff0c;这样就能够满足企业的私域运营的需求。然而&#xff0c;对于许多企业来说&#xff0c;运营公众号和二次开…

LLM 构建Data Multi-Agents 赋能数据分析平台的实践之②:数据治理之二(自动处理)

前述 在前文的multi Agents for Data Analysis的设计说起&#xff0c;本文将继续探索和测试借助llm实现基于私有知识库的数据治理全自动化及智能化。整体设计如下&#xff1a; 整个体系设计了3个Agent以及一个Planer&Execute Agent&#xff0c;第一个Agent用于从企业数据…

结合ArcGIS+SWAT模型+Century模型:流域生态系统水-碳-氮耦合过程模拟

原文链接&#xff1a;结合ArcGISSWAT模型Century模型&#xff1a;流域生态系统水-碳-氮耦合过程模拟https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&tempkeyMTI2NV9sMGRZNUJoVkNVc1ZzSzRuMl9XXzhqX0R3cXpESWFwM1E4cFY4ejNqWFh3VUl0dlZkNWk4b20ydFdFTy1xS2ZObGN0Z0ZXSjly…

大话设计模式——9.单例模式(Singleton Pattern)

简介 确保一个类只有一个实例&#xff0c;并提供全局访问点来获取该实例&#xff0c;是最简单的设计模式。 UML图&#xff1a; 单例模式共有两种创建方式&#xff1a; 饿汉式&#xff08;线程安全&#xff09; 提前创建实例&#xff0c;好处在于该实例全局唯一&#xff0c;不…

c++之旅第九弹——模版

大家好啊&#xff0c;这里是c之旅第九弹&#xff0c;跟随我的步伐来开始这一篇的学习吧&#xff01; 如果有知识性错误&#xff0c;欢迎各位指正&#xff01;&#xff01;一起加油&#xff01;&#xff01; 创作不易&#xff0c;希望大家多多支持哦&#xff01; 一.模版的概念…

改进的注意力机制的yolov8和UCMCTrackerDeepSort的多目标跟踪系统

基于yolov8和UCMCTracker/DeepSort的注意力机制多目标跟踪系统 本项目是一个强大的多目标跟踪系统&#xff0c;基于[yolov8]链接和[UCMCTracker/DeepSot]/链接构建。 &#x1f3af; 功能 多目标跟踪&#xff1a;可以实现对视频中的多目标进行跟踪。目标检测&#xff1a;可以实…

2023年上半年信息系统项目管理师综合知识真题与答案解释(2)

2023年上半年信息系统项目管理师综合知识真题与答案解释(2) And Her Name Is? 她的名字是&#xff1f; During my second month of college, our professor gave us a pop quiz. 在我上大学的第二个月&#xff0c;我们的教授给了我们一个流行测验。 I was a conscientio…

自然语言控制机械臂:ChatGPT与机器人技术的融合创新(上)

引言&#xff1a; 自OpenAI发布ChatGPT以来&#xff0c;世界正迅速朝着更广泛地将AI技术融合到机器人设备中的趋势发展。机械手臂&#xff0c;作为自动化与智能化技术的重要组成部分&#xff0c;在制造业、医疗、服务业等领域的应用日益广泛。随着AI技术的进步&#xff0c;机械…

开源大数据集群部署(二十)Trino部署

作者&#xff1a;櫰木 1 解压trino的包到opt目录 cd /root/bigdata tar -xzvf trino-server-389.tar.gz -C /opt/ ln -s /opt/trino-server-389 /opt/trino2 创建trino用户&#xff0c;并配置专属jdk11 useradd trino su – trino chown -R trino:hadoop /opt/trino-server-…

async+await——用法——基础积累

对于asyncawait&#xff0c;我一直都不太会用。。。。 今天记录一下asyncawait的实际用法&#xff1a; 下面是一个实际的使用场景&#xff1a; 上面的代码如下&#xff1a; async fnConfirmCR(){let type this.crType;let crId this.crId;if(typeof crId object){let ne…

一起学习python——基础篇(13)

前言&#xff0c;python编程语言对于我个人来说学习的目的是为了测试。我主要做的是移动端的开发工作&#xff0c;常见的测试主要分为两块&#xff0c;一块为移动端独立的页面功能&#xff0c;另外一块就是和其他人对接工作。 对接内容主要有硬件通信协议、软件接口文档。而涉…

andorid 矢量图fillColor设置无效

问题&#xff1a;andorid 矢量图fillColor设置无效 解决&#xff1a;去掉如下 android:tint一行

股票手续费怎么降下来?这些技巧帮你省钱!

在股票交易中&#xff0c;手续费是每个投资者都必须面对的成本。降低手续费可以有效地增加投资回报。以下是一些降低股票手续费的方法&#xff1a; 1. 选择低佣金的券商&#xff1a;不同的证券公司提供的佣金费率不同&#xff0c;选择佣金较低的券商可以直接减少交易成本 2. 增…

antd+vue——datepicker日期控件——禁用日期功能

需求&#xff1a;今天之前的日期禁用 <a-date-pickerv-model.trim"formNE.deliveryTime":disabled-date"disabledDate"valueFormat"YYYY-MM-DD"allowClearstyle"width: 100%" />禁用日期的范围&#xff1a; //时间范围 disab…

【C语言】C语言题库【附源码+持续更新】

欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 目录 1、练习2-1 Programming in C is fun! 2、练习2-3 输出倒三角图案 3、练习2-4 温度转换 4、练习2-6 计算物体自由下落的距离 5、练习2-8 计算摄氏温度 6、练习2-9 整数四则运算 7、练习2-10 计算分段函数[1…

3D目标检测跟踪 | 基于kitti+waymo数据集的自动驾驶场景的3D目标检测+跟踪渲染可视化

项目应用场景 面向自动驾驶场景的 3D 目标检测目标跟踪&#xff0c;基于kittiwaymo数据集的自动驾驶场景的3D目标检测跟踪渲染可视化查看。 项目效果 项目细节 > 具体参见项目 README.md (1) Kitti detection 数据集结构 # For Kitti Detection Dataset └── k…

力扣347. 前 K 个高频元素

思路&#xff1a;记录元素出现的次数用map&#xff1b; 要维护前k个元素&#xff0c;不至于把所有元素都排序再取前k个&#xff0c;而是新建一个堆&#xff0c;用小根堆存放前k个最大的数。 为什么是小根堆&#xff1f;因为堆每次出数据时只出堆顶&#xff0c;每次把当前最小的…