迁移学习案例-python代码

在这里插入图片描述

大白话

迁移学习就是用不太相同但又有一些联系的A和B数据,训练同一个网络。比如,先用A数据训练一下网络,然后再用B数据训练一下网络,那么就说最后的模型是从A迁移到B的。

迁移学习的具体形式是多种多样的,比如先用A训练好一个网络,然后复制这个网络的某几个层的参数到一个新的网络作为初始化的参数,然后用B数据去训练这个新网络。又或者,面对中文翻译的问题,中文翻译成英文和中文翻译成火星文,前几层在提取特征,可以共享参数层,后面几层由于任务不同就可以各自私有训练。

案例来源:李宏毅课程-机器学习-迁移学习

A数据:是源数据,量大效果好,并且有标签。
在这里插入图片描述
B数据:量少,没标签。
在这里插入图片描述

目的:希望用A数据先训练网络提取到关键特征,然后预测B数据的标签。但是把他们当作两个任务效果不佳,于是以一种迁移的方法解决--域对抗(先用A训练好模型,再直接用B测试,这样效果不佳;而是希望以一种“迁移”的方法,把A数据的知识拿到B上面用)

直接上代码

import matplotlib.pyplot as plt

def no_axis_show(img, title='', cmap=None):
    # imshow, 缩放模式为nearest。
    fig = plt.imshow(img, interpolation='nearest', cmap=cmap)
    # 不要显示axis。
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(title)

titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
for i in range(10):
    plt.subplot(1, 10, i+1)
    fig = no_axis_show(plt.imread(f'work/real_or_drawing/train_data/{i}/{500*i}.bmp'), title=titles[i])
plt.figure(figsize=(18, 18))
for i in range(10):
    plt.subplot(1, 10, i+1)
    fig = no_axis_show(plt.imread(f'work/real_or_drawing/test_data/0/' + str(i).rjust(5, '0') + '.bmp'))
import cv2
import matplotlib.pyplot as plt
titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))

original_img = plt.imread(f'work/real_or_drawing/train_data/0/0.bmp')
plt.subplot(1, 5, 1)
no_axis_show(original_img, title='original')

gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')

gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')

canny_50100 = cv2.Canny(gray_img, 50, 100)
plt.subplot(1, 5, 3)
no_axis_show(canny_50100, title='Canny(50, 100)', cmap='gray')

canny_150200 = cv2.Canny(gray_img, 150, 200)
plt.subplot(1, 5, 4)
no_axis_show(canny_150200, title='Canny(150, 200)', cmap='gray')

canny_250300 = cv2.Canny(gray_img, 250, 300)
plt.subplot(1, 5, 5)
no_axis_show(canny_250300, title='Canny(250, 300)', cmap='gray')
import cv2
import numpy as np
import paddle

import paddle.optimizer as optim
from paddle.io import DataLoader
from paddle.vision.datasets import DatasetFolder
from paddle.nn import Sequential, Conv2D, BatchNorm1D, BatchNorm2D, ReLU, MaxPool2D, Linear
from paddle.vision.transforms import Compose, Grayscale, Transpose, RandomHorizontalFlip, RandomRotation, Resize, ToTensor
class Canny(paddle.vision.transforms.transforms.BaseTransform):
    def __init__(self, low, high, keys=None):
        super(Canny, self).__init__(keys)
        self.low = low
        self.high = high

    def _apply_image(self, img):
        Canny = lambda img: cv2.Canny(np.array(img), self.low, self.high)
        return Canny(img)
source_transform = Compose([
    RandomHorizontalFlip(),
    RandomRotation(15),
    Grayscale(),
    Canny(low=170, high=300),
    # Transpose(),
    ToTensor()
    ])
target_transform = Compose([
    Grayscale(),
    Resize((32, 32)),
    RandomHorizontalFlip(),
    RandomRotation(15, fill=(0,)),
    ToTensor()
    ])

source_dataset = DatasetFolder('work/real_or_drawing/train_data', transform=source_transform)
target_dataset = DatasetFolder('work/real_or_drawing/test_data', transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)
class FeatureExtractor(paddle.nn.Layer):

    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.conv = Sequential(
            Conv2D(1, 64, 3, 1, 1),
            BatchNorm2D(64),
            ReLU(),
            MaxPool2D(2),

            Conv2D(64, 128, 3, 1, 1),
            BatchNorm2D(128),
            ReLU(),
            MaxPool2D(2),

            Conv2D(128, 256, 3, 1, 1),
            BatchNorm2D(256),
            ReLU(),
            MaxPool2D(2),

            Conv2D(256, 256, 3, 1, 1),
            BatchNorm2D(256),
            ReLU(),
            MaxPool2D(2),

            Conv2D(256, 512, 3, 1, 1),
            BatchNorm2D(512),
            ReLU(),
            MaxPool2D(2)
        )
        
    def forward(self, x):
        x = self.conv(x).squeeze()
        return x

