unet中的attn_processor的修改(用于设计新的注意力模块)

参考资料

文章目录

  • unet中的一些变量的数据情况
    • attn_processor
    • unet.config
    • unet_sd
  • 自己定义自己的attn Processor ,对原始的attn Processor进行修改

IP-adapter中设置attn的方法
参考的代码: 腾讯ailabipadapter 的官方训练代码

unet中的一些变量的数据情况

# init adapter modules
	#用来存储自己重构后的注意力处理器字典
    attn_procs = {}
    unet_sd = unet.state_dict()
    for name in unet.attn_processors.keys():
    	#如果是自注意力注意力attn1,那么设置为空,否则设置为交叉注意力的维度
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        #这里记录此时这个快的通道式
        if name.startswith("mid_block"):
        #'block_out_channels', [320, 640, 1280, 1280]
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
        #name中的,up_block.的后一个位置就是表示是第几个上块
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
            #这里是从unet_sd当中把这个交叉注意力层的原始kv权重拷贝一份出来
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            #然后这里将新构建的字典里面的attn_processor给替换为自己定义的IPAttnProcessor
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
            #这里将新构建的attn模型的权重初始化为原来的SD的uent中的crossattn的权重
            attn_procs[name].load_state_dict(weights)
    #最后这里将unet的注意力处理器设置为自己重构后的注意力字典
    unet.set_attn_processor(attn_procs)

attn_processor

unet中的unet.state_dict()存储了所有attn_processor的字典
我们要做修改的话,重构一个类似的字典,然后把其中我们需要修改的模块的attn_processor的类型进行替换

我们来看一下unet.attn_processors是什么样子的
在这里插入图片描述
unet.attn_processors是一个字典,包含32个元素
它的 key 是每个处理类所在位置,并结合unet的结构以及其中中crossattn块的个数(总共2,2,2,1,3,3,3(16个块)(每个块分别有一个自注意力和一个交叉注意力模块,所以总共有32个注意力块)),
我们知道了每块的名称的命名的含义:
比如:

'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor'
down_blocks.0.(可以是0,12,)(有3个下块)
代表第一个下块

attentions.0.(可以是0,1)(每个下块有2个transformer块)
代表第一个下块中的第一个transformer块

transformer_blocks.0.
这里都是0

attn1.processor(每个transformer块有2和注意快,一个交叉注意力,一个自注意力)
代表是自注意力还是交叉注意力(attn2.代表交叉注意力层,attn1代表自注意力层)

unet.config

unet.config 是unet配置的参数

FrozenDict([('sample_size', 64),
 ('in_channels', 4), 
('out_channels', 4), 
('center_input_sample', False), ('flip_sin_to_cos', True), ('freq_shift', 0), ('down_block_types', ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']), ('mid_block_type', 'UNetMidBlock2DCrossAttn'), ('up_block_types', ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D']), 
('only_cross_attention', False), 
('block_out_channels', [320, 640, 1280, 1280]),
 ('layers_per_block', 2), ('downsample_padding', 1), ('mid_block_scale_factor', 1), ('dropout', 0.0), ('act_fn', 'silu'), ('norm_num_groups', 32), 
 ('norm_eps', 1e-05), ('cross_attention_dim', 768), ('transformer_layers_per_block', 1), ('reverse_transformer_layers_per_block', None),
  ('encoder_hid_dim', None), ('encoder_hid_dim_type', None), ('attention_head_dim', 8), ('num_attention_heads', None), ('dual_cross_attention', False), ('use_linear_projection', False), ('class_embed_type', None), ('addition_embed_type', None), ('addition_time_embed_dim', None), ('num_class_embeds', None), ('upcast_attention', False), ('resnet_time_scale_shift', 'default'), ('resnet_skip_time_act', False), ('resnet_out_scale_factor', 1.0), ('time_embedding_type', 'positional'), ('time_embedding_dim', None), ('time_embedding_act_fn', None), ('timestep_post_act', None), ('time_cond_proj_dim', None), ('conv_in_kernel', 3), ('conv_out_kernel', 3), ('projection_class_embeddings_input_dim', None), ('attention_type', 'default'), ('class_embeddings_concat', False), ('mid_block_only_cross_attention', None), ('cross_attention_norm', None), ('addition_embed_type_num_heads', 64), ('_use_default_values', ['addition_embed_type', 'encoder_hid_dim', 'transformer_layers_per_block', 'addition_embed_type_num_heads', 'upcast_attention', 'conv_in_kernel', 'attention_type', 'resnet_out_scale_factor', 'time_embedding_dim', 'time_embedding_act_fn', 'conv_out_kernel', 'reverse_transformer_layers_per_block', 'mid_block_type', 'class_embeddings_concat', 'time_embedding_type', 'use_linear_projection', 'class_embed_type', 'only_cross_attention', 'resnet_time_scale_shift', 'encoder_hid_dim_type', 'projection_class_embeddings_input_dim', 'dual_cross_attention', 'addition_time_embed_dim', 'cross_attention_norm', 'dropout', 'timestep_post_act', 'resnet_skip_time_act', 'num_attention_heads', 'time_cond_proj_dim', 'mid_block_only_cross_attention', 'num_class_embeds']), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.6.0'), ('_name_or_path', '/media/dell/DATA/RK/pretrained_model/stable-diffusion-v1-5')])

unet_sd

这里面是一个字典,包含了所有层的各个小模块的权重
在这里插入图片描述
这里是从unet_sd当中把这个交叉注意力层的原始kv权重拷贝一份出来,用于初始化自己设计的注意力处理器

 weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }

