论文链接:arxiv.org/abs/2304.02…

代码链接:github.com/lllyasviel/…

Demo链接:segment-anything.com/demo

SAM从使命、模型、数据三部分展开写作,和模型的立异比较起来,使命界说和数据的作业愈加出彩,官网也给出了demo,能直观感受SAM的作用,这篇blog也会环绕这几部分展开。

demo

demo中有敞开point, box, everything三种办法。由于text prompt作用不太稳定,demo和代码中都没有该部分。

  • 鼠标悬停: 显示的是悬停位置的切割成果,例如下图中将鼠标放到手的位置.

    一篇文章搞懂Segment Anything(SAM)

  • 点击: 切割包括该点的物体,会按最小切割的成果展现出来,假如想切割的物体大于展现的成果,能够在物体的其他部分也点击下。

一篇文章搞懂Segment Anything(SAM)

一篇文章搞懂Segment Anything(SAM)

  • box: 框定一个box,切割box中的物体

    一篇文章搞懂Segment Anything(SAM)

  • everything: 将图片中全部物体的切割都展现出来

    一篇文章搞懂Segment Anything(SAM)

使命

一篇文章搞懂Segment Anything(SAM)

使命的规划灵感来自于NLP范畴,例如NLP中能够通过预测next token作为预练习使命,而在下流使命中能够运用prompt engineering做运用。因而,为了建立切割的根底模型,使命的规划方针是也需求具有相似的才能。 这儿作者扩展了下NLP里prompt在图画切割里的用法, prompt能够是以下几种类型:

  • point
  • box
  • mask
  • 恣意格式的文本

为了支撑以下的几种输入prompt格式,要求模型能够区分具有混杂含义的prompt,例如下图中,一个point的prompt可能有多种切割办法.这多种切割办法关于模型来说都是有用的。

一篇文章搞懂Segment Anything(SAM)

预练习: 将上面说到的多种sequence的prompt告知模型,练习方针是让模型输出对应promt的切割成果,而且希望模型输出的成果和GT尽可能共同。区别于之前的交互式切割算法,SAM基本能治通过一次交互就能得到很合理的切割成果。要到达这个意图,需求规划十分共同的模型结构和loss。

zero-shot transfer:需求模型对任何prompt,得到适宜的切割成果。例如,假如要做实例切割,能够把检测得到的box作为prompt,SAM就能去做实例切割

related tasks: 切割里有许多子使命,例如边际切割,语义切割等,SAM能完结全部已知的切割使命和还没有作为一个方向的切割使命。之前现已有相似的能够做多种切割的模型(solo), 可是这些模型有多个子子输出,然后做排列组合能够得到多种切割成果。而SAM通过prompt将多个切割使命合并在一起。

总而言之,作者是希望SAM能够切割全部,而且能相CLIP相同,能运用到最开端没有想到的范畴。

模型

一篇文章搞懂Segment Anything(SAM)

一篇文章搞懂Segment Anything(SAM)

模型的结构如上图所示. prompt会通过prompt encoder, 图画会通过image encoder。然后将两部分embedding通过一个轻量化mask decoder得到融合后的特征。encoder部分运用的都是已有模型,decoder运用transformer。这部分论文中介绍的相比照较少,下面会结合代码一起整理下:

  • image encoder: 运用的是用ViT走位backbone的MAE模型。在交互式切割的展现中,image encoder只会运行一次。在试验中,别离有用到ViT-H, ViT-L, ViT-B三种巨细的模型作为image encoder。代码如下,build_sam#L47
sam_model_registry = {
    "default": build_sam_vit_h,
    "vit_h": build_sam_vit_h,
    "vit_l": build_sam_vit_l,
    "vit_b": build_sam_vit_b,
}
  • prompt encoder: prompt一共有point,box, mask, text四种,会将其分为三类。pint和box能够作为一类运用position encodings, text能够运用CLIP作为encoder, 而mask是一种密布型的prompt,能够运用卷积作为encoder.prompt_encoder.py#LL128C5-L128C5 prompt_encoder的代码如下所示,其间用position embedding别离完结了point和box query两种稀疏embedding,用卷积完结了mask query密布embedding.,
