【SVG 生成系列论文(九)】如何通过文本生成 svg logo?IconShop 模型推理代码详解

  • SVG 生成系列论文(一) 和 SVG 生成系列论文(二) 分别介绍了 StarVector 的大致背景和详细的模型细节。
  • SVG 生成系列论文(三)和 SVG 生成系列论文(四)则分别介绍实验、数据集和数据增强细节。
  • SVG 生成系列论文(五)介绍了从光栅图像(如 PNG、JPG 格式)转换为矢量图形(如 SVG、EPS 格式)的关键技术-像素预过滤(pixel prefiltering), Diffvg 这篇论文也是 SVG 生成与编辑领域中 “基于优化”方法的开创性研究。
  • SVG 生成系列论文(六) 和 SVG 生成系列论文(七) 简要介绍了 IconShop 的背景、应用和部分细节。
  • SVG 生成系列论文(八)则介绍了模型架构和具体的训练技巧。

本文将详细拆解 IconShop(论文原文以及代码🔗)的模型结构和对应开源代码。上篇有提到过模型架构如下所示,本篇则从代码的逻辑进行解释,主要是 /path/to/IconShop/model/decoder.py 中的 sample 以及 forward 两个函数。

模型架构

架构整体分为 4 个部分:SVG 图标嵌入(SVG Icon Embedding),文本嵌入(Text Embedding),输入准备和输出生成。

在这里插入图片描述

sample 函数

在原项目中,调用模型推理是从 sample_pixels = sketch_decoder.sample(n_samples=BS, text=tokenized_text) 中进行的,因此重点解读这部分代码。

  1. 输入部分:
    self: 指代调用该方法的对象,意味着可以访问类的属性和方法。
    n_samples: 要生成的样本数量,这里是 batchsize。
    text: 输入的文本数据,用于条件生成。形状为[batch_size, text_len],其中text_len是文本序列的长度,默认是 50。
    pixel_seq: 已有的像素序列,初始化为None,形状为[batch_size, max_len]。
    xy_seq: 已有的坐标序列,初始化为None,形状为[batch_size, max_len, 2]。
