M2m中的采样

 采样的完整代码

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, SubsetRandomSampler

def get_oversampled_data(dataset, num_sample_per_class):
    """ Generate a list of indices that represents oversampling of the dataset. """
    targets = np.array(dataset.targets)
    class_sample_count = np.array([num_sample_per_class[target] for target in targets])
    weight = 1. / class_sample_count
    samples_weight = torch.from_numpy(weight)
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    return sampler

def get_val_test_data(dataset, num_test_samples):
    """ Split dataset into validation and test indices. """
    num_classes = 10
    targets = dataset.targets
    test_indices = []
    val_indices = []

    for i in range(num_classes):
        indices = [j for j, x in enumerate(targets) if x == i]
        np.random.shuffle(indices)
        val_indices.extend(indices[:num_test_samples])
        test_indices.extend(indices[num_test_samples:num_test_samples*2])

    return val_indices, test_indices

def get_oversampled(dataset_name, num_sample_per_class, batch_size, transform_train, transform_test):
    """ Create training and testing loaders with oversampling for imbalance. """
    dataset_class = datasets.__dict__[dataset_presets[dataset_name]['class']]
    dataset_train = dataset_class(root='./data', train=True, download=True, transform=transform_train)
    dataset_test = dataset_class(root='./data', train=False, download=True, transform=transform_test)

    # Oversampling
    sampler = get_oversampled_data(dataset_train, num_sample_per_class)
    train_loader = DataLoader(dataset_train, batch_size=batch_size, sampler=sampler)

    # Validation and Test split
    val_idx, test_idx = get_val_test_data(dataset_test, 1000)
    val_loader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(val_idx))
    test_loader = DataLoader(dataset_test, batch_size=batch_size, sampler=SubsetRandomSampler(test_idx))

    return train_loader, val_loader, test_loader

# Configuration and run
dataset_presets = {
    'cifar10': {'class': 'CIFAR10', 'num_classes': 10}
}
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
num_sample_per_class = [500] * 10  # Pretend we want equal class distribution

train_loader, val_loader, test_loader = get_oversampled('cifar10', num_sample_per_class, 64, transform, transform)

# Print out some info from loaders
for i, (inputs, targets) in enumerate(train_loader):
    print(f'Batch {i}, Targets Counts: {torch.bincount(targets)}')
    if i == 1:  # Just show first two batches for demonstration
        break

WeightedRandomSampler类的__iter__

def __iter__(self) -> Iterator[int]:
    rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
    return iter(rand_tensor.tolist())
  • 方法功能:此方法实现了迭代器协议,允许WeightedRandomSampler对象在迭代中返回一系列随机选择的索引。

过采样的效果

get_oversampled函数中,使用了WeightedRandomSampler来实现过采样的逻辑。这个过程虽然看起来是通过权重调整样本的选取概率,但实际上,通过这种方式也可以达到过采样的效果,尤其是当设置replacement=True时。让我们更详细地分析一下这一点:

权重的分配

权重是根据num_sample_per_class数组分配的,这个数组定义了每个类别希望被采样到的频率。在数据加载过程中,每个类别的样本将根据其在num_sample_per_class中对应的值获得一个权重。权重越大的类别在每次迭代中

被选中的概率也越大。这样,通过调整这些权重,我们可以控制模型在训练过程中看到的每个类别样本的频率,实现对类别不平衡的处理。

过采样的实现

在使用WeightedRandomSampler时,关键的参数是replacement

  • 如果replacement=True:这允许同一个样本在一次抽样中被多次选择,即进行了过采样。对于少数类的样本来说,即使它们在数据集中的绝对数量不多,也可以通过这种方式增加它们在每个训练批次中出现的次数,从而让模型更频繁地从这些少数类样本学习。

  • 如果replacement=False:则每个样本只能被抽样一次,这通常用于不放回的抽样。在这种模式下,WeightedRandomSampler不会直接导致过采样,但可以用来确保每个类别在数据批次中都有均等的代表性,从而帮助模型学习到更平衡的特征。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/652063.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Brewer Science将在CS Mantech进行展示

在风景如画的亚利桑那州图森市举办的CS Mantech盛会上(2024年5月20日至23日),杰出化合物半导体材料企业Brewer Science,将带来一场名为“化合物半导体制造的创新材料解决方案”的演讲盛宴。这一演讲,定于五月二十一日星…

宝塔:如何在宝塔面板做301重定向

如何在宝塔面板做301重定向?301重定向对于网站来说非常重要。如果你的网站以www开头,我们应该把没有www的域名重定向到有www的域名,反之亦然。 1、我们进入宝塔管理后台 2、登录面板并单击添加站点。既然要把xxx.com 301发到www.xxx.com,我…

【设计模式深度剖析】【5】【结构型】【桥接模式】| 以电视和遥控器为例加深理解

👈️上一篇:组合模式 设计模式-专栏👈️ 目 录 桥接模式(Bridge Pattern)定义英文原话是:直译理解 4个角色UML类图代码示例 应用优点缺点使用场景 示例解析:电视和遥控器UML类图 桥接模式(Bridge Pattern) 定义 英文原话是&am…

最新!!2024年上半年软考【中级】网络工程师 综合知识真题解析

2024上半年软考考试已经结束了,为大家整理了网友回忆版的网络工程师真题及答案,总共41道题。 上半年考试的宝子们可以对答案预估分数!准备下半年考的宝子可以提前把握考试知识点和出题方向,说不定会遇到相同考点的题目&#xff01…

