diffusers 源码待理解之处

一、训练DreamBooth时,相关代码的细节小计

在这里插入图片描述
**

class_labels = timesteps 时,模型的前向传播怎么走?待深入去看

**

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

利用class_prompt去生成数据,而不是instance_prompt

在这里插入图片描述

class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        class_num=None,
        size=512,
        center_crop=False,
        encoder_hidden_states=None,
        class_prompt_encoder_hidden_states=None,
        tokenizer_max_length=None,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.encoder_hidden_states = encoder_hidden_states
        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
        self.tokenizer_max_length = tokenizer_max_length

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            if class_num is not None:
                self.num_class_images = min(len(self.class_images_path), class_num)
            else:
                self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        instance_image = exif_transpose(instance_image)

        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)

        if self.encoder_hidden_states is not None:
            example["instance_prompt_ids"] = self.encoder_hidden_states
        else:
            text_inputs = tokenize_prompt(
                self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
            )
            example["instance_prompt_ids"] = text_inputs.input_ids
            example["instance_attention_mask"] = text_inputs.attention_mask

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            class_image = exif_transpose(class_image)

            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)

            if self.class_prompt_encoder_hidden_states is not None:
                example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
            else:
                class_text_inputs = tokenize_prompt(
                    self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
                )
                example["class_prompt_ids"] = class_text_inputs.input_ids
                example["class_attention_mask"] = class_text_inputs.attention_mask

        return example
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
    if tokenizer_max_length is not None:
        max_length = tokenizer_max_length
    else:
        max_length = tokenizer.model_max_length

    text_inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )

    return text_inputs
def collate_fn(examples, with_prior_preservation=False):
    has_attention_mask = "instance_attention_mask" in examples[0]

    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    if has_attention_mask:
        attention_mask = [example["instance_attention_mask"] for example in examples]

    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.
    if with_prior_preservation:
        input_ids += [example["class_prompt_ids"] for example in examples]
        pixel_values += [example["class_images"] for example in examples]

        if has_attention_mask:
            attention_mask += [example["class_attention_mask"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.cat(input_ids, dim=0)

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }

    if has_attention_mask:
        attention_mask = torch.cat(attention_mask, dim=0)
        batch["attention_mask"] = attention_mask

    return batch

Dataset和Dataloader的构成
在这里插入图片描述
为了避免模型过拟合或者是说语言漂移的情况,需要用模型去用一个普通的prompt先生成样本。

fine-tune text-encoder,但是对显存要求更高
在这里插入图片描述

二、训练text to image,相关代码的细节小计

**

1、Dataloader的构建如下,但是为啥没有attention_mask呢?训练DreamBooth时有
2、训练或者微调模型时需要图文数据对,如果没有文本数据,可以用BLIP去生成图像描述的文本,但是文本描述不一定可靠
**

 # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
            data_dir=args.train_data_dir,
        )
    else:
        data_files = {}
        if args.train_data_dir is not None:
            data_files["train"] = os.path.join(args.train_data_dir, "**")
        dataset = load_dataset(
            "imagefolder",
            data_files=data_files,
            cache_dir=args.cache_dir,
        )
        # See more about loading custom images at
        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset["train"].column_names

    # 6. Get the column names for input/target.
    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
    if args.image_column is None:
        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        image_column = args.image_column
        if image_column not in column_names:
            raise ValueError(
                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
            )
    if args.caption_column is None:
        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        caption_column = args.caption_column
        if caption_column not in column_names:
            raise ValueError(
                f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
            )

    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    # Preprocessing the datasets.
    train_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        examples["input_ids"] = tokenize_captions(examples)
        # images text pixel_values input_ids 4种key
        return examples

    with accelerator.main_process_first():
        if args.max_train_samples is not None:
            dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
        # Set the training transforms
        train_dataset = dataset["train"].with_transform(preprocess_train)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        return {"pixel_values": pixel_values, "input_ids": input_ids}

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
    )

三、训ControlNet

Dataloader的搭建的代码如下:


1、新增conditioning_pixel_values图像数据,用于做可控的生成
2、输入中依旧没有attention-mask,待思考