def sample(self, n_samples, text, pixel_seq=None, xy_seq=None):
    """ sample from distribution (top-k, top-p) """
    pix_samples = []
    xy_samples = []
    # latent_ext_samples = []
    top_k = 0
    top_p = 0.5
  1. 初始化:
    定义了空列表pix_samplesxy_samples 以存储采样的像素和坐标序列。
    同时定义了 top_k 和 top_p 作为采样策略的参数,默认使用“Top-P”策略(top_p=0.5),不使用“Top-K”。了解更多,可参见大模型推理常见采样策略(Top-k, Top-p)
    # Mapping from pixel index to xy coordiante
    pixel2xy = {}
    x=np.linspace(0, BBOX-1, BBOX)
    y=np.linspace(0, BBOX-1, BBOX)
    xx,yy=np.meshgrid(x,y)
    xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
    for pixel, xy in enumerate(xy_grid):
      pixel2xy[pixel] = xy+COORD_PAD+SVG_END
  1. 像素到坐标映射:
    创建一个从像素索引到坐标(x, y)的映射。这里假设有一个固定大小的边界框(BBOX=200),通过网格生成所有可能的坐标组合,并将这些坐标与像素索引关联起来。其中,COORD_PAD = NUM_END_TOKEN + NUM_MASK_AND_EOM(NUM_END_TOKEN = 3,NUM_MASK_AND_EOM = 2),SVG_END = 1

  2. 处理文本输入: 确保文本序列的长度不超过预设的最大长度self.text_len

  3. 循环采样:

  • for k in range(text.shape[1] + pixlen, self.total_seq_len): 是开始对文本之后的 SVG 序列进行采样(预测)。
  • 初始化或更新像素pixel_seq和坐标序列xy_seq的长度。
  • 使用模型的forward方法计算当前状态下下一个token的概率分布。
  • 应用“Top-K”和“Top-P”过滤策略(top_k_top_p_filtering)到概率分布上,仅保留最可能的token。
  • 从过滤后的分布中多分类采样得到下一个像素值。
  • 将采样的像素索引转换为实际的坐标。其中 PIX_PAD = NUM_END_TOKEN + NUM_MASK_AND_EOM, SVG_END = 1
  1. 序列更新和早停条件:
  • 将生成的像素和坐标序列加入现有序列中。
  • 检查生成是否完成(例如生成到结束符),如果有完成的样本则记录并移除。
  • 如果所有样本都生成完毕,提前停止循环。
  1. 返回结果:
    返回生成的坐标序列xy_samples
    # Sample per token
    text = text[:, :self.text_len]
    pixlen = 0 if pixel_seq is None else pixel_seq.shape[1]
    for k in range(text.shape[1] + pixlen, self.total_seq_len):
      if k == text.shape[1]:
        pixel_seq = [None] * n_samples
        xy_seq = [None, None] * n_samples
      
      # pass through model
      with torch.no_grad():
        p_pred = self.forward(pixel_seq, xy_seq, None, text)
        p_logits = p_pred[:, -1, :]

      next_pixels = []
      # Top-p sampling of next pixel
      for logit in p_logits: 
        filtered_logits = top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p)
        next_pixel = torch.multinomial(F.softmax(filtered_logits, dim=-1), 1)
        next_pixel -= self.num_text_token
        next_pixels.append(next_pixel.item())

      # Convert pixel index to xy coordinate
      next_xys = []
      for pixel in next_pixels:
        if pixel >= PIX_PAD+SVG_END:
          xy = pixel2xy[pixel-PIX_PAD-SVG_END]
        else:
          xy = np.array([pixel, pixel]).astype(int)
        next_xys.append(xy)
      next_xys = np.vstack(next_xys)  # [BS, 2]
      next_pixels = np.vstack(next_pixels)  # [BS, 1]
        
      # Add next tokens
      nextp_seq = torch.LongTensor(next_pixels).view(len(next_pixels), 1).cuda()
      nextxy_seq = torch.LongTensor(next_xys).unsqueeze(1).cuda()
      
      if pixel_seq[0] is None:
        pixel_seq = nextp_seq
        xy_seq = nextxy_seq
      else:
        pixel_seq = torch.cat([pixel_seq, nextp_seq], 1)
        xy_seq = torch.cat([xy_seq, nextxy_seq], 1)
      
      # Early stopping
      done_idx = np.where(next_pixels==0)[0]
      if len(done_idx) > 0:
        done_pixs = pixel_seq[done_idx] 
        done_xys = xy_seq[done_idx]
        # done_ext = latent_ext[done_idx]
       
        # for pix, xy, ext in zip(done_pixs, done_xys, done_ext):
        for pix, xy in zip(done_pixs, done_xys):
          pix = pix.detach().cpu().numpy()
          xy = xy.detach().cpu().numpy()
          pix_samples.append(pix)
          xy_samples.append(xy)
          # latent_ext_samples.append(ext.unsqueeze(0))
  
      left_idx = np.where(next_pixels!=0)[0]
      if len(left_idx) == 0:
        break # no more jobs to do
      else:
        pixel_seq = pixel_seq[left_idx]
        xy_seq = xy_seq[left_idx]
        text = text[left_idx]
    
    # return pix_samples, latent_ext_samples
    return xy_samples

forward

这个函数的主要作用是前向传播(或称推理),根据输入的像素序列、坐标序列、掩码和文本,计算模型的输出。具体过程如下:

  1. 准备输入:
  • 如果需要计算损失return_loss,则去掉最后一个时间步的数据。
  • 计算上下文序列的长度c_seqlen,包括文本和像素序列的长度。
