【多任务学习】Multi-task Learning 手把手编码带数据集, 一文吃透多任务学习

文章目录

  • 前言
  • 1.多任务学习
    • 1.1 定义
    • 1.2 原理
  • 2. 多任务学习code
    • 2.1 数据集初探
    • 2.2 预处理
    • 2.3 网络结构
    • 2.4 训练
  • 3. 总结


前言

我们之前讲过的模型通常聚焦单个任务,比如预测图片的类别等,在训练的时候,我们会关注某一个特定指标的优化.
但是有时候,我们需要知道一个图片,从它身上知道新闻的类型(政治/体育/娱乐)和是男性的新闻还是女性的.
我们关注某一个特定指标的优化,可能忽略了对有关注的指标的有用信息.具体来说就是训练相关任务所带来的额外信息,通过在多个相关任务中共享表示,我们可以使得模型在我们原本任务上获得更好的泛化能力.这种方法就叫做多任务学习.


1.多任务学习

1.1 定义

同时完成多个预测,共享表示,共享特征提取.使得模型关注到一些特有的特征.其实一套提取特征的网络,配合多个损失函数,就是多任务损失.
图像定位是单任务,若还需要知道类别,就变成了多任务学习.
在这里插入图片描述

1.2 原理

多任务学习的模型通常通过所有任务重共用隐藏层(特征提取层),而针对不同任务使用多个输出层来实现.自动学习到的任务越多,模型就能获得捕捉所有任务的表示,而原本任务上过拟合的风险更小.
多任务学习中,针对一个任务的特征提取,由于其它任务也能对提取的特征做出筛选,所以可以帮助模型将注意力集中到那些真正起作用的特征上.
模型会学习那些尽量表达多个任务的特征,而这些特征泛化能力会很好.

2. 多任务学习code

同时预测一个物品的颜色和类别.

2.1 数据集初探

一个分支用于分类给定输入图像的服装种类(比如衬衫、裙子、牛仔裤、鞋子等);
另一个分支负责分类该服装的颜色(黑色、红色、蓝色等)。
总体而言,我们的数据集由 2525 张图像构成,分为 7 种「颜色+类别」组合,包括:

黑色牛仔裤(344 张图像)
黑色鞋子(358 张图像)
蓝色裙子(386 张图像)
蓝色牛仔裤(356 张图像)
蓝色衬衫(369 张图像)
红色裙子(380 张图像)
红色衬衫(332 张图像)
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc

2.2 预处理

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
import glob
from torchvision import transforms
from torch.utils import data
from PIL import Image

img_paths = glob.glob(r"F:\multi-output-classification\dataset\*\*.jpg")
img_paths[:5]

在这里插入图片描述
路径文件夹就表示了标签,所以要获取其标签:

label_names = [img_path.split("\\")[-2] for img_path in img_paths]
label_names[:5]

在这里插入图片描述

label_array = np.array([la.split("_") for la in label_names])
label_array

在这里插入图片描述

label_color = label_array[:,0]
label_color

在这里插入图片描述

label_item = label_array[:,1]
label_item


吧他们转成index,因为torch中只认数字

unique_color = np.unique(label_color)
unique_color
unique_item = np.unique(label_item)
unique_item
item_to_idx = dict((v,k) for k, v in enumerate(unique_item))
item_to_idx
color_to_idx = dict((v,k) for k, v in enumerate(unique_color))
color_to_idx
label_item = [item_to_idx.get(k) for k in label_item]
label_color = [color_to_idx.get(k) for k in label_color ]
transform = transforms.Compose([
    transforms.Resize((96,96)),
    transforms.ToTensor(),
])

自定义数据集

