【Segment Anything Model】四:预处理自己的数据集接入SAM

文章目录

  • 1️⃣预备知识
  • 2️⃣实现思路
  • 🔸脚本预处理得到包含embedd和GT的npz
  • 🔸编写Dataset类
  • 3️⃣代码
  • 🔸实现脚本预处理得到包含embedd和GT的npz代码
  • 🔸实现Dataset的代码

1️⃣预备知识

欢迎订阅本专栏(为爱发电,限时免费),联系前三篇一起食用哈!上一篇讲了如何使用SAM接口完成一个训练流程,本篇只专注于如何处理包装自己的数据集。

流程如下:
在这里插入图片描述
直接将图像编码器编码得到的embedding存入npz代表原始图像,是因为,我们有很多种训练策略,但每一次的编码过程是一摸一摸的,并且也是最耗时的一部分,所以,将其静态化,每次用的时候拿来解压。

由于比较粗糙并且没有做交叉验证,所以这里在原始图像存放路径的时候就划分好了训练测试,但一般自己的数据集还是做个交叉验证,在得到npz之后划分训练测试。

2️⃣实现思路

🔸脚本预处理得到包含embedd和GT的npz

embedding步骤:
1.归一化
2.ResizeLongestSide到1024*1024
3.sam_model.preprocess预处理
4.sam_model.image_encoder编码

GD步骤:1.校验GT是否是2D 2.校验是否和img尺寸大小相同 3.uint8到255

🔸编写Dataset类

init:解压npz,读取数据放入self变量

getitem:根据GT获得边界框当作框提示,在GT内随机选择点当作点提示, 将embedd,box,point,GT,组装torch.tensor

len:返回图片个数就好啦

在这里插入图片描述

3️⃣代码

🔸实现脚本预处理得到包含embedd和GT的npz代码

注释都在代码里吗,按行注释,我真贴心💓

import numpy as np
import os
join = os.path.join
from skimage import transform, io
from tqdm import tqdm
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide

# GT存放路径,到文件夹
gt_path = "./"
# 组装好npz的保存路径
save_path = "./"
# 获取所有GT图像名称
names = sorted(os.listdir(gt_path))
os.makedirs(save_path, exist_ok=True)
model_type = 'vit_b'
checkpoint = 'xx/sam_vit_b_01ec64.pth'
device = 'cuda:0'
sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
imgs = []
gts = []
img_embeddings = []
# image路径 到最后一层文件夹
img_path=""
for gt_name in tqdm(names):
    # 如果你是jpg改一下后缀
    image_name = gt_name.split('.')[0] + "png"
    # 读取GT
    gt_data = io.imread(join(gt_path, gt_name))
    # GT必须是2D,如果是3D就取前两通道
    if len(gt_data.shape) == 3:
        gt_data = gt_data[:, :, 0]
    assert len(gt_data.shape) == 2, 'GT must be 2D'
    # 尺寸转256数值转255
    gt_data = transform.resize(gt_data == 255, (256, 256), order=0,
                               preserve_range=True, mode='constant')
    gt_data = np.uint8(gt_data)

    # 排除GT特别小的情况,这条可以不加
    if np.sum(gt_data) > 100:
        assert np.max(gt_data) == 1 and np.unique(gt_data).shape[0] == 2, 'GT must be 2D'
        image_data = io.imread(join(img_path, image_name))
        # 计算最大值最小值
        lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(image_data, 99.5)
        # 排除特别特殊的像素
        image_data_pre = np.clip(image_data, lower_bound, upper_bound)
        # 归一化
        image_data_pre = (image_data_pre - np.min(image_data_pre)) / (
                np.max(image_data_pre) - np.min(image_data_pre)) * 255.0
        image_data_pre[image_data == 0] = 0
        # 归一化
        image_data_pre = transform.resize(image_data_pre, (256, 256), order=3,
                                          preserve_range=True, mode='constant', anti_aliasing=True)
        image_data_pre = np.uint8(image_data_pre)

        imgs.append(image_data_pre)
        gts.append(gt_data)
        # SAM提供的resize到1024
        sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
        resize_img = sam_transform.apply_image(image_data_pre)
        # resize_img是通道在后,sam要求通道在前,transposehi是对resize_img数组进行维度重排(dimension reordering)的操作。
        resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device)
        # 增加一个channel假装当作有一个batchsize输入到sam_model.image_encoder
        input_image = sam_model.preprocess(resize_img_tensor[None, :, :, :])  # (1, 3, 1024, 1024)
        # 提前计算图像embedding
        with torch.no_grad():
            embedding = sam_model.image_encoder(input_image)
            img_embeddings.append(embedding.cpu().numpy()[0])