def forward(self, pix, xy, mask, text, return_loss=False):
    '''
    pix.shape  [batch_size, max_len]
    xy.shape   [batch_size, max_len, 2]
    mask.shape [batch_size, max_len]
    text.shape [batch_size, text_len]
    '''
    pixel_v = pix[:, :-1] if return_loss else pix
    xy_v = xy[:, :-1] if return_loss else xy
    pixel_mask = mask[:, :-1] if return_loss else mask

    c_bs, c_seqlen, device = text.shape[0], text.shape[1], text.device
    if pixel_v[0] is not None:
      c_seqlen += pixel_v.shape[1]  
  1. 嵌入计算:

对输入的文本进行嵌入计算text_emb
如果有像素和坐标序列,则分别对它们进行嵌入计算(coord_embed_x, pixel_embed),并与文本嵌入拼接。


    # Context embedding values
    context_embedding = torch.zeros((1, c_bs, self.embed_dim)).to(device) # [1, bs, dim]

    # tokens.shape [batch_size, text_len, emb_dim]
    tokens = self.text_emb(text)

    # Data input embedding
    if pixel_v[0] is not None:
      # coord_embed.shape [batch_size, max_len-1, emb_dim]
      # pixel_embed.shape [batch_size, max_len-1, emb_dim] 
      coord_embed = self.coord_embed_x(xy_v[...,0]) + self.coord_embed_y(xy_v[...,1]) # [bs, vlen, dim]
      pixel_embed = self.pixel_embed(pixel_v)
      embed_inputs = pixel_embed + coord_embed

      # tokens.shape [batch_size, text_len+max_len-1, emb_dim]
      tokens = torch.cat((tokens, embed_inputs), dim=1)
  1. 位置编码和掩码计算:

计算位置编码,并与嵌入后的输入序列拼接。
生成用于Transformer的掩码nopeak_mask,确保模型不会看到未来的时间步。
如果有像素掩码pixel_mask,则将其扩展到与嵌入序列匹配的形状。

    # nopeak_mask.shape [c_seqlen+1, c_seqlen+1]
    nopeak_mask = torch.nn.Transformer.generate_square_subsequent_mask(c_seqlen+1).to(device)  # masked with -inf
    if pixel_mask is not None:
      # pixel_mask.shape [batch_size, text_len+max_len]
      pixel_mask = torch.cat([(torch.zeros([c_bs, context_embedding.shape[0]+self.text_len])==1).to(device), pixel_mask], axis=1)  

  1. Transformer解码:

将处理后的输入序列送入Transformer解码器,得到输出序列。

    decoder_out = self.decoder(tgt=decoder_inputs, memory=memory_encode, memory_key_padding_mask=None,tgt_mask=nopeak_mask, tgt_key_padding_mask=pixel_mask)
  1. 计算logits和损失:

通过全连接层logit_fc计算输出的logits
如果需要计算损失,则分别计算文本和像素的损失cross_entropy,并返回总损失。
如果不需要计算损失,则直接返回logits。

    # Logits fc
    logits = self.logit_fc(decoder_out)  # [seqlen, bs, dim] 
    logits = logits.transpose(1,0)  # [bs, textlen+seqlen, total_token] 

    logits_mask = self.logits_mask[:, :c_seqlen+1]
    max_neg_value = -torch.finfo(logits.dtype).max
    logits.masked_fill_(logits_mask, max_neg_value)

    if return_loss:
      logits = rearrange(logits, 'b n c -> b c n')
      text_logits = logits[:, :, :self.text_len]
      pix_logits = logits[:, :, self.text_len:]

      pix_logits = rearrange(pix_logits, 'b c n -> (b n) c')
      pix_mask = ~mask.reshape(-1)
      pix_target = pix.reshape(-1) + self.num_text_token

      text_loss = F.cross_entropy(text_logits, text)
      pix_loss = F.cross_entropy(pix_logits[pix_mask], pix_target[pix_mask], ignore_index=MASK+self.num_text_token)
      loss = (text_loss + self.loss_img_weight * pix_loss) / (self.loss_img_weight + 1)
      return loss, pix_loss, text_loss
    else:
      return logits

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/677600.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