查看修改后unet的attn_processors

这里将unet.attn_processors的所有values()转化为list

adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

这里是IPadapter替换后的attn processor 的情况

ModuleList(
  (0): AttnProcessor2_0()
  (1): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (2): AttnProcessor2_0()
  (3): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (4): AttnProcessor2_0()
  (5): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (6): AttnProcessor2_0()
  (7): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (8): AttnProcessor2_0()
  (9): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (10): AttnProcessor2_0()
  (11): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (12): AttnProcessor2_0()
  (13): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (14): AttnProcessor2_0()
  (15): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (16): AttnProcessor2_0()
  (17): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
  (18): AttnProcessor2_0()
  (19): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (20): AttnProcessor2_0()
  (21): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (22): AttnProcessor2_0()
  (23): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
  )
  (24): AttnProcessor2_0()
  (25): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (26): AttnProcessor2_0()
  (27): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (28): AttnProcessor2_0()
  (29): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
  )
  (30): AttnProcessor2_0()
  (31): IPAttnProcessor2_0(
    (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
    (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
  )
)

自己定义自己的attn Processor ,对原始的attn Processor进行修改

在原始的attention_processor.py 文件中定义新的attn processor类

原始的attention_processor中的attn processor

class AttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

#3 ipadapter 新定义的
class IPAttnProcessor(nn.Module):
    r"""
    Attention processor for IP-Adapater.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            encoder_hidden_states, ip_hidden_states = (
                encoder_hidden_states[:, :end_pos, :],
                encoder_hidden_states[:, end_pos:, :],
            )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # for ip-adapter
        ip_key = self.to_k_ip(ip_hidden_states)
        ip_value = self.to_v_ip(ip_hidden_states)

        ip_key = attn.head_to_batch_dim(ip_key)
        ip_value = attn.head_to_batch_dim(ip_value)

        ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
        self.attn_map = ip_attention_probs
        ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)

        hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

自己定义两个新的,然后也放如这个文件里面

class StyleAttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
class LayoutAttnProcessor(nn.Module):
    r"""
    Default processor for performing attention-related computations.
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

然后导入这两个attn processor

from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor, \
    LayoutAttnProcessor, StyleAttnProcessor

替换后的结果如下

这里是将第三个下块,和第1个上块分别替换为layout attn 和 style attn

    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        # 第三个下块的名称开头是这个
        elif name.startswith("down_blocks.2.attentions"):
            attn_procs[name] = LayoutAttnProcessor()
        #第一个上块的名称开头是这个
        elif name.startswith("up_blocks.1.attentions"):
            attn_procs[name] = StyleAttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
            attn_procs[name].load_state_dict(weights)

修改后 attn_processors 如下

{
'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
),

##  可以看到,这里的attn替换为了我们自己定义的layout  attn
 'down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': LayoutAttnProcessor(), 'down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': LayoutAttnProcessor(), 

## 可以看到,这里的attn替换为了我们自己定义的style  attn
'up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 'up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor': StyleAttnProcessor(), 



'up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=640, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=640, bias=False)
), 'up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=320, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=320, bias=False)
), 'mid_block.attentions.0.transformer_blocks.0.attn1.processor': AttnProcessor2_0(), 'mid_block.attentions.0.transformer_blocks.0.attn2.processor': IPAttnProcessor2_0(
  (to_k_ip): Linear(in_features=768, out_features=1280, bias=False)
  (to_v_ip): Linear(in_features=768, out_features=1280, bias=False)
)
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/907351.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深度学习基础—序列采样