class LabelPredictor(paddle.nn.Layer):

    def __init__(self):
        super(LabelPredictor, self).__init__()

        self.layer = Sequential(
            Linear(512, 512),
            ReLU(),

            Linear(512, 512),
            ReLU(),

            Linear(512, 10),
        )

    def forward(self, h):
        c = self.layer(h)
        return c

class DomainClassifier(paddle.nn.Layer):

    def __init__(self):
        super(DomainClassifier, self).__init__()

        self.layer = Sequential(
            Linear(512, 512),
            BatchNorm1D(512),
            ReLU(),

            Linear(512, 512),
            BatchNorm1D(512),
            ReLU(),

            Linear(512, 512),
            BatchNorm1D(512),
            ReLU(),

            Linear(512, 512),
            BatchNorm1D(512),
            ReLU(),

            Linear(512, 1),
        )

    def forward(self, h):
        y = self.layer(h)
        return y
feature_extractor = FeatureExtractor()
label_predictor = LabelPredictor()
domain_classifier = DomainClassifier()

class_criterion = paddle.nn.loss.CrossEntropyLoss()
domain_criterion = paddle.nn.BCEWithLogitsLoss()

optimizer_F = optim.Adam(parameters=feature_extractor.parameters())
optimizer_C = optim.Adam(parameters=label_predictor.parameters())
optimizer_D = optim.Adam(parameters=domain_classifier.parameters())
def train_epoch(source_dataloader, target_dataloader, lamb):
    '''
      Args:
        source_dataloader: source data的dataloader
        target_dataloader: target data的dataloader
        lamb: 调控adversarial的loss系数。
    '''

    # D loss: Domain Classifier的loss
    # F loss: Feature Extrator & Label Predictor的loss
    # total_hit: 计算目前对了几笔 total_num: 目前经过了几笔
    running_D_loss, running_F_loss = 0.0, 0.0
    total_hit, total_num = 0.0, 0.0

    for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):

        # source_data = source_data.cuda()
        # source_label = source_label.cuda()
        # target_data = target_data.cuda()
        
        # 我们把source data和target data混在一起,否则batch_norm可能会算错 (两边的data的mean/var不太一样)
        mixed_data = paddle.concat([source_data, target_data], axis=0)
        domain_label = paddle.zeros([source_data.shape[0] + target_data.shape[0], 1])
        # 设定source data的label为1
        domain_label[:source_data.shape[0]] = 1

        # Step 1 : 训练Domain Classifier
        feature = feature_extractor(mixed_data)
        # 因为我们在Step 1不需要训练Feature Extractor,所以把feature detach避免loss backprop上去。
        domain_logits = domain_classifier(feature.detach())
        # print('domain_logits.shape:', domain_logits.shape, 'domain_label.shape:', domain_label.shape)
        loss = domain_criterion(domain_logits, domain_label)
        # running_D_loss+= loss.numpy()[0]
        running_D_loss+= loss.numpy()
        # print('loss:', loss)
        loss.backward()
        optimizer_D.step()

        # Step 2 : 训练Feature Extractor和Domain Classifier
        class_logits = label_predictor(feature[:source_data.shape[0]])
        domain_logits = domain_classifier(feature)
        # loss为原本的class CE - lamb * domain BCE,相减的原因同GAN中的Discriminator中的G loss。
        loss = class_criterion(class_logits, source_label) - lamb * domain_criterion(domain_logits, domain_label)
        # running_F_loss+= loss.numpy()[0]
        running_F_loss+= loss.numpy()
        loss.backward()
        optimizer_F.step()
        optimizer_C.step()

        optimizer_D.clear_grad()
        optimizer_F.clear_grad()
        optimizer_C.clear_grad()
        # print('class_logits.shape:', class_logits.shape, 'source_label.shape:', source_label.shape)
        # print('class_logits[0]:', class_logits[0], 'source_label[0]:', source_label[0])
        total_hit += np.sum((paddle.argmax(class_logits, axis=1) == source_label).numpy())
        total_num += source_data.shape[0]
        print(i, end='\r')

    return running_D_loss / (i+1), running_F_loss / (i+1), total_hit / total_num