生成式人工智能时代:职业任务的转型与技能需求的演变

近年来,人工智能的发展已经从决策式时代迈入了生成式时代。这一转变不仅改变了现有职业的性质,也催生了众多新兴职业。在这个过程中,劳动者所需具备的技能也在不断演变。本文将探讨生成式人工智能对职业任务的影响,以及劳动者在新…

[C/C++]_[初级]_[在Windows和macOS平台上导出动态库的一些思考]

场景 最近看了《COM本质论》里关于如何设计基于抽象基类作为二进制接口,把编译器和链接器的实现隐藏在这个二进制接口中,从而使用该DLL时不需要重新编译。在编译出C接口时,发现接口名直接是函数名,比如BindNativePort,怎么不是_BindNativePort?说明 VC++导出的函数默认是使…

mac电脑多协议远程管理软件:Termius 8.4.0激活版下载

Termius 是一款功能强大的跨平台远程访问工具,可用于管理和连接各种远程系统和服务器。它支持SSH、Telnet、SFTP和Serial协议,并提供了键盘快捷键、自动完成和多标签功能,使用户可以方便地控制和操作远程主机。 Termius 提供了端到端的加密保…

风电Weibull+随机出力!利用ARMA模型随机生成风速+风速Weibull分布程序代码!

前言 随着能源问题日益突出,风力发电等以可再生能源为基础的发电技术越来越受到关注。建立能够正确反映实际风速特性的风速模型是研究风力发电系统控制策略以及并网运行特性的重要基础叫。由于风速的随机性和波动性,系统中的机械设备和电气设备以及电网…

STM32H750外设ADC之外部触发和注入管理

目录 概述 1 外部触发转换和触发极性 1.1 外部触发条件 1.2 忽略硬件触发条件 1.3 触发框图 1.4 常规通道的外部触发 1.5 注入通道的外部触发 2 注入通道管理 2.1 触发注入模式 2.2 自动注入模式 2.3 注入转换延迟 概述 本文主要介绍STM32H750外设ADC之外部触发和注…

拿到Offer了才知道,这家公司年终奖只有几百块~

我也挠头了 最近又有不少粉丝上岸了,其中一位分享的事情比较有意思,和你分享一下: 以后你对比Offer的时候也可以多个经验。 事情是这样的: 他在经过2个多月空窗期之后终于拿到了Offer,月薪涨幅不大,但是…

Python用于存储和组织大型数据集的文件格式库之h5py使用详解

概要 在科学计算和数据分析中,大规模数据集的存储和管理是一个重要的问题。HDF5(Hierarchical Data Format version 5)是一种用于存储和组织大型数据集的文件格式。Python 的 h5py 库是一个用于与 HDF5 文件交互的接口,它结合了 HDF5 的强大功能和 Python 的易用性,使得处…

任务3.3 学生喂养三种宠物:猫、狗和鸟

本任务旨在通过Java面向对象编程中的多态性和方法重载概念,实现一个学生喂养三种不同宠物(猫、狗、鸟)的程序。 定义基类和派生类 创建一个Animal基类,包含所有动物共有的属性和方法,如name、age、speak()、move()和ea…

【全开源】Java同城服务同城信息同城任务发布平台小程序APP公众号源码

📢 连接你我,让任务触手可及 🌟 引言 在快节奏的现代生活中,我们时常需要寻找一些便捷的方式来处理生活中的琐事。同城任务发布平台系统应运而生,它为我们提供了一个高效、便捷的平台,让我们能够轻松发布…

解锁阿里巴巴API接口的无限可能:打造你的电商、物流、支付新纪元

