Segment Anything简介
SAM即Segment Anything,其目的在于构建一个分割领域的foundation model,用于zero-shot任务,以及交互式prompt分割任务。其结构如下所示。SAM具有大模型的encoder提图像特征,然后通过轻量的prompt encoder和mask decoder将提示转为embedding,并与图像的特征embedding进行交互(仍然是基于transformer),最后输出mask。
SAM模型整体结构
SAM作为foundation model,其输出的mask并非为某类或者某些特定类别的物体,而是一个广义的“object”或者“stuff”的概念,因此会有一定的歧义性(比如,将一个人的上半身的一个点作prompt,我们无法判断想要分割的是这个人的衣服,还是整个人等等),因此SAM可以支持multimask输出,并根据需要选择合适的mask。
SAM模型的整体结构与基本风格
设计原则:
遵循scaling law的思路,希望通过大模型、大数据量(data-driven),使得模型学习到(涌现出)类别/任务无关的“object”概念,成为视觉领域的通用模型,而不是针对某类任务或者某种特定类别训练模型
可prompt,可以交互输入一定先验信息,利用SAM的语义能力,完成开集分割(SAM无需知道分割的对象的各种类别信息,关于对象的信息被编码到prompt中)
prompt的多模态性:支持多种prompt类型,比如point、box、dense mask以及text(需要图文多模module如CLIP支持,将text的隐空间关联到图像的隐空间)
zero-shot能力/OpenSet能力:不指定任务,可以直接迁移到众多下游子任务中(当然后续研究发现对于特殊场景,比如医学图像、隐藏目标分割等,还是无法较好处理,但是语义能力仍可用,因此出现了SAMed、Adapter-SAM等方案将其迁移到特殊任务中)
整体训练过程:
SAM的训练和标注流程比较特殊,它并非通常的先标注-再训练-最后得到结果的这种范式,而是将标注(annotation)和模型训练(training)闭环,形成一个data engine的策略,即一边用标注数据训网络,一边用网络产生标注,从而可以整体运转,获得更多的标注数据训练模型,提高模型的效果。SAM论文中将训练过程主要分为三个阶段:
手工协助阶段(assisted-manual stage)
半自动阶段(semi-automatic stage)
全自动阶段(full-automatic stage)
SAM的任务、结构与训练过程
代码分析
Encoder细节
分为image encoder和prompt encoder,重点关注prompt encoder。
ImageEncoderViT
采用ViT架构,对图像做encode,图像固定尺寸输入(1024, 1024)
,最后得到的隐编码尺寸为(64, 64)
(patch size 为 16)。在SAM的训练阶段,采用了MAE预训练好的ViT做监督训练。在预测阶段,首先将读入的图像的长边resize到指定尺寸(比如1024),同时保持其aspect ratio。然后,沿着短边方向进行填充,得到正方形图像输入。
PromptEncoder
为了适应不同的prompt类型,采用了多种不同的encoder。需要注意的是,即使point和bbox的prompt对于SAM也是进行编码送入网络处理的,因此不一定完全将point分割为指定类别,也不一定bbox prompt后的结果完全在bbox中,在SAM中这两者都只有提示大致位置的功能。
首先,PrompEncoder中设置了四种不同的embedding层:
1 2 3 4 self.num_point_embeddings: int = 4 point_embeddings = [nn.Embedding(1 , embed_dim) for i in range (self.num_point_embeddings)] self.point_embeddings = nn.ModuleList(point_embeddings) self.not_a_point_embed = nn.Embedding(1 , embed_dim)
其中,point_embeddings
共有4个,分别表示四种不同的点,即:正样本、负样本、bbox的左上角点、bbox的右下角点。nn.Embedding(1, embed_dim)
中的输入为vocab长度为1,即只有一个 的向量,其中d为embed_dim。这个向量是nn.Embedding的weight,因此随着训练更新。另外,对于没有point prompt的情况,也给一个embedding向量,即not_a_point_embed
,这个embedding对应label既不是前景fg也不是背景bg时的占位向量。上述的embedding训练好后,可以用来指示点的类别。
另一方面,点的位置信息则通过 位置编码(Position Encoding,PE) 来实现。PE的目的是将连续的位置信息映射到高频,从而增加不同位置的区分度。这里采用了PositionEmbeddingRandom
实现。对正负样本point prompt和bbox的角点prompt进行PE编码后,与上面的点的类型编码结合(相加),得到最终的prompt encoding结果。
对point prompt(正负样本点)的编码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def _embed_points ( self, points: torch.Tensor, labels: torch.Tensor, pad: bool , ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 if pad: padding_point = torch.zeros((points.shape[0 ], 1 , 2 ), device=points.device) padding_label = -torch.ones((labels.shape[0 ], 1 ), device=labels.device) points = torch.cat([points, padding_point], dim=1 ) labels = torch.cat([labels, padding_label], dim=1 ) point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1 ] = 0.0 point_embedding[labels == -1 ] += self.not_a_point_embed.weight point_embedding[labels == 0 ] += self.point_embeddings[0 ].weight point_embedding[labels == 1 ] += self.point_embeddings[1 ].weight return point_embedding
对bbox prompt的编码(top-left和bottom-right两角点)
1 2 3 4 5 6 7 8 def _embed_boxes (self, boxes: torch.Tensor ) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 coords = boxes.reshape(-1 , 2 , 2 ) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding[:, 0 , :] += self.point_embeddings[2 ].weight corner_embedding[:, 1 , :] += self.point_embeddings[3 ].weight return corner_embedding
SAM中的dense prompt即mask输入。对于mask的处理,SAM通过一个下采样卷积层将其resize到与image embedding同等大小,然后相加。(由于mask_downscaling
下采样了4倍,因此mask_input_size
需要被设置为image_embedding_size
的4倍)
1 2 3 4 5 6 7 8 9 10 11 self.mask_input_size = (4 * image_embedding_size[0 ], 4 * image_embedding_size[1 ]) self.mask_downscaling = nn.Sequential( nn.Conv2d(1 , mask_in_chans // 4 , kernel_size=2 , stride=2 ), LayerNorm2d(mask_in_chans // 4 ), activation(), nn.Conv2d(mask_in_chans // 4 , mask_in_chans, kernel_size=2 , stride=2 ), LayerNorm2d(mask_in_chans), activation(), nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1 ), ) self.no_mask_embed = nn.Embedding(1 , embed_dim)
同时,对于没有dense mask的情况,也训练一个embedding作为占位符,该向量在spatial(即H和W)方向上进行复制,得到和dense mask embedding相同的张量,与image embedding相加。
在forward中,对各个prompt进行组装,基本原则是sparse prompt(即各种点)得到的embedding向量沿着dim=1进行cat,即得到 的结果。其中 表示点的数量。而dense mask编码成dense embedding,无dense mask时直接复制no_mask_embed
中的权重向量。注意,这里的各种prompt都是可选的(optional)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 def forward ( self, points: Optional [Tuple [torch.Tensor, torch.Tensor]], boxes: Optional [torch.Tensor], masks: Optional [torch.Tensor], ) -> Tuple [torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense embeddings. Arguments: points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates and labels to embed. boxes (torch.Tensor or none): boxes to embed masks (torch.Tensor or none): masks to embed Returns: torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined by the number of input points and boxes. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks) sparse_embeddings = torch.empty((bs, 0 , self.embed_dim), device=self._get_device()) if points is not None : coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None )) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1 ) if boxes is not None : box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1 ) if masks is not None : dense_embeddings = self._embed_masks(masks) else : dense_embeddings = self.no_mask_embed.weight.reshape(1 , -1 , 1 , 1 ).expand( bs, -1 , self.image_embedding_size[0 ], self.image_embedding_size[1 ] ) return sparse_embeddings, dense_embeddings
Decoder细节
SAM中的MaskDecoder
是一个轻量化的结构,便于交互式分割响应。其基本结构如图所示:
MaskDecoder结构示意图
首先,如前所述,图像被编码成为 大小的feature,每个feature维度为256。上面的各种prompt最终也被编码成为同样长度(256)的特征向量,作为token与图像进行attention的交互。
具体实现如下:
先考虑需要输出的mask和iou的形式,这里:
1 self.num_mask_tokens = num_multimask_outputs + 1
num_multimask_outputs
这里设定为3。由于SAM支持两种模式:multimask和直接输出单一mask,因此需要+1。除了mask的token之外,还需要预测每个mask的iou,因此还需要一个iou token,经过Transformer处理后的iou token接入一个MLP网络,网络输出维度等于mask个数,因此为每个mask分配了一个iou分数。mask token处理后与一个经过Transformer + upscaling(2x transposed conv实现)的处理后的image embedding进行逐像素点乘,最终得到对应的mask。
下面来看MaskDecoder
输入的形式,以及如何将prompt融合进来的。首先,mask token 和 iou token 也被设置为nn.Embedding
层的权重,即:
1 2 3 self.iou_token = nn.Embedding(1 , transformer_dim) self.num_mask_tokens = num_multimask_outputs + 1 self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
由于我们还有 PromptEncoder
编码来的稀疏和稠密prompt embeding,其中代表mask的稠密embedding比较简单,直接将其与图像embedding相加即可;对于稀疏prompt,将它们与前面的mask token和iou token在通道维度上拼接,得到输入Transformer的tokens:
1 2 3 output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0 ) output_tokens = output_tokens.unsqueeze(0 ).expand(sparse_prompt_embeddings.size(0 ), -1 , -1 ) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1 )
接下来,tokens
将与image_embeddings
(包括dense mask embedding 和 image pe)一起输入到一个TwoWayTransformer
,即图中的主体部分。该模块可以对tokens→image embedding和image embeddings → tokens两个方向进行cross attention操作,从而使得image embedding也接受tokens的修正。这个过程如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 def forward (self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple [Tensor, Tensor]:if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else : q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys
可以看出,两次cross attention中,q和k互换,并且k和v来源相同。最开始是一个self-attention,第一个cross attention后还有一个MLP对attention后的token进行更新。
最后,mask的输出方式如下:
1 2 3 4 5 6 7 8 9 10 11 12 src = src.transpose(1 , 2 ).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List [torch.Tensor] = [] for i in range (self.num_mask_tokens): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = torch.stack(hyper_in_list, dim=1 ) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1 , h, w) iou_pred = self.iou_prediction_head(iou_token_out)
然后,可以根据是否设置输出 multimask,选择合适的输出mask:
1 2 3 4 5 6 7 if multimask_output: mask_slice = slice (1 , None ) else : mask_slice = slice (0 , 1 ) masks = masks[:, mask_slice, :, :] iou_pred = iou_pred[:, mask_slice]
Reference
Segment Anything
https://github.com/facebookresearch/segment-anything
https://segment-anything.com/