# 上面数据已经处理好并存在数组了,需要数据字典存在npz中
# 沿着纵轴堆砌,每一个都是(256, 256, 3),堆起来是(n, 256, 256, 3)
imgs = np.stack(imgs, axis=0)  # (n, 256, 256, 3)
gts = np.stack(gts, axis=0)  # (n, 256, 256)
img_embeddings = np.stack(img_embeddings, axis=0)  # (n, 1, 256, 64, 64)
# np的保存npz操作
np.savez_compressed(join(save_path, '.npz'), imgs=imgs, gts=gts, img_embeddings=img_embeddings)

🔸实现Dataset的代码

import numpy as np
import matplotlib.pyplot as plt
import os

join = os.path.join
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import monai
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import random

torch.manual_seed(2023)

# 构造自己的Dataset继承Dataset类
class MyselfDataset(Dataset):
    def __init__(self, data_root):
        print("into init")
        self.data_root = data_root
        # 访问npz文件
        self.npz_files = sorted(os.listdir(self.data_root))
        # 去除npz里的数据
        self.npz_data = [np.load(join(data_root, f)) for f in self.npz_files]
        # 将取出来的数据放在变量保存
        self.ori_gts = np.vstack([d['gts'] for d in self.npz_data])
        self.img_embeddings = np.vstack([d['img_embeddings'] for d in self.npz_data])

    def __len__(self):
        return self.ori_gts.shape[0]

    def __getitem__(self, index):
        img_embed = self.img_embeddings[index]
        gt2D = self.ori_gts[index]
        # 获取非零点坐标
        y_indices, x_indices = np.where(gt2D > 0)
        # 获取GT坐标框
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)
        # 在GT框加扰动
        H, W = gt2D.shape
        x_min = max(0, x_min - np.random.randint(0, 10))
        x_max = min(W, x_max + np.random.randint(0, 10))
        y_min = max(0, y_min - np.random.randint(0, 10))
        y_max = min(H, y_max + np.random.randint(0, 10))
        bboxes = np.array([x_min, y_min, x_max, y_max])

        # 在GT在5像素以内的地方随机选择两个背景点
        y_zero, x_zero = np.where(gt2D == 0)
        y_zero = np.unique(y_zero)
        x_zero = np.unique(x_zero)
        y_list = y_zero[(y_min - 5 < y_zero) & (y_zero < y_max + 5)]
        x_list = x_zero[(x_min - 5 < x_zero) & (x_zero < x_max + 5)]
        y1, y2 = random.choices(y_list, k=2)
        x1, x2 = random.choices(x_list, k=2)
        background_index1 = [x1, y1]
        background_index2 = [x2, y2]
        # 在GT内随机选择前景点
        foreground_index1, foreground_index2, foreground_index3 = random.choices(
            np.argwhere(gt2D == 1), k=3)
        # 将所有选择好的点添加到list,如果是单点,不需要直接返回点的index就好。
        pt_list_s = []
        pt_list_s.append(background_index1)
        pt_list_s.append(background_index2)
        pt_list_s.append(foreground_index1)
        pt_list_s.append(foreground_index2)
        pt_list_s.append(foreground_index3)
        points = pt_list_s
        # 0是背景1是前景
        points_labels = [0, 0, 1, 1, 1]
        return torch.tensor(img_embed).float(), torch.tensor(gt2D[None, :, :]).long(), torch.tensor(
            bboxes).float(), torch.tensor(points).float(), torch.tensor(points_labels).float()