Alibaba API接口是Alibaba平台对外开放的一系列编程接口,开发者可以通过这些接口访问Alibaba平台的数据和功能,如商品搜索、订单管理、支付接口等。这些接口基于HTTP/HTTPS协议,支持多种编程语言和数据格式(如JSON、XML等&#xf…

[Algorithm][动态规划][子序列问题][最长等差数列][等差数列划分 Ⅱ - 子序列]详细讲解

目录 1.最长等差数列1.题目链接2.算法原理详解3.代码实现 2.[等差数列划分 II - 子序列]1.题目链接2.算法原理详解3.代码实现 1.最长等差数列 1.题目链接 最长等差数列 2.算法原理详解 思路: 确定状态表示 -> dp[i]的含义 dp[i]:以i位置元素为结尾…

碳微球是新型碳材料 在高科技领域应用价值极高

碳微球是新型碳材料 在高科技领域应用价值极高 碳微球是一种新型碳材料,由石墨片层在玻璃相的石墨结构间断分布而构成。   与碳纳米管、石墨烯等碳材料不同,碳微球具有独特的球形结构,这赋予了其高比表面、高堆积密度等特点及良好的导电性、…

PROFINET转CANOPEN(WL-ABC3033)连接台达伺服驱动器ASDA-B3

在工业自动化领域这片广阔天地中,通信协议的转换犹如一道横亘在工程师们面前的难题。特别是在将众多采用不同通信协议的设备汇聚一堂,共同协作完成任务的场景中,如何确保数据如丝般顺滑地穿梭于各个节点之间,确保每台设备都能心领…

Spring-DI入门案例

黑马程序员SSM框架教程 文章目录 一、DI入门案例思路分析二、实现步骤2.1 删除service中使用new形式创建的Dao对象2.2 提供以来对象对应的setter方法2.3 配置service与到之间的关系 一、DI入门案例思路分析 基于IoC管理bean(上个案例已经实现)service中…

进阶 RocketMQ - 消息存储-一张图掌握核心要点

看了很多遍源码整理的 一张图进阶 RocketMQ 图片,关于 RocketMQ 你只需要记住这张图! 消息传递责任已移交至Broker,接下来如何处理?首先,我们需要确保消息的持久化,避免因宕机导致的数据丢失。那么&#xf…

生活旅游数据恢复:全国违章查询

【步骤一:备份数据】 在开始数据恢复之前,首先要做的是备份现有的数据。虽然这一步不直接涉及到数据恢复,但万一在恢复过程中出现问题,您还可以回滚到备份,以避免数据丢失。 打开全国违章查询app。在主界面上找到并点…

下载视频怎么转换MP4?wmv转换mp4,推荐这3种方法

在数字化时代,我们经常需要从网上下载各种视频,但有时候下载的视频并不是我们想要的格式,比如WMV。为了能在更多的设备上播放或进行编辑,我们可能需要将其转换为更通用的MP4格式。 那么,下载的视频如何转换成MP4呢&am…

上位机图像处理和嵌入式模块部署(f407 mcu内部flash编程)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 对于f407这样的mcu来说,有的时候我们需要对mcu内部的flash进行编程处理。有两种情况需要对flash进行编程,一种情况是可能一…

Simulink中使用powergui做FFT分析

快速傅里叶变换(FFT)能更快的将信号从时域转换到频域进行表示,在频谱图上,可以直观的观察到信号的不同频率的大小和性质。实现信号的降噪、滤波等效果。 Simulink中的powergui模块 powergui其实是电力系统的图形化用户接口&…

【最新鸿蒙应用开发】——使用axios完成手机号注册业务

使用Axios请求实现目标效果图: 短信验证码登录 校验图形验证码,校验通过 发送短信验证码到用户手机上,可通过在线 WebSocket查看:wss://guardian-api.itheima.net/verifyCode 根据 手机号 短信验证码 实现登录 更新图形验证码…