Segment-Anything模型(SAM)基本原理与代码分析

Segment Anything简介

SAM即Segment Anything,其目的在于构建一个分割领域的foundation model,用于zero-shot任务,以及交互式prompt分割任务。其结构如下所示。SAM具有大模型的encoder提图像特征,然后通过轻量的prompt encoder和mask decoder将提示转为embedding,并与图像的特征embedding进行交互(仍然是基于transformer),最后输出mask。

SAM模型整体结构

SAM作为foundation model,其输出的mask并非为某类或者某些特定类别的物体,而是一个广义的“object”或者“stuff”的概念,因此会有一定的歧义性(比如,将一个人的上半身的一个点作prompt,我们无法判断想要分割的是这个人的衣服,还是整个人等等),因此SAM可以支持multimask输出,并根据需要选择合适的mask。

SAM模型的整体结构与基本风格

设计原则:

  • 遵循scaling law的思路,希望通过大模型、大数据量(data-driven),使得模型学习到(涌现出)类别/任务无关的“object”概念,成为视觉领域的通用模型,而不是针对某类任务或者某种特定类别训练模型

  • 可prompt,可以交互输入一定先验信息,利用SAM的语义能力,完成开集分割(SAM无需知道分割的对象的各种类别信息,关于对象的信息被编码到prompt中)

  • prompt的多模态性:支持多种prompt类型,比如point、box、dense mask以及text(需要图文多模module如CLIP支持,将text的隐空间关联到图像的隐空间)

  • zero-shot能力/OpenSet能力:不指定任务,可以直接迁移到众多下游子任务中(当然后续研究发现对于特殊场景,比如医学图像、隐藏目标分割等,还是无法较好处理,但是语义能力仍可用,因此出现了SAMed、Adapter-SAM等方案将其迁移到特殊任务中)

整体训练过程:

SAM的训练和标注流程比较特殊,它并非通常的先标注-再训练-最后得到结果的这种范式,而是将标注(annotation)和模型训练(training)闭环,形成一个data engine的策略,即一边用标注数据训网络,一边用网络产生标注,从而可以整体运转,获得更多的标注数据训练模型,提高模型的效果。SAM论文中将训练过程主要分为三个阶段:

  • 手工协助阶段(assisted-manual stage)

  • 半自动阶段(semi-automatic stage)

  • 全自动阶段(full-automatic stage)

SAM的任务、结构与训练过程

代码分析

Encoder细节

分为image encoder和prompt encoder,重点关注prompt encoder。

ImageEncoderViT

采用ViT架构,对图像做encode,图像固定尺寸输入(1024, 1024) ,最后得到的隐编码尺寸为(64, 64) (patch size 为 16)。在SAM的训练阶段,采用了MAE预训练好的ViT做监督训练。在预测阶段,首先将读入的图像的长边resize到指定尺寸(比如1024),同时保持其aspect ratio。然后,沿着短边方向进行填充,得到正方形图像输入。

PromptEncoder

  • sparse prompt

为了适应不同的prompt类型,采用了多种不同的encoder。需要注意的是,即使point和bbox的prompt对于SAM也是进行编码送入网络处理的,因此不一定完全将point分割为指定类别,也不一定bbox prompt后的结果完全在bbox中,在SAM中这两者都只有提示大致位置的功能。

首先,PrompEncoder中设置了四种不同的embedding层:

1
2
3
4
self.num_point_embeddings: int = 4  # pos/neg point + 2 box corners
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)

其中,point_embeddings 共有4个,分别表示四种不同的点,即:正样本、负样本、bbox的左上角点、bbox的右下角点。nn.Embedding(1, embed_dim)中的输入为vocab长度为1,即只有一个的向量,其中d为embed_dim。这个向量是nn.Embedding的weight,因此随着训练更新。另外,对于没有point prompt的情况,也给一个embedding向量,即not_a_point_embed,这个embedding对应label既不是前景fg也不是背景bg时的占位向量。上述的embedding训练好后,可以用来指示点的类别。