def make_train_dataset(args, tokenizer, accelerator):
    # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
        )
    else:
        if args.train_data_dir is not None:
            dataset = load_dataset(
                args.train_data_dir,
                cache_dir=args.cache_dir,
            )
        # See more about loading custom images at
        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset["train"].column_names

    # 6. Get the column names for input/target.
    if args.image_column is None:
        image_column = column_names[0]
        logger.info(f"image column defaulting to {image_column}")
    else:
        image_column = args.image_column
        if image_column not in column_names:
            raise ValueError(
                f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
            )

    if args.caption_column is None:
        caption_column = column_names[1]
        logger.info(f"caption column defaulting to {caption_column}")
    else:
        caption_column = args.caption_column
        if caption_column not in column_names:
            raise ValueError(
                f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
            )

    if args.conditioning_image_column is None:
        conditioning_image_column = column_names[2]
        logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
    else:
        conditioning_image_column = args.conditioning_image_column
        if conditioning_image_column not in column_names:
            raise ValueError(
                f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
            )

    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if random.random() < args.proportion_empty_prompts:
                captions.append("")
            elif isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    image_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(args.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    conditioning_image_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(args.resolution),
            transforms.ToTensor(),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        images = [image_transforms(image) for image in images]

        conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]

        examples["pixel_values"] = images
        examples["conditioning_pixel_values"] = conditioning_images
        examples["input_ids"] = tokenize_captions(examples)

        return examples

    with accelerator.main_process_first():
        if args.max_train_samples is not None:
            dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
        # Set the training transforms
        train_dataset = dataset["train"].with_transform(preprocess_train)

    return train_dataset


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.stack([example["input_ids"] for example in examples])

    return {
        "pixel_values": pixel_values,
        "conditioning_pixel_values": conditioning_pixel_values,
        "input_ids": input_ids,
    }

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/288904.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python+selenium操作下拉框

以该网站为例&#xff1a;https://www.17sucai.com/pins/demo-show?id5926 该网页下存在多个可供测试的下拉框。 基本脚手架代码&#xff1a; 1 2 3 4 5 6 7 8 9 10 from selenium.webdriver.support.ui import Select from selenium import webdriver import time driver …

解析一次get请求后台解码中文乱码的问题

今日遇到一个项目组中个人独有的bug&#xff0c;系统输入中文搜索内容搜不出来&#xff0c;组员都可以&#xff0c;从前台查到后台&#xff0c;发现前端的获取值和传递值都没什么问题&#xff0c;到了后台&#xff0c;接收的中文参数直接是个乱码&#xff0c;但是想到之前也有传…

WEB:探索开源OFD.js技术应用

1、简述 OFD.js 是一个由开源社区维护的 JavaScript 库&#xff0c;专注于在浏览器中渲染和处理 OFD 文件。OFD 作为一种开放式的文档格式&#xff0c;被广泛应用于电子政务、电子合同等领域。OFD.js 的出现为开发者提供了一个强大的工具&#xff0c;使得在前端实现 OFD 文件的…

(16)Linux 进程等待 wait/waitpid 的 status 参数

前言&#xff1a;我们开始讲解进程等待&#xff0c;简单地讲解 wait 函数&#xff0c;然后我们主要讲解 waitpid 函数。由于 wait 只有一个参数 status&#xff0c;且 waitpid 有三个参数且其中一个也是 status&#xff0c;我们本章重点讲解这个 status 参数。 一、进程等待&a…

[ Tool ] celery分布式任务框架基本使用

celery官方 Celery 官网&#xff1a;www.celeryproject.org/ Celery 官方文档英文版&#xff1a;docs.celeryproject.org/en/latest/i… Celery 官方文档中文版&#xff1a;docs.jinkan.org/docs/celery… Celery是一个简单、灵活且可靠的&#xff0c;处理大量消息的分布式系…

DRF从入门到精通七(djangorestframework-simplejwt、定制返回格式、多方式登录)

文章目录 一、djangorestframework-simplejwt快速使用1.基础使用步骤2.自己配置视图校验访问局部配置认证及权限类全局配置认证及权限类 3.关于双token认证问题 二、定制返回格式三、多方式登录 一、djangorestframework-simplejwt快速使用 JWT主要用于签发登录接口需要配合认证…

阿里云服务器云盘ESSD Entry、SSD、高效云盘性能测评

阿里云服务器系统盘或数据盘支持多种云盘类型&#xff0c;如高效云盘、ESSD Entry云盘、SSD云盘、ESSD云盘、ESSD PL-X云盘及ESSD AutoPL云盘等&#xff0c;阿里云百科aliyunbaike.com详细介绍不同云盘说明及单盘容量、最大/最小IOPS、最大/最小吞吐量、单路随机写平均时延等性…

yolo增加Shape-IoU,完美超越SIoU/EIoU/CIoU

论文地址&#xff1a;https://arxiv.org/pdf/2312.17663.pdf 代码地址&#xff1a;GitHub - malagoutou/Shape-IoU 摘要 作为检测定位分支的重要组成部分&#xff0c;边界框回归损失在目标检测任务中起着重要作用。现有的边界框回归方法通常考虑GT框和预测框之间的几何关系&…

C语言实例_math.h库函数功能及其用法详解