引言 深度学习基础—循环神经网络(RNN)https://blog.csdn.net/sniper_fandc/article/details/143417972?fromshareblogdetail&sharetypeblogdetail&sharerId143417972&sharereferPC&sharesourcesniper_fandc&sharefromfrom_link …

Qt中的Model与View5: QStyledItemDelegate

目录 QStyledItemDelegate API 重载公共函数 保护函数 重载保护函数 当在 Qt 项目视图中显示模型数据时,例如 QTableView,每个项目由代理绘制。此外,当项目被编辑时,提供一个编辑器小部件,该小部件在编辑时显示在项…

AI打造超写实虚拟人物:是科技奇迹还是伦理挑战?

内容概要 在这个科技飞速发展的时代,超写实虚拟人物仿佛从科幻小说中走进了我们的日常生活。它们以生动的形象和细腻的动作,不仅在影视、广告和游戏中吸引了无数目光,更让我们对AI技术的未来充满了期待和疑惑。这些数字化身在逼真的外貌下&a…

海浪中的记忆:海滨学院班级回忆录开发

3系统分析 3.1可行性分析 通过对本海滨学院班级回忆录实行的目的初步调查和分析,提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本海滨学院班级回忆录采用SSM框架,JAVA作为开…

文本串的应用(1)

一、文本串的加密解密操作 一个文本串可用事先给定的字母映射表进行加密。 例如,假设字母映射表为: a b c d e f g h i j k l m n o p q r s t u v w x y z n g z q t c o b m u h e l k p d a w x f y i v r s j 则字符串“encrypt”被加密为“tkz…

MCU裸机任务调度架构