class Multi_dataset(data.Dataset):
    def __init__(self,imgs_path, label_color, label_item) -> None:
        super().__init__()
        self.imgs_path = imgs_path
        self.label_color = label_color
        self.label_item = label_item
    
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        pil_img = Image.open(img_path)
        # 防止有图片有黑白图
        pil_img = pil_img.convert('RGB')
        pil_img = transform(pil_img)
        label_c = self.label_color[index]
        label_i = self.label_item[index]
        return pil_img, (label_c,label_i)
    def __len__(self):
        return len(self.imgs_path)

划分训练集

count = len(multi_dataset)
count
# 划分训练集 测试集
train_count = int(count*0.8)
test_count =  count - train_count
train_ds, test_ds = data.random_split(multi_dataset,[train_count, test_count])
len(train_ds),len(test_ds)
BATCHSIZE = 32
train_dl = data.DataLoader(train_ds,batch_size=BATCHSIZE,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=BATCHSIZE)

在这里插入图片描述
在这里插入图片描述

2.3 网络结构

2.4 训练

3. 总结

未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/17788.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PostgreSQL 基础知识:psql 提示和技巧

对于积极使用和连接到 PostgreSQL 数据库的任何开发人员或 DBA 来说,能够访问psql命令行工具是必不可少的。在我们的第一篇文章中,我们讨论了 psql的简要历史,并演示了如何在您选择的平台上安装它并连接到 PostgreSQL 数据库。 在本文中&…

HTTPS协议介绍

文章目录 一、HTTPS协议的认识二、常见的加密方式1.对称加密2.非对称加密 三、数据摘要四、HTTPS的工作过程探究1.只使用对称加密2.只使用非对称加密3.双方都使用非对称加密4.非对称加密对称加密5.中间人攻击6.引入证书7.非对称加密对称加密证书认证 一、HTTPS协议的认识 HTTP…

HTTP的method方法 GET POST PUT DELETE HEAD OPTIONS CONNECT PATCH TRACE

HTTP的method方法 GET POST PUT DELETE HEAD OPTIONS CONNECT PATCH TRACE GET 向指定的资源发出“显示”请求。使用GET方法应该只用在读取数据,而不应当被用于产生“副作用”的操作中,例如在Web Application中。其中一个原因是GET可能会被网络蜘蛛等随意…

Docker 持久化存储 Bind mounts

Docker 持久化存储 Bind mounts Bind mounts 的 -v 与 --mount 区别启动容器基于bind mount挂载到容器中的非空目录只读 bind mountcompose 中使用 bind mount 官方文档:https://docs.docker.com/storage/bind-mounts/ Bind mounts 的 -v 与 --mount 区别 如果使用…

ePWM模块(1)

ePWM模块 ePWM模块内部包含有7个子模块,分别是时间基准子模块TB、比较功能子模块CC,动作限定子模块AQ、死区控制子模块DB、斩波控制子模块PC、事件触发子模块ET和故障捕获子模块TZ。 每个ePWM模块都具有以下功能: 可以输出两路PWM,EPWMxA和EPWMxB两路PWM可以独立输出,也可…

大二一个学期学这么点内容,没有概念,只有实操

如何查看所有的数据库: Show databases; 如何进入某个数据库: use xxx; 如何新进数据库: Create database jx; 如何删除数据库: Drop database jx; 如何查看所有的表格: Show tables; 如何创建数据表&#xf…

【Flink】DataStream API使用之执行环境

1. 执行环境 Flink 程序可以在各种上下文环境中运行:我们可以在本地 JVM 中执行程序,也可以提交 到远程集群上运行。不同的环境,代码的提交运行的过程会有所不同。这就要求我们在提交作业执行计算时,首先必须获取当前 Flink 的运…

(异或相消)猫猫数字异或和

E - Red Scarf (atcoder.jp) 刚入坑写的一道题被我拉出来对比分析了 我的思路: 垃圾运气选手凭借直觉乱搞猜出来的,没有思路。 题解思路: 由问题陈述中XOR的定义,我们可以看出计算3个或更多整数的XOR可以以任意顺序进行&#…