一、前言 数学在计算机编程中扮演着至关重要的角色&#xff0c;C语言的math.h头文件提供了一系列的函数和工具&#xff0c;用于数学计算和常用数学函数的实现。这些函数包括数值运算、三角函数、指数对数函数等&#xff0c;为开发人员提供了强大的数学处理能力。本文将对math.…

从零开始了解大数据(七):总结

系列文章目录 从零开始了解大数据(一)&#xff1a;数据分析入门篇-CSDN博客 从零开始了解大数据(二)&#xff1a;Hadoop篇-CSDN博客 从零开始了解大数据(三)&#xff1a;HDFS分布式文件系统篇-CSDN博客 从零开始了解大数据(四)&#xff1a;MapReduce篇-CSDN博客 从零开始了解大…

多模态大模型Vary:扩充视觉Vocabulary,实现更细粒度的视觉感知

前言 现代大型视觉语言模型(LVLMs)具有相同的视觉词汇- CLIP&#xff0c;它可以涵盖大多数常见的视觉任务。然而&#xff0c;对于一些需要密集和细粒度视觉感知的特殊视觉任务&#xff0c;例如文档级OCR或图表理解&#xff0c;特别是在非英语场景下&#xff0c;clip风格的词汇…

电子电路快速入门

参考&#xff1a; 电子电路设计入门篇一 https://www.bilibili.com/video/BV18C4y1p7p9Proteus与protel的区别 https://www.zhihu.com/question/385796380数字集成电路的设计流程简介 https://zhuanlan.zhihu.com/p/24476011?tt_fromweixin硬件电路设计的基本流程、作用和注意…

StratifiedGroupKFold解释和代码实现

StratifiedGroupKFold解释和代码实现 文章目录 一、StratifiedGroupKFold解释和代码实现是什么&#xff1f;二、 实验数据设置2.1 实验数据生成代码2.2 代码结果 三、实验代码3.1 实验代码3.2 实验结果3.3 结果解释 四、样本类别类别不平衡 一、StratifiedGroupKFold解释和代码…

简单Diff算法

简单Diff算法 渲染器的核心 Diff算法 解决的问题 比较新旧虚拟节点的子节点&#xff0c;实现最小化更新。 虚拟节点key属性的作用 就像虚拟节点的“身份证号”&#xff0c;在更新时&#xff0c;渲染器会通过key属性找到可复用的节点&#xff0c;然后尽可能地通过DOM移动操…

Hexo 部署 Github Pages, Github Actions自动部署

想整个静态的博客部署在github pages 历经两天的折磨终于是摸索成功了&#xff0c;官网的文档太简陋了&#xff0c;很多东西没说清楚。 欢迎大家访问我的博客&#xff01; CanyueThis is Canyues blog.https://mobeicanyue.github.io/ 最终实现的效果&#xff0c;一个项目仓库…

polar CTF 简单rce

一、题目 <?php /*PolarD&N CTF*/ highlight_file(__FILE__); function no($txt){if(!preg_match("/cat|more|less|head|tac|tail|nl|od|vim|uniq|system|proc_open|shell_exec|popen| /i", $txt)){return $txt;}else{ die("whats up");}} $yyds(…

【openGauss服务器端工具的使用】

【openGauss服务器端工具的使用】 gs_checkperf openGauss 不仅提供了gs_checkperf工具来帮助用户了解openGauss的负载情况。 使用数据库安装用户登录服务器&#xff0c;执行如下命令进行查看数据库性能&#xff1a; 简要信息展示&#xff1a;[ommopengauss03 ~]$ gs_checkperf…

Ubuntu Server 22.04 连接Wifi并配置静态IP

Ubuntu Server 22.04 连接Wifi并配置静态IP 前言&#xff1a;我家最近好几台电脑&#xff0c;我都想跑着Ubuntu Server做服务器&#xff0c;但是近几年的超级本已经不自带网口了&#xff0c;所以我就考虑用Wifi来联网&#xff0c;速度也还可以&#xff0c;但是既然是跑服务&…

《算法导论》复习——CHP1、CHP2 算法基础

基本定义&#xff1a; 算法是一组有穷的规则&#xff0c;规定了解决某一特定类型问题的一系列运算。 关心算法的正确性和效率。 算法的五个重要特性&#xff1a;确定性、能行性、输入、输出、有穷性。 基础方法&#xff1a; 伪代码&#xff08;Pseudocode&#xff09;&#xff…

Springboot集成RabbitMq二

接上一篇&#xff1a;Springboot集成RabbitMq一-CSDN博客 1、搭建项目-消费者 与之前一样 2、创建配置类 package com.wym.rabbitmqconsumer.utils;import org.springframework.amqp.core.Binding; import org.springframework.amqp.core.BindingBuilder; import org.spring…