def forward(
        self,
        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
        boxes: Optional[torch.Tensor],
        masks: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Embeds different types of prompts, returning both sparse and dense
        embeddings.
        Arguments:
          points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
            and labels to embed.
          boxes (torch.Tensor or none): boxes to embed
          masks (torch.Tensor or none): masks to embed
        Returns:
          torch.Tensor: sparse embeddings for the points and boxes, with shape
            BxNx(embed_dim), where N is determined by the number of input points
            and boxes.
          torch.Tensor: dense embeddings for the masks, in the shape
            Bx(embed_dim)x(embed_H)x(embed_W)
        """
        bs = self._get_batch_size(points, boxes, masks)
        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
        if points is not None:
            coords, labels = points
            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))     # position embedding
            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
        if boxes is not None:
            box_embeddings = self._embed_boxes(boxes)   # position embedding
            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
        if masks is not None:
            dense_embeddings = self._embed_masks(masks)    # conv embedding
        else:
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )
        return sparse_embeddings, dense_embeddings
  • mask decoder:: 运用一个transformer将image embedding和prompt embedding做双向的cross-attention;而且也有prompt embedding的self-attention。也有MLP和linear classifier分类切割区域。mask decoder, transformer.py#L151这儿的queries是query embedding,keys是image embedding,query_pe和queries相同,key_pe是需求加到image embedding上的位置编码。query embedding会通过self attention。query embedding和image embedding会做双向的cross-attention, 具体完结办法是如上代码所示,image embedding会作为query,query embedding会作为key和value;相同的,query embedding会作为query,image embedding会作为key和value。
def forward(
        self,
        image_embedding: Tensor,
        image_pe: Tensor,
        point_embedding: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
          image_embedding (torch.Tensor): image to attend to. Should be shape
            B x embedding_dim x h x w for any h and w.
          image_pe (torch.Tensor): the positional encoding to add to the image. Must
            have the same shape as image_embedding.
          point_embedding (torch.Tensor): the embedding to add to the query points.
            Must have shape B x N_points x embedding_dim for any N_points.
        Returns:
          torch.Tensor: the processed point_embedding
          torch.Tensor: the processed image_embedding
        """
        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
        bs, c, h, w = image_embedding.shape
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
        image_pe = image_pe.flatten(2).permute(0, 2, 1)
        # Prepare queries
        queries = point_embedding
        keys = image_embedding
        # Apply transformer blocks and final layernorm
        for layer in self.layers:
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,
                key_pe=image_pe,
            )
        # Apply the final attention layer from the points to the image
        q = queries + point_embedding
        k = keys + image_pe
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm_final_attn(queries)
        return queries, keys
  • 处理混杂的输入: 关于一个prompt,模型会输出3个mask,实际上也能够输出更多的切割成果,3个能够看作一个物体的整体、部分、子部分,基本能满足大多数状况。运用IOU的办法,排序mask。在反向传达时,参与核算的只要loss最小的mask相关的参数.

  • 高效: 这儿首要指的是prompt encodermask decoder。在web浏览器上,CPU核算只用约50ms

  • loss和练习细节: 首要运用的是focal loss和dice loss。每一个mask,会随机发生11种prompt与之配对。

数据

数据引擎