# 训练200 epochs
for epoch in range(200):
    train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, lamb=0.1)

    paddle.save(feature_extractor.state_dict(), f'extractor_model.pdparams')
    paddle.save(label_predictor.state_dict(), f'predictor_model.pdparams')

    print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))

训练结束,预测一波

result = []
label_predictor.eval()
feature_extractor.eval()
for i, (test_data, _) in enumerate(test_dataloader):
    test_data = test_data

    class_logits = label_predictor(feature_extractor(test_data))

    x = paddle.argmax(class_logits, axis=1).detach().numpy()
    result.append(x)

import pandas as pd
result = np.concatenate(result)

# Generate your submission
df = pd.DataFrame({'id': np.arange(0,len(result)), 'label': result})
df.to_csv('work/DaNN_submission.csv',index=False)

训练比较慢,还得是把代码转到cuda上才行,demo可以把epoch减小一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/887907.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode讲解篇之300. 最长递增子序列

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 这题我们可以通过动态规划求解&#xff0c;使用一个数组f&#xff0c;数组f的i号元素表示[0, i]范围内最长递增子序列的长度 状态转移方程&#xff1a;f[i] max(f[j] 1)&#xff0c;其中 0 < j < i 题…

docker快速安装ELK

一、创建elk目录 创建/elk/elasticsearch/data/目录 mkdir -p /usr/local/share/elk/elasticsearch/data/ 创建/elk/logstash/pipeline/目录 mkdir -p /usr/local/share/elk/logstash/pipeline/ 创建/elk/kibana/conf/目录 mkdir -p /usr/local/share/elk/kibana/conf/ 二、创建…

大模型应用新领域:探寻终端侧 AI 竞争核心|智于终端

2024年过去2/3&#xff0c;大模型领域的一个共识开始愈加清晰&#xff1a; AI技术的真正价值在于其普惠性。没有应用&#xff0c;基础模型将无法发挥其价值。 于是乎&#xff0c;回顾这大半年&#xff0c;从互联网大厂到手机厂商&#xff0c;各路人马都在探索AI时代Killer AP…

【超级详细解释】力扣每日一题 134.加油站 48. 旋转图像

134.加油站 力扣 这是一个很好的问题。这个思路其实基于一种贪心策略。我们从整个路径的油量变化来理解它&#xff0c;结合一个直观的“最低点法则”&#xff0c;来确保找到正确的起点。 问题的核心&#xff1a;油量差值的累积 对于每个加油站&#xff0c;我们有两个数组&…

1、如何查看电脑已经连接上的wifi的密码?

在电脑桌面右下角的如下位置&#xff1a;双击打开查看当前连接上的wifi的名字&#xff1a;ZTE-kfdGYX-5G 按一下键盘上的win R 键, 输入【cmd】 然后&#xff0c;按一下【回车】。 输入netsh wlan show profile ”wifi名称” keyclear : 输入完成后&#xff0c;按一下回车&…

51单片机的水质检测系统【proteus仿真+程序+报告+原理图+演示视频】

1、主要功能 该系统由AT89C51/STC89C52单片机LCD1602显示模块温度传感器ph传感器浑浊度传感器蓝牙继电器LED、按键和蜂鸣器等模块构成。适用于水质监测系统&#xff0c;含检测和调整水温、浑浊度、ph等相似项目。 可实现功能: 1、LCD1602实时显示水温、水体ph和浑浊度 2、温…

Studying-多线程学习Part3 - condition_variable与其使用场景、C++11实现跨平台线程池

来源&#xff1a;多线程学习 目录 condition_variable与其使用场景 生产者与消费者模型 C11实现跨平台线程池 condition_variable与其使用场景 生产者与消费者模型 生产者-消费者模式是一种经典的多线程设计模式&#xff0c;用于解决多个线程之间的数据共享和协作问题。…

基于PHP的校园二手书交易管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的校园二手书交易管理系统 一 介绍 此二手书交易管理系统基于原生PHP开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为用户和管理员。 技术栈&#xff1a;phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注…

k8s中pod的管理

一、资源管理 1.概述 说到k8s中的pod&#xff0c;即荚的意思&#xff0c;就不得不先提到k8s中的资源管理&#xff0c;k8s中可以用以下命令查看我们的资源&#xff1a; kubectl api-resources 比如我们现在需要使用k8s开启一个东西&#xff0c;那么k8s通过apiserver去对比etc…