ChatGPT :十几个国内免费可用 ChatGPT 网页版

前言 ChatGPT(全名:Chat Generative Pre-trained Transformer),美国OpenAI 研发的聊天机器人程序 ,于2022年11月30日发布 。ChatGPT是人工智能技术驱动的自然语言处理工具,它能够通过理解和学习人类的语言…

树脂塞孔有哪些优缺点及应用?

树脂塞孔的概述 树脂塞孔就是利用导电或者非导电树脂,通过印刷,利用一切可能的方式,在机械通孔、机械盲埋孔等各种类型的孔内进行填充,实现塞孔的目的。 树脂塞孔的目的 1 树脂填充各种盲埋孔之后,利于层压的真空下…

盲目自学网络安全只会成为脚本小子?

前言:我们来看看怎么学才不会成为脚本小子 目录: 一,怎么入门? 1、Web 安全相关概念(2 周)2、熟悉渗透相关工具(3 周)3、渗透实战操作(5 周)4、关注安全圈动…

2.RabbitMQ

RabbitMQ 1.初识MQ 1.1.同步和异步通讯 微服务间通讯有同步和异步两种方式: 同步通讯:就像打电话,需要实时响应。 异步通讯:就像发邮件,不需要马上回复。 两种方式各有优劣,打电话可以立即得到响应&am…

springboot整合flowable工作流引擎的简单使用

内容来自网络整理,文章最下有引用地址,可跳转至相关资源页面。若有侵权请联系删除 环境: mysql5.7.2 springboot 2.3.9.RELEASE flowable 6.7.2 采坑: 1.当前flowable sql需要与引用的pom依赖一致,否则会报library…

进程/线程 状态模型详解

前言:最近操作系统复习到线程的状态模型(也可以说进程的状态模型,本文直接用线程来说)时候,网上查阅资料,发现很多文章都说的很不一样,有五状态模型、六状态模型、七状态模型.......虽然都是对的…

[Python]爬虫基础——urllib库

urllib目录 一、简介二、发送请求1、urlopen()函数2、Request()函数 三、异常处理四、解析URL五、分析Robots协议 一、简介 urllib库是Python内置的标准库。包含以下四个模块: 1、request:模拟发送HTTP请求; 2、error:处理HTTP请…

前端技术——css

1.CSS的引入 【1】为什么要学习CSS? 如果只用HEML画页面的话--->这个页面就是页面上需要的元素罗列起来,但是页面效果很差,不好看,为了让页面好看,为了修饰页面。所以我们需要用到CSS。 CSS的作用:修饰HTML页面…

linux操作手册

开机&关机 指令 shutdown -h now 立刻进行关机 shutdown -h num num分钟后执行关机 shutdown -r now 现在重启计算机 halt 关机 rebboot 重启计算机 sync 把内存的数据同步到磁盘 注意事项 无论是重启还是关闭系统,都必须先执行sync,将内存…

使用无标注的数据训练Bert

文章目录 1、准备用于训练的数据集2、处理数据集3、克隆代码4、运行代码5、将ckpt模型转为bin模型使其可在pytorch中运用 Bert官方仓库:https://github.com/google-research/bert 1、准备用于训练的数据集 此处准备的是BBC news的数据集,下载链接&…

camunda表达式如何使用

在Camunda中,表达式是一种灵活的方式,可以用于在流程定义和表单中计算和处理数据。表达式可以在Camunda的各个环节中使用,例如服务任务、网关、表单、条件等。 以下是Camunda表达式的一些常见用途: 1、计算值:表达式可…

卢北辰:数据点亮梦想,能力驱动人生 | 提升之路系列(九)

导读 为了发挥清华大学多学科优势,搭建跨学科交叉融合平台,创新跨学科交叉培养模式,培养具有大数据思维和应用创新的“π”型人才,由清华大学研究生院、清华大学大数据研究中心及相关院系共同设计组织的“清华大学大数据能力提升项…