不像CLIP中图画文本对通过互联网容易获取,切割的数据获取本钱巨大。SAM开源了一个10亿张图片的切割数据集。在SAM中规划了一个数据引擎用于获取切割的数据,数据引擎首要分为以下三部分:

  • 辅佐标示: 简略来说便是用能够获取到的开源切割数据练习一个初始的SAM模型V0版别,再用V0在没有切割标示的数据上生成预标示,人工check模型的成果并作修改和承认。得到新的数据后,再将新的数据参加到练习集从头练习SAM得到V1版别,再循环标示数据和迭代模型。一共进行6次练习。开端的时分数据集比较少,运用的ViT-B模型,最终会运用ViT-H模型。 这儿面还有一些功率提高的数据,例如跟着模型的迭代,每个mask的标示耗时从34s到14s。SAM模型在每张图片上生成的mask从20到44个。在该阶段数据集最终有12万张图片,430万个mask

  • 半主动化标示: 通过第一阶段后,现已有一个不错的SAM模型能生成切割成果。半主动化标示的意图是增加mask的多样性。具体做法是练习一个检测模型,用于检测SAM生成的mask成果是否可信,只保留可信的mask成果,然后将图片给人工标示。人工标示会在可信的mask根底上标示出其他的切割框。通过5次的迭代后,数据集新增了18万张图片,590万mask。

主动标示: 通过前面两个阶段后,SAM有了较好的成果,能切割出图片中的方针,而且关于混杂的prompt也有了较好的输出。这个模型能够主动的对一些图片做标示。主动标示的时分需求有一些挑选策略,模型输出的成果可能还是会呈现一些过错。首要有以下三种办法做挑选

  • SAM模型有一个IOU prediction的模块能输出mask的confidence,如下图所示

    一篇文章搞懂Segment Anything(SAM)

  • stable mask的判断,具体的办法是在得到切割成果前对logit加正向和负向的扰动,假如两次扰动生成的切割成果IOU大于0.95,则以为生成的mask是可靠的

  • NMS过滤掉重复的mask

一篇文章搞懂Segment Anything(SAM)

数据质量

图画: 包括11M高分辨率(33004950)的图画,其他的一些开源数据集,例如COCO分辨率较低(480640) Mask: 包括1.1B的mask,99.1%都是模型生成的。作者试验了下,只运用模型生成的mask和即运用模型生成也运用人工标示的mask,模型的作用是适当的。因而发布的数据集里只包括模型生成的mask Mask 质量: 抽取了一部分mask数据做人工的精标,精标前后有94%的mask具有90%以上的IOU。而其他的一些开源数据集只要85-91%的IOU

下面也从mask的数量,每种mask尺寸的占比和mask占外接矩形比例等多方面和其他数据集做了比照

一篇文章搞懂Segment Anything(SAM)

数据来源分布

一篇文章搞懂Segment Anything(SAM)

一篇文章搞懂Segment Anything(SAM)

不同性别,肤色,年龄人群切割作用的差异比照

一篇文章搞懂Segment Anything(SAM)

zero-short Transfer试验

评价的数据集都是SAM模型练习时的不同,而且包括水下,第一视角等没有在SAM中呈现过场景的图片

point mask

这儿比照的是用point作为prompt比照切割的成果, 在绝大部分数据集中都优于RITM(当前的SOTA)

一篇文章搞懂Segment Anything(SAM)

边际检测

SAM在练习的时分便是选用的包括point prompt的办法,作者这儿还比照了一些在练习时没有包括的办法,边界检测便是其间一种。SAM在运用边界检测时,运用办法是在图片上铺上16*16均匀的point prompt,每个prompt发生3个mask,再通过NMS后。通过Sobel filtering得到边际检测的成果。SAM的成果倾向于提取更丰厚的边际,因而在方针上recall和专门做边际检测的模型适当,precision会低些。

一篇文章搞懂Segment Anything(SAM)

一篇文章搞懂Segment Anything(SAM)

方针检测

切割的成果取bbox,就能做方针检测了.整体方针低于ViTDet,可是在中等常见和不太常见的方针上作用优于ViTDet

一篇文章搞懂Segment Anything(SAM)

实例切割