之后连系上篇 【Segment Anything Model】SAM模型微调自定义数据集,更改混合提示方式:点,框,点框混合
在这里取值训练就好啦
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/62576.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于SPSSPRO实现层次分析法(AHP)

层次分析法&#xff0c;简称AHP&#xff0c;是指将与决策总是有关的元素分解成目标、准则、方案等层次&#xff0c;在此基础之上进行定性和定量分析的决策方法。&#xff08;摘自百度百科&#xff09; 层次分析法有着广泛使用&#xff0c;涉及到的平台也多种多样&#xff0c;今…

Android Studio新版本logcat过滤说明

按包名过滤 //输入package:&#xff08;输入一个p就会有提示的&#xff09; &#xff0c;后面加上包名 比如: package:com.xal.runcontrol package:包名可以完整或者输部分包名即可 package:包名需要输完整准确 package~:正则表达式过滤 不了解正则表达式的可以参考&#…

·[K8S:使用calico网络插件]:解决集群节点NotReady问题

文章目录 一&#xff1a;安装calico&#xff1a;1.1&#xff1a;weget安装Colico网络通信插件&#xff1a;1.2&#xff1a;修改calico.yaml网卡相关配置&#xff1a;1.2.1&#xff1a;查看本机ip 网卡相关信息&#xff1a;1.2.2&#xff1a;修改calico.yaml网卡interface相关信…

普及100Hz高刷+1ms响应 微星发布27寸显示器:仅售799元

不论办公还是游戏&#xff0c;高刷及低响应时间都很重要&#xff0c;微星现在推出了一款27寸显示器PRO MP273A&#xff0c; 售价只有799元&#xff0c;但支持100Hz高刷、1ms响应时间&#xff0c;还有FreeSync技术减少撕裂。 PRO MP273A的100Hz高刷新率是其最大的卖点之一&#…

使用node.js 搭建一个简单的HelloWorld Web项目

文档结构 config.ini #将本文件放置于natapp同级目录 程序将读取 [default] 段 #在命令行参数模式如 natapp -authtokenxxx 等相同参数将会覆盖掉此配置 #命令行参数 -config 可以指定任意config.ini文件 [default] authtokencc83c08d73357802 #对应一条隧…

Python绘制箭头向量图,并绘制三体引力场

文章目录 简介箭头设置三维场图 简介 箭头向量图十分常见&#xff0c;比如天气预报在显示风场的时候&#xff0c;就会贴心地用箭头指明风的方向。在matplotlib中&#xff0c;用quiver函数来绘制箭头向量图&#xff0c;示例如下 import matplotlib.pyplot as plt import numpy…

使用hexo进行博客迁移

本文不会从0开始介绍如何通过hexo去搭建一个github page。因为最近折腾了下&#xff0c;发现这玩意儿确实写个博客很费劲&#xff0c;打算把他拖管到github当作我的知识库网站&#xff0c;我的主要文章还是通过mweb写完一键发布到博客园&#xff0c;然后csdn记录一些杂文和思考…

PHP8的程序结构-PHP8知识详解

在做任何事情之前&#xff0c;都需要遵循一定的规则。在PHP8中&#xff0c;程序能够安照人们的意愿执行程序&#xff0c;主要依靠程序的流程控制语句。 不管多复杂的程序&#xff0c;都是由这些基本的语句组成的。语句是构造程序的基本单位。程序执行的过程就是执行程序语句的…

JVM 学习—— 类加载机制

前言 在上一篇文章中&#xff0c;荔枝梳理了有关Java中JVM体系架构的相关知识&#xff0c;其中涉及到的有关Java类加载机制的相关知识并没有过多描述。那么在这篇文章中&#xff0c;荔枝会详细梳理一下有关JVM的类加载机制和双亲委派模型的知识&#xff0c;希望能够帮助到有需要…

Android应用开发(6)TextView进阶用法