《从零开始大模型开发与微调》真的把大模型说透了!零基础入门一定要看!

2022年底&#xff0c;ChatGPT震撼上线&#xff0c;大语言模型技术迅速“席卷”了整个社会&#xff0c;人工智能技术因此迎来了一次重要进展。与大语言模型相关的研发岗薪资更是水涨船高&#xff0c;基本都是5w月薪起。很多程序员也想跟上ChatGPT脚步&#xff0c;今天给大家带来…

51单片机系列-串口(UART)通信技术

&#x1f308;个人主页&#xff1a; 羽晨同学 &#x1f4ab;个人格言:“成为自己未来的主人~” 并行通信和串行通信 并行方式 并行方式&#xff1a;数据的各位用多条数据线同时发送或者同时接收 并行通信特点&#xff1a;传送速度快&#xff0c;但因需要多根传输线&#xf…

免杀对抗—GOC#反VT沙盒资源分离混淆加密

前言 今天的主要内容是反VT沙盒&#xff0c;我们都知道生成的后门会被杀软上穿到沙盒中去运行&#xff0c;去逆向。如此一来我们的后门就很容易被查杀掉&#xff0c;但如果我们对后门进行一些操作&#xff0c;让它在被逆向的时候&#xff0c;反编译出一堆乱码&#xff0c;或者…

【大语言模型-论文精读】用于医疗领域摘要任务的大型语言模型评估综述

【大语言模型-论文精读】用于医疗领域摘要任务的大型语言模型评估综述 论文信息&#xff1a; 用于医疗领域摘要任务的大型语言模型评估&#xff1a;一篇叙述性综述&#xff0c; 文章是由 Emma Croxford , Yanjun Gao 博士 , Nicholas Pellegrino , Karen K. Wong 等人近期合作…

STM32PWM应用

目录 一、输出比较(OC) 二、PWM&#xff1a; 1、简介 2、基本结构 3、参数计算 三、PWM驱动LED呼吸灯 1、代码 四、PWM驱动Sg90舵机 1、工作原理 2、完整代码 五、PWM驱动直流电机 1、TB6612芯片模块 2、完整代码&#xff1a; 一、输出比较(OC) OC&#xff08;Outp…

智能医疗:Spring Boot医院管理系统开发

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常适…

【Python】PDFMiner.six:高效处理PDF文档的Python工具

PDF是一种广泛使用的文件格式&#xff0c;特别适用于呈现固定布局的文档。然而&#xff0c;提取PDF文件中的文本和信息并不总是那么简单。幸好有许多Python库可以帮助我们&#xff0c;其中&#xff0c;PDFMiner.six 是一个功能强大、专门用于PDF文档解析的库。 ⭕️宇宙起点 &a…

cnn突破四(生成卷积核与固定核对比)

cnn突破三中生成四个卷积核&#xff0c;训练6万次&#xff0c;91分&#xff0c;再训练6万次&#xff0c;95分&#xff0c;不是很满意&#xff0c;但又找不到问题点&#xff0c;所以就想了个办法&#xff0c;使用三个固定核&#xff0c;加上三层bpnet神经网络&#xff0c;看看效…

王道-数据结构

1 设数组data[m]作为循环队列的存储空间,front为队头指针,rear为队尾指针,则执行出队操作后其头指针front值为____ 答案:D 解析:队列的头指针指向队首元素的实际位置,因此出队操作后,头指针需向上移动一个元素的位置。循环队列的容量为m,所以头指针front加1以后,需…

【d60】【Java】【力扣】509. 斐波那契数

思路 要做的问题&#xff1a;求F&#xff08;n&#xff09;, F&#xff08;n&#xff09;就等于F(n-1)F(n-2)&#xff0c;要把这个F(n-1)F(n-2)当作常量&#xff0c;已经得到的值&#xff0c; 结束条件&#xff1a;如果是第1 第2 个数字的时候&#xff0c;没有n-1和n-2,所以…

闯关训练三:Git 基础知识

任务1: 破冰活动&#xff1a;自我介绍 点击Fork目标项目&#xff0c;创建一个新的Fork 获取仓库链接 在连接好开发机的vscode终端中逐行执行以下代码&#xff1a; git clone https://github.com/KelvinIII/Tutorial.git # 修改为自己frok的仓库 cd Tutorial/ git branch -a g…