先用一个方针检测算法,用方针检测得到的box作为prompt输入到SAM,就能够做实例切割了。试验的成果分为了定量(用测试集的GT)和定性(人来评判好坏)两种。定量的方针不如BiTDet—H,定性的方针SAM优于ViTDet。作者给出的解说是COCO数据集标示作用一般(在人看来乃至不如SAM和ViTDet模型输出的成果),因而ViTDet在COCO上做练习时拟合到了一些过错的偏差,但过错的偏差和标示相似,因而定量的方针不如ViTDet

一篇文章搞懂Segment Anything(SAM)

一篇文章搞懂Segment Anything(SAM)

Text to Mask

这儿指的是用文本作为prompt,然后切割出文本说到的方针。作者在练习的时分取的是图片中方针尺寸大于100*100的方针,用CLIP提取image embedding(text embedding也行,由于CLIP的image embedding和text embedding是对齐的),作为prompt encoder模块的输出,用于练习SAM模型。这一部分没有和其他办法比照,也由于作用不太稳定,在官方的demo中没有展现

一篇文章搞懂Segment Anything(SAM)

消融试验

一篇文章搞懂Segment Anything(SAM)

有以下的定论: 左面的图,数据来源的影响:

  • 参加半主动标示的数据和主动标示的数据功能都有很大的提高
  • 只用模型生成的数据与额外加上人工标示的数据差异不大

中心的图, 数据量的影响:

  • 数据量从0.1M到1M,模型功能提高很大
  • 数据量从1M到11M,模型功能变化不明显,实际运用中1M差不多满足

右边的图, image encoder的影响:

  • ViT-B到ViT-L提高很大
  • ViT-L到ViT-H提高一般,实际运用ViT-L满足

总结

SAM的热度也十分高,相同作为FB的作业,SAM只是放出来两个月,github上star的数量现已超过了detectron2三年的总和。SAM的希望是能将该模型作为图画范畴的根底模型(foundation model),像CLIP那样能在各个范畴大放反常,或许像GPT相同能一致NLP范畴。SAM也确真实许多场景得到了运用,例如开源的SD中也融合了SAM,能够做许多风趣的运用,例如从假人模特身上用SAM得到衣服的mask,再结合ControlNet,就能够生成不同的人穿戴相同的衣服。

最开端自媒体宣扬的文章也是《CV范畴不存在了》,《CV界的GPT3》相似的标题,SAM确实是在一致上迈出了很大的一步,但实际上CV范畴的一致还有许多挑战。NLP范畴中的Bert用完型填空和GPT预测下一个token的预练习在十分多的使命上表现了很好的泛化性,乃至在一些没有练习过的使命上能取得比一些专家模型更好的作用。

  • 使命和数据上的不一致,CV范畴的分类是输出类别,检测输出bbox,切割输出mask。尽管单个使命能够复用,可是整体缺少一个通用的使命。使命上的不一致,数据上也很难做到一致,分类的使命有许多数据,可是检测和切割的数据就要少十分多,而且标示本钱巨大。单纯练习分类作为backbone也很难处理其他使命,检测和切割的算法依然需求做很多的优化

  • CV范畴的使命缺少孕育大模型的土壤,CV使命一直在考虑模型的核算量,显存占用。假如将每个像素看作一个token,一张512*512的图片就有26万个token。假如transformer最开端呈现在CV范畴,面对的问题是显存和核算量都比resnet差,而且作用也远不如resnet。假如没有transformer极大的促进了NLP范畴的开展,CV范畴可能也不会从头思考transfomer能增大感受野,能有更好的泛化才能。

  • 还没有找到CV范畴【高维】的使命。NLP范畴的完形填空和对话确实是一种很高维的使命。模型能完结这些使命,一些NER或许RE之类的底层使命也能很好的被处理。现在CV范畴有一些尝试做foundation model,例如比照学习或许像SAM,在一些使命上表现了不错的泛化性,可能是这些办法能一致其他使命,但现在的开展还不太够,也可能是其他一些还没呈现/开展起来的使命。但这种【高维】的使命一定能通过一些办法降维处理现在简直全部的CV根底使命。