Android应用开发学习笔记——目录索引 上一章Android应用开发&#xff08;5&#xff09;文本视图&#xff08;TextView&#xff09;介绍了文本视图&#xff08;TextView&#xff09;设置文本内容、设置文本大小、设置文本显示颜色。 TextView是最基础的文本显示控件&#xff…

【STM32零基础入门教程03】GPIO输入输出之GPIO框图分析

本章节主要讲解点亮LED的基本原理&#xff0c;以及GPIO框图的讲解。 如何点亮LED&#xff08;输出&#xff09; 首先我们查看原理图&#xff0c;观察电路图中LED的连接情况&#xff0c;如下图可以看出我们的板子中LED一端通过限流电阻连接的PB0另一端连接的是高电平VCC&#xf…

召唤神龙打造自己的ChatGPT

在之前的两篇文章中&#xff0c;我介绍了GPT 1和2的模型&#xff0c;并分别用Tensorflow和Pytorch来实现了模型的训练。具体可以见以下文章链接&#xff1a; 1. 基于Tensorflow来重现GPT v1模型_gzroy的博客-CSDN博客 2. 花费7元训练自己的GPT 2模型_gzroy的博客-CSDN博客 有…

Rocketmq 定时消息源码分析

定时消息定义 生产者将消息投放到broker后&#xff0c;不会马上被消费者消费。需要等待到特定时间才会被消费。 调用链路 producer 将定时消息写入commitLog线程ReputThead 休息1毫秒&#xff0c;读取一次commitlog数据&#xff0c;写入ConsumeQueue和IndexFile线程Scheduled…

右值引用带来的效率提升(C++11)

文章目录 一.左值引用和右值引用二.C11区分左值和右值的语法设计意义--对象的移动构造和移动赋值场景分析1:C11之前C11之后 场景分析2:函数std::move右值引用的广泛使用 三.引用折叠 一.左值引用和右值引用 左值:可以取到地址的对象(可以出现在赋值符号的左边),对左值的引用称…

arcgis--网络分析(理论篇)

1、定义概念 &#xff08;1&#xff09;网络&#xff1a;由一系列相互联通的点和线组成&#xff0c;用来描述地理要素&#xff08;资源&#xff09;的流动情况。 &#xff08;2&#xff09;网络分析&#xff1a;对地理网络&#xff08;如交通网络、水系网络&#xff09;&…

【数据结构】排序算法系列

常见的排序如下&#xff1a; 一、比较类排序 1. 交换排序 &#xff08;1&#xff09; 冒泡排序 【数据结构】交换排序&#xff08;一&#xff09;——冒泡排序_Jacky_Feng的博客-CSDN博客 &#xff08;2&#xff09; 快速排序 【数据结构】交换排序&#xff08;二&#xf…

用于大型图像模型的 CNN 内核的最新内容

一、说明 由于OpenAI的ChatGPT的巨大成功引发了大语言模型的繁荣&#xff0c;许多人预见到大图像模型的下一个突破。在这个领域&#xff0c;可以提示视觉模型分析甚至生成图像和视频&#xff0c;其方式类似于我们目前提示 ChatGPT 的方式。 用于大型图像模型的最新深度学习方法…

【力扣每日一题】2023.8.7 反转字符串

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们一个字符数组形式的字符串&#xff0c;让我们直接原地修改反转字符串&#xff0c;不必返回。 给出的条件是使用O(1)的额外空间…

c语言——计算一串字符的长度

//计算一串字符的长度 //在main函数中输出一个字符&#xff0c;并且计算出该字符的长度。 #include<stdio.h> #include<stdlib.h> int length(char *s){int i0;while(*s!\0){i;s;}return i;} int main() {int len;char str[20];printf("输入字符串&#xff1a…

【JavaEE进阶】Spring核心与设计思想

文章目录 一. Spring框架概述1. 什么是Spring框架2. 为什么要学习框架?3. Spring框架学习的难点 二. Spring 核心与设计思想1. 什么是容器?2. 什么是IoC?3. Spring是IoC容器4. DI&#xff08;依赖注入&#xff09;5. DL&#xff08;依赖查找&#xff09; 一. Spring框架概述…