第1章 方式一(平均主义) int main(int argc, char **argv){/* RTC 初始化 */bsp_RTC_Init(&rtc);/* 串口初始化 */uartInit(115200);/* LED初始化 */ledInit();while(1){// 任务1(获取传感器数据)// 任务2// 任务3} } 1.1 平均主义的缺陷 获取传感器数据可以600ms去读取一…

【力扣专题栏】面试题 01.02. 判定是否互为字符重排,如何利用数组模拟哈希表解决两字符串互排问题?

题解目录 1、题目描述解释2、算法原理解析3、代码编写(1)、两个数组分别模拟哈希表解决(2)、利用一个数组模拟哈希表解决问题 1、题目描述解释 2、算法原理解析 3、代码编写 (1)、两个数组分别模拟哈希表解决 class Solution { public:bool CheckPermutation(string s1, stri…

【OJ题解】C++实现反转字符串中的每个单词

💵个人主页: 起名字真南 💵个人专栏:【数据结构初阶】 【C语言】 【C】 【OJ题解】 题目要求:给定一个字符串 s ,你需要反转字符串中每个单词的字符顺序,同时仍保留空格和单词的初始顺序。 题目链接: 反转字符串中的所…

全新更新!Fastreport.NET 2025.1版本发布,提升报告开发体验

在.NET 2025.1版本中,我们带来了巨大的期待功能,进一步简化了报告模板的开发过程。新功能包括通过添加链接报告页面、异步报告准备、HTML段落旋转、代码文本编辑器中的文本搜索、WebReport图像导出等,大幅提升用户体验。 FastReport .NET 是…

【数据结构与算法】第8课—数据结构之二叉树(堆)

文章目录 1. 树1. 什么是树?1.2 树的相关概念1.3 树的表示法 2. 二叉树2.1 特殊的二叉树2.2 二叉树的性质2.3 二叉树的存储结构 3. 实现顺序结构二叉树3.1 堆的概念3.2 堆的实现3.2.1 堆的数据结构3.2.2 堆的初始化3.2.3 堆插入数据3.2.4 删除堆顶数据3.2.5 堆的判空…

Spring的高效开发思维(三)

时间:2024年 11月 02日 作者:小蒋聊技术 邮箱:wei_wei10163.com 微信:wei_wei10 音频:喜马拉雅 大家好,欢迎来到“小蒋聊技术” 我是小蒋!。小蒋今天想和大家聊聊Spring Cloud微服务架构&am…

TEC半导体致冷工作原理:【图文详讲】

目录 1:什么是TEC 2:TEC工作原理 3:TEC结构 4:TEC技术参数 5:TEC选型 6:实物TEC 7:手机散热器 1:什么是TEC TEC半导体致冷器(Thermo Electric Cooler&#xff09…

Linux开发讲课47--- 详解 Linux 中的虚拟文件系统

虚拟文件系统是一种神奇的抽象,它使得 “一切皆文件” 哲学在 Linux 中成为了可能。 什么是文件系统?根据早期的 Linux 贡献者和作家 Robert Love 所说,“文件系统是一个遵循特定结构的数据的分层存储。” 不过,这种描述也同样适用…

海的记忆篇章:海滨学院班级回忆录项目

摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了海滨学院班级回忆录的开发全过程。通过分析海滨学院班级回忆录管理的不足,创建了一个计算机管理海滨学院班级回忆录的方案。文章介绍了海滨学院班级回…

【WebRTC】WebRTC的简单使用

目录 1.下载2.官网上的使用3.本地的使用 参考: 【webRTC】一、windows编译webrtc Windows下WebRTC编译 1.下载 下载时需要注意更新python的版本和网络连接,可以先试试ping google。比较关键的步骤是 cd webrtc-checkout set https_proxy127.0.0.1:123…

【05】如何解决tomcat命令提示符控制台乱码问题

Web项目开发过程中,直接在命令提示符窗口中通过输入startup.bat命令运行tomcat,在新弹出的tomcat命令提示符窗口中输出的中文是乱码问题的处理。 如何解决tomcat命令提示符控制台乱码问题 文章目录 如何解决tomcat命令提示符控制台乱码问题1.解决问题思路…

Golang | Leetcode Golang题解之523题连续的子数组和

题目&#xff1a; 题解&#xff1a; func checkSubarraySum(nums []int, k int) bool {m : len(nums)if m < 2 {return false}mp : map[int]int{0: -1}remainder : 0for i, num : range nums {remainder (remainder num) % kif prevIndex, has : mp[remainder]; has {if …

JeecgBoot集成工作流实战教程

Activiti是一个轻量级的工作流程和业务流程管理&#xff08;BPM&#xff09;平台&#xff0c;它主要面向业务人员、开发人员和系统管理员。这个平台的核心是一个快速且可靠的Java BPMN 2流程引擎。Activiti是开源的&#xff0c;并且基于Apache许可证进行分发。它可以运行在任何…

springcloud整合sentinel,限流策略持久化到nacos,详细配置案例

目录 1.组件下载和启动 &#xff08;1&#xff09;sentinel-dashboard下载 &#xff08;2&#xff09;nacos下载 &#xff08;3&#xff09;jmeter下载 &#xff08;4&#xff09;redis下载&#xff08;与流控关系不大&#xff0c;与项目启动有关&#xff09; 2.本微服务项…

利用Docker Compose构建微服务架构

&#x1f493; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4dd; Gitee主页&#xff1a;瑕疵的gitee主页 ⏩ 文章专栏&#xff1a;《热点资讯》 利用Docker Compose构建微服务架构 引言 Docker Compose 简介 安装 Docker Compose 创建项目结构 编写 Dockerfile 前端 Dockerf…