基于ssm+vue图书管理系统

基于ssmvue图书管理系统 ssm477图书管理系统 相关技术 javassmmysqlvueelementui

上海亚商投顾:沪指震荡反弹 半导体产业链午后爆发

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 一.市场情绪 沪指昨日震荡反弹,尾盘涨幅扩大至1%,深成指、创业板指同步上行,科创50指数…

搭载昇腾310NPU的Orange Pi AIpro开箱体验以及深度学习样例测试

Orange Pi AIpro开箱体验以及样例测试 随着人工智能和物联网技术的快速发展,单板计算机(Single Board Computer, SBC)在创客和开发者社区中越来越受到欢迎。我最近入手了一款高性能的单板计算机——Orange Pi AIpro。 在入手此款AI开发板之…

【三维重建】ePnP

PnP问题应用与一下场景: 已知三维点和对应二维点以及相机相机内参数,可以获取相机外参。 我们介绍其中的一种算法:ePnP 算法流程 1、ePnP算法首先在世界坐标系内寻找4个控制点,记作 C 1 w , C 2 w , C 3 w , C 4 w C_1^w,C_2^w,…

Laravel和ThinkPHP框架比较

一、开发体验与易用性比较 1. 代码可读性: - Laravel以其优雅的语法和良好的代码结构著称,使得代码更加易读易懂。 - 相比之下,ThinkPHP的代码可读性较为一般,在一些复杂业务场景下,可能会稍显混乱。 让您能够一站式…

每天写两道(一):无重复字符的最长子串、反转链表

3. 无重复字符的最长子串 3. 无重复字符的最长子串 给定一个字符串 s ,请你找出其中不含有重复字符的 最长子串的长度。 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc",所以其长度为 3。(1)滑动窗口 双…

Web安全:文件上传漏洞详解,文件上传漏洞原理、绕过方式和防御方案。

「作者简介」:2022年北京冬奥会网络安全中国代表队,CSDN Top100,就职奇安信多年,以实战工作为基础对安全知识体系进行总结与归纳,著作适用于快速入门的 《网络安全自学教程》,内容涵盖系统安全、信息收集等…

搭建服务器的主流中间件有哪些?如何在外网访问内网的服务?

计算机业内人士对于搭建服务器的中间件并不陌生,apache、tomcat、IIS、nginx 都是比较常用的搭建服务器的中间件,它们之间还是有一些区别差异的。今天就说说这些中间件之间有哪些区别,以及如何利用快解析实现内网主机应用让外网访问。 首先说…

c++ 将指针转换为 void* 后,转换为怎么判断原指针类型?

当将指针转换为void后,擦除了指针所指向对象的类型信息,因此无法通过void指针来判断原始指针的类型。我这里有一套编程入门教程,不仅包含了详细的视频讲解,项目实战。如果你渴望学习编程,不妨点个关注,给个…

【状态机动态规划】3129. 找出所有稳定的二进制数组 I

本文涉及知识点 动态规划汇总 LeetCode 3129. 找出所有稳定的二进制数组 I 给你 3 个正整数 zero ,one 和 limit 。 一个 二进制数组 arr 如果满足以下条件,那么我们称它是 稳定的 : 0 在 arr 中出现次数 恰好 为 zero 。 1 在 arr 中出现…

OpenWrt 23.05 安装之后默认空间小 磁盘扩容 教程 软路由实测 系列六

1 安装fdisk opkg update opkg install fdisk #查看磁盘 rootOpenWrt:~# fdisk -l GPT PMBR size mismatch (246303 ! 250069679) will be corrected by write. The backup GPT table is not on the end of the device. Disk /dev/sda: 119.24 GiB, 128035676160 bytes, 25006…

【leetcode 141】环形链表——快慢指针(龟兔赛跑)

给你一个链表的头节点 head ,判断链表中是否有环。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整数 pos 来表示链表尾连接到链表中的位置(…

如何使用Spring Cache优化后端接口?

Spring Cache是Spring框架提供的一种缓存抽象,它可以很方便地集成到应用程序中,用于提高接口的性能和响应速度。使用Spring Cache可以避免重复执行耗时的方法,并且还可以提供一个统一的缓存管理机制,简化缓存的配置和管理。 本文将详细介绍如何使用Spring Cache来优化接口,…

【Java】JavaSE概述

1、简介 Java SE(Java Platform, Standard Edition)是Java技术的核心平台,它提供了Java编程语言、Java虚拟机(JVM)以及Java核心类库和API。Java SE主要用于开发和部署桌面应用程序、服务器应用程序、命令行工具和嵌入…

kkFileView——全能的在线文件预览解决方案

引言 在数字化办公日益普及的今天,文件的在线预览成为了一个不可或缺的功能。无论是个人还是企业,都希望能够在浏览器中直接打开并浏览各种格式的文档。今天,我们将探索一款国产开源免费的在线文件文档预览软件——kkFileView。 一、kkFile…

Pag格式在vue3中的简单使用方法

目前前端使用pag格式的方法比较少&#xff0c; 在这里我来简单实现一下pag格式在vue3中的使用方式。 第一步 先下载啦 npm i libpag 来对pag文件安装依赖 其次我们在自己想要引入的vue页面进行引入 <script setup> import { ref, computed, watchEffect, nextTick …