- 给定输入图像 I ∈ R 3 × H × W I \in R^{3 \times H \times W} I∈R3×H×W。
- 给定需要的prompts:
- M ∈ R 1 × H × W M \in R^{1 \times H \times W} M∈R1×H×W,代表图片的前背景信息。
- P ∈ R N × 2 P \in R^{N \times 2} P∈RN×2,其中 N N N 是点的个数,2 代表坐标。
- B ∈ R 4 B \in R^{4} B∈R4,4 代表左上角与右下角点的坐标。
- T T T 代表一段文本,暂时还未开放。
- I I I 输入到image encoder中提取特征,得到image embeddings: f I = V I T ( I ) , f I ∈ R c × h × w f^{I}=VIT(I),f^{I} \in R^{c \times h \times w} fI=VIT(I),fI∈Rc×h×w c , h , w c,h,w c,h,w 分别是特征维度与特征的空间高,宽。
- 得到稠密编码 f D ∈ R c × h × w f^{D} \in R^{c \times h \times w} fD∈Rc×h×w。如果有 M M M,将其输入到卷积网络中卷它,如果没有的话,直接复制no_mask_embed向量填充。
- 得到稀疏编码
f
S
∈
R
K
×
c
f^{S} \in R^{K \times c}
fS∈RK×c。
- 对于点 P P P,进行位置编码,得到 f P ∈ R N × c f^P \in R^{N \times c} fP∈RN×c (每个点映射为一个 c c c 维向量),并且 f P f^P fP 中不同区域(填充部分,前景,背景)要添加对应的编码加以区分。
- 对于框 B B B,首先重塑为两个点,然后使用与点相同的方式进行点编码,最后两个点加上对应的坐上角与右下角的编码,最终得到 f B ∈ R 2 × c f^B \in R^{2 \times c} fB∈R2×c。
- 最后将 f P f^P fP 与 f B f^B fB 拼接起来作为稀疏编码,最后的稀疏编码可能只包含点编码或框编码,但实质都是点编码,只是框编码会额外加两个可学习编码加以区分,即三种情况: K = N ∣ K = 2 ∣ K = N + 2 K =N|K=2|K=N+2 K=N∣K=2∣K=N+2
- f k e y = f I + f D , f k e y ∈ R c × h × w f^{key}=f^{I}+f^{D},f^{key} \in R^{c \times h \times w} fkey=fI+fD,fkey∈Rc×h×w 作为mask decoder的 key
- 加入各种token输入到mask decoder中,作为 query。iou_token: f i o u ∈ R 1 × c f^{iou} \in R^{1 \times c} fiou∈R1×c,mask_tokens: f m a s k ∈ R 4 × c f^{mask} \in R^{4 \times c} fmask∈R4×c (3个mask+1个背景)。 f q u e r y = C a t ( f i o u , f m a s k , f S ) , f q u e r y ∈ R ( 5 + K ) × c f k e y , f q u e r y = M a s k D e c o d e r ( f k e y , f q u e r y , f p e ) f^{query}=Cat(f^{iou},f^{mask},f^S),f^{query} \in R^{(5 + K) \times c}\\ f^{key},f^{query}=MaskDecoder(f^{key},f^{query},f^{pe}) fquery=Cat(fiou,fmask,fS),fquery∈R(5+K)×cfkey,fquery=MaskDecoder(fkey,fquery,fpe) f p e f^{pe} fpe是位置编码
- 最终得到
f
k
e
y
∈
R
c
×
h
×
w
f^{key} \in R^{c \times h \times w}
fkey∈Rc×h×w,
f
q
u
e
r
y
∈
R
(
5
+
K
)
×
c
f^{query} \in R^{(5 + K) \times c}
fquery∈R(5+K)×c。
- 随后 f k e y f^{key} fkey 进行反卷积,还原到图像尺寸 H H H, W W W(实际会进行一些采样)。
- f q u e r y f^{query} fquery 的第一个表示iou,后三个表示mask,对后三个进行线性映射。
- 前两步结果求向量积,得到mask预测。 f i o u = f q u e r y [ : , 0 , : ] f^{iou}=f^{query}[:,0,:] fiou=fquery[:,0,:] f m a s k = f q u e r y [ : , 1 : 4 , : ] f^{mask}=f^{query}[:,1:4,:] fmask=fquery[:,1:4,:] f m a s k = M L P ( f m a s k ) , f m a s k ∈ R 3 × c f^{mask}=MLP(f^{mask}),f^{mask} \in R^{3 \times c} fmask=MLP(fmask),fmask∈R3×c f m a s k = M a t M u l ( f m a s k , f k e y ) , f m a s k ∈ R 3 × H × W f^{mask}=MatMul(f^{mask}, f^{key}),f^{mask} \in R^{3 \times H \times W} fmask=MatMul(fmask,fkey),fmask∈R3×H×W f i o u = M L P ( f i o u ) , f i o u ∈ R 3 f^{iou}=MLP(f^{iou}),f^{iou} \in R^{3} fiou=MLP(fiou),fiou∈R3
- 最终模型得到 3 个 mask 以及 3 个置信度。