另一方面,点的位置信息则通过 位置编码(Position Encoding,PE) 来实现。PE的目的是将连续的位置信息映射到高频,从而增加不同位置的区分度。这里采用了PositionEmbeddingRandom实现。对正负样本point prompt和bbox的角点prompt进行PE编码后,与上面的点的类型编码结合(相加),得到最终的prompt encoding结果。

  1. 对point prompt(正负样本点)的编码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
  1. 对bbox prompt的编码(top-left和bottom-right两角点)
1
2
3
4
5
6
7
8
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
  • dense prompt

SAM中的dense prompt即mask输入。对于mask的处理,SAM通过一个下采样卷积层将其resize到与image embedding同等大小,然后相加。(由于mask_downscaling下采样了4倍,因此mask_input_size 需要被设置为image_embedding_size的4倍)

1
2
3
4
5
6
7
8
9
10
11
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)

同时,对于没有dense mask的情况,也训练一个embedding作为占位符,该向量在spatial(即H和W)方向上进行复制,得到和dense mask embedding相同的张量,与image embedding相加。

  • 各种prompt的合并组装

在forward中,对各个prompt进行组装,基本原则是sparse prompt(即各种点)得到的embedding向量沿着dim=1进行cat,即得到 的结果。其中表示点的数量。而dense mask编码成dense embedding,无dense mask时直接复制no_mask_embed中的权重向量。注意,这里的各种prompt都是可选的(optional)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.

Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed

Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)

return sparse_embeddings, dense_embeddings

Decoder细节

SAM中的MaskDecoder 是一个轻量化的结构,便于交互式分割响应。其基本结构如图所示:

MaskDecoder结构示意图

首先,如前所述,图像被编码成为大小的feature,每个feature维度为256。上面的各种prompt最终也被编码成为同样长度(256)的特征向量,作为token与图像进行attention的交互。

具体实现如下:

  1. 先考虑需要输出的mask和iou的形式,这里:
1
self.num_mask_tokens = num_multimask_outputs + 1

num_multimask_outputs 这里设定为3。由于SAM支持两种模式:multimask和直接输出单一mask,因此需要+1。除了mask的token之外,还需要预测每个mask的iou,因此还需要一个iou token,经过Transformer处理后的iou token接入一个MLP网络,网络输出维度等于mask个数,因此为每个mask分配了一个iou分数。mask token处理后与一个经过Transformer + upscaling(2x transposed conv实现)的处理后的image embedding进行逐像素点乘,最终得到对应的mask。

  1. 下面来看MaskDecoder输入的形式,以及如何将prompt融合进来的。首先,mask token 和 iou token 也被设置为nn.Embedding层的权重,即:
1
2
3
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

由于我们还有 PromptEncoder 编码来的稀疏和稠密prompt embeding,其中代表mask的稠密embedding比较简单,直接将其与图像embedding相加即可;对于稀疏prompt,将它们与前面的mask token和iou token在通道维度上拼接,得到输入Transformer的tokens:

1
2
3
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

接下来,tokens将与image_embeddings(包括dense mask embedding 和 image pe)一起输入到一个TwoWayTransformer,即图中的主体部分。该模块可以对tokens→image embedding和image embeddings → tokens两个方向进行cross attention操作,从而使得image embedding也接受tokens的修正。这个过程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)

# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)

# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)

# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)

return queries, keys

可以看出,两次cross attention中,q和k互换,并且k和v来源相同。最开始是一个self-attention,第一个cross attention后还有一个MLP对attention后的token进行更新。

  1. 最后,mask的输出方式如下:
1
2
3
4
5
6
7
8
9
10
11
12
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)

然后,可以根据是否设置输出 multimask,选择合适的输出mask:

1
2
3
4
5
6
7
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]

Reference

Segment Anything

https://github.com/facebookresearch/segment-anything

https://segment-anything.com/