找回密码
 会员注册
查看: 36|回复: 0

SMA2代码实现详解——ImageEncoder篇(FpnNeck章)

[复制链接]

2万

主题

0

回帖

7万

积分

超级版主

积分
70604
发表于 2024-9-10 10:02:35 | 显示全部楼层 |阅读模式
SMA2:代码实现详解——ImageEncoder篇(FpnNeck)总配置YAML文件、OmegaConf和hydraSAM2的官方实现是使用yaml文件来配置整体的模型结构与参数的。关键代码如下:defbuild_sam2(config_file,ckpt_path=None,device="cuda",mode="eval",hydra_overrides_extra=[],apply_postprocessing=True,):ifapply_postprocessing:hydra_overrides_extra=hydra_overrides_extra.copy()hydra_overrides_extra+=[#dynamicallyfallbacktomulti-maskifthesinglemaskisnotstable"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true","++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05","++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",]#Readconfigandinitmodelcfg=compose(config_name=config_file,overrides=hydra_overrides_extra)OmegaConf.resolve(cfg)model=instantiate(cfg.model,_recursive_=True)_load_checkpoint(model,ckpt_path)model=model.to(device)ifmode=="eval":model.eval()returnmodel123456789101112131415161718192021222324252627从代码的第10行到第20行都是在配置模型参数。第19行的compose函数与第21行的instantiate函数都是hydra库的库函数。Hydra是一个开源Python框架,也是由Meta团队开发的,它可简化研究和其他复杂应用程序的开发。其主要功能是能够通过组合动态创建分层配置,并通过配置文件和命令行覆盖它。Hydra对yaml文件的读写操作是基于OmegaConf库的。回到我们的代码,第19行的compose函数用来读取config_name参数指定的yaml文件,生成可类似于Dict访问的Python对象,并根据overrides参数的内容,覆盖从yaml得到的部分参数内容。第21行的instantiate函数根据yaml文件中的配置信息实际构建网络模型。这个地方只用文字可能不太好理解,我们举个例子:例子yaml文件ptimizer:_target_:my_app.Optimizeralgo:SGDlr:0.0112345例子class文件:classOptimizer:algo:strlr:floatdef__init__(self,algo:str,lr:float)->None:self.algo=algoself.lr=lr1234567例子实例化函数pt=instantiate(cfg.optimizer)print(opt)#Optimizer(algo=SGD,lr=0.01)#overrideparametersonthecall-siteopt=instantiate(cfg.optimizer,lr=0.2)print(opt)#Optimizer(algo=SGD,lr=0.2)12345678那么我们接下来见一下SMA2的具体构造(以tiny版本为例):model:_target_:sam2.modeling.sam2_base.SAM2Baseimage_encoder:_target_:sam2.modeling.backbones.image_encoder.ImageEncoderscalp:1trunk:_target_:sam2.modeling.backbones.hieradet.Hieraembed_dim:96num_heads:1stages:[1,2,7,2]global_att_blocks:[5,7,9]window_pos_embed_bkg_spatial_size:[7,7]neck:_target_:sam2.modeling.backbones.image_encoder.FpnNeckposition_encoding:_target_:sam2.modeling.position_encoding.PositionEmbeddingSinenum_pos_feats:256normalize:truescale:nulltemperature:10000d_model:256backbone_channel_list:[768,384,192,96]fpn_top_down_levels:[2,3]#outputlevel0and1directlyusethebackbonefeaturesfpn_interp_model:nearestmemory_attention:_target_:sam2.modeling.memory_attention.MemoryAttentiond_model:256pos_enc_at_input:truelayer:_target_:sam2.modeling.memory_attention.MemoryAttentionLayeractivation:reludim_feedforward:2048dropout:0.1pos_enc_at_attn:falseself_attention:_target_:sam2.modeling.sam.transformer.RoPEAttentionrope_theta:10000.0feat_sizes:[32,32]embedding_dim:256num_heads:1downsample_rate:1dropout:0.1d_model:256pos_enc_at_cross_attn_keys:truepos_enc_at_cross_attn_queries:falsecross_attention:_target_:sam2.modeling.sam.transformer.RoPEAttentionrope_theta:10000.0feat_sizes:[32,32]rope_k_repeat:Trueembedding_dim:256num_heads:1downsample_rate:1dropout:0.1kv_in_dim:64num_layers:4memory_encoder:_target_:sam2.modeling.memory_encoder.MemoryEncoderout_dim:64position_encoding:_target_:sam2.modeling.position_encoding.PositionEmbeddingSinenum_pos_feats:64normalize:truescale:nulltemperature:10000mask_downsampler:_target_:sam2.modeling.memory_encoder.MaskDownSamplerkernel_size:3stride:2padding:1fuser:_target_:sam2.modeling.memory_encoder.Fuserlayer:_target_:sam2.modeling.memory_encoder.CXBlockdim:256kernel_size:7padding:3layer_scale_init_value:1e-6use_dwconv:True#depth-wiseconvsnum_layers:2num_maskmem:7image_size:1024#applyscaledsigmoidonmasklogitsformemoryencoder,anddirectlyfeedinputmaskasoutputmask#SAMdecodersigmoid_scale_for_mem_enc:20.0sigmoid_bias_for_mem_enc:-10.0use_mask_input_as_output_without_sam:true#Memorydirectly_add_no_mem_embed:true#usehigh-resolutionfeaturemapintheSAMmaskdecoderuse_high_res_features_in_sam:true#output3masksonthefirstclickoninitialconditioningframesmultimask_output_in_sam:true#SAMheadsiou_prediction_use_sigmoid:True#cross-attendtoobjectpointersfromotherframes(basedonSAMoutputtokens)intheencoderuse_obj_ptrs_in_encoder:trueadd_tpos_enc_to_obj_ptrs:falseonly_obj_ptrs_in_the_past_for_eval:true#objectocclusionpredictionpred_obj_scores:truepred_obj_scores_mlp:truefixed_no_obj_ptr:true#multimasktrackingsettingsmultimask_output_for_tracking:trueuse_multimask_token_for_obj_ptr:truemultimask_min_pt_num:0multimask_max_pt_num:1use_mlp_for_obj_ptr_proj:true#Compilationflag#HieraTdoesnotcurrentlysupportcompilation,shouldalwaysbesettoFalsecompile_image_encoder:False123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115如同我们在SMA2里面所讲的那样,SMA2模型由image_encoder、memory_attention、memory_encoder所构成(见Yaml的第3,26,59行)。ImageEncoder从yaml文件中,我们可以清晰的看到,ImageEncoder由两部分组成,分别是Hiera模型作为trunk和FpnNeck作为neck。Hiera是一个掩码自编码器MAE,是论文"Hiera:Ahierarchicalvisiontransformerwithoutthebells-and-whistles.ICML,2023."中提出的预训练模型。使用Hiera的编码器提取特征,并使用特征金字塔(FPN,FpnNeck)来融合提取出的特征。接下来我们看一下ImageEncoder的代码:classImageEncoder(nn.Module):def__init__(self,trunk:nn.Module,neck:nn.Module,scalp:int=0,):super().__init__()self.trunk=trunkself.neck=neckself.scalp=scalpassert(self.trunk.channel_list==self.neck.backbone_channel_list),f"Channeldimsoftrunkandneckdonotmatch.Trunk:{self.trunk.channel_list},neck:{self.neck.backbone_channel_list}"defforward(self,sample:torch.Tensor):#Forwardthroughbackbonefeatures,pos=self.neck(self.trunk(sample))ifself.scalp>0:#Discardthelowestresolutionfeaturesfeatures,pos=features[:-self.scalp],pos[:-self.scalp]src=features[-1]output={"vision_features":src,"vision_pos_enc":pos,"backbone_fpn":features,}returnoutput1234567891011121314151617181920212223242526272829关键代码是第18行,样本在ImageEncoder内部先经过trunk,然后再经过neck。实际上就是先使用Hiera处理得到结果,然后使用FpnNeck处理。FPN其实在图像领域是一个比较早的技术了,和他的名称相同,一目了然。这里就大概解释一下,比如模块中的position_encoding并未对x做操作,只是根据x的形状得到了pos。Neck:FpnNeckclassFpnNeck(nn.Module):'''根据yaml中的配置:d_model=256,backbone_channel_list=[768,384,192,96]fpn_top_down_levels=[2,3]fpn_interp_model=nearest'''def__init__(self,position_encoding:nn.Module,d_model:int,backbone_channel_listist[int],kernel_size:int=1,stride:int=1,padding:int=0,fpn_interp_model:str="bilinear",fuse_type:str="sum",fpn_top_down_levels:Optional[List[int]]=None,):super().__init__()self.position_encoding=position_encodingself.convs=nn.ModuleList()self.backbone_channel_list=backbone_channel_listfordiminbackbone_channel_list:current=nn.Sequential()current.add_module(##跳步连接中的1阶算子"conv",nn.Conv2d(in_channels=dim,out_channels=d_model,kernel_size=kernel_size,stride=stride,padding=padding,),)self.convs.append(current)self.fpn_interp_model=fpn_interp_modelassertfuse_typein["sum","avg"]self.fuse_type=fuse_type#levelstohavetop-downfeaturesinitsoutputs#e.g.iffpn_top_down_levelsis[2,3],thenonlyoutputsoflevel2and3#havetop-downpropagation,whileoutputsoflevel0andlevel1haveonly#lateralfeaturesfromthesamebackbonelevel.iffpn_top_down_levelsisNone:#defaultistohavetop-downfeaturesonalllevelsfpn_top_down_levels=range(len(self.convs))self.fpn_top_down_levels=list(fpn_top_down_levels)defforward(self,xsist[torch.Tensor])ut=[None]*len(self.convs)pos=[None]*len(self.convs)assertlen(xs)==len(self.convs)#fpnforwardpass#seehttps://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.pyprev_features=None#forwardintop-downorder(fromlowtohighresolution)n=len(self.convs)-1foriinrange(n,-1,-1):x=xs[i]lateral_features=self.convs[n-i](x)ifiinself.fpn_top_down_levelsandprev_featuresisnotNone:top_down_features=F.interpolate(prev_features.to(dtype=torch.float32),scale_factor=2.0,mode=self.fpn_interp_model,align_corners=(Noneifself.fpn_interp_model=="nearest"elseFalse),antialias=False,)prev_features=lateral_features+top_down_featuresifself.fuse_type=="avg":prev_features/=2else:prev_features=lateral_featuresx_out=prev_featuresout[i]=x_outpos[i]=self.position_encoding(x_out).to(x_out.dtype)returnout,pos1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586interpolate函数做上采样,conv1×11\times11×1算子将每个特征映射到相同的维度d_model。数据流转形式和上面的图片是一致。我们可以从代码67行的条件语句可以看出,模型只针对fpn_top_down_levels中指定的步骤所得出的特征做FPN融合。输出结果是一个元组(out,pos),我们先看out,out是一个元素全为tensor的列表,每个tensor的形状应为(…,d_model,x.shape[1],x.shape[2])。classPositionEmbeddingSine(nn.Module):##传入position_encoding实例的类定义"""Thisisamorestandardversionofthepositionembedding,verysimilartotheoneusedbytheAttentionisallyouneedpaper,generalizedtoworkonimages."""def__init__(self,num_pos_feats,temperature:int=10000,normalize:bool=True,scale:Optional[float]=None,):...@torch.no_grad()defforward(self,x:torch.Tensor):y_embed=(torch.arange(1,x.shape[-2]+1,dtype=torch.float32,device=x.device).view(1,-1,1).repeat(x.shape[0],1,x.shape[-1]))x_embed=(torch.arange(1,x.shape[-1]+1,dtype=torch.float32,device=x.device).view(1,1,-1).repeat(x.shape[0],x.shape[-2],1))dim_t=torch.arange(self.num_pos_feats,dtype=torch.float32,device=x.device)dim_t=self.temperature**(2*(dim_t//2)/self.num_pos_feats)pos_x=x_embed[:,:,:,None]/dim_tpos_y=y_embed[:,:,:,None]/dim_tpos_x=torch.stack((pos_x[:,:,:,0::2].sin(),pos_x[:,:,:,1::2].cos()),dim=4).flatten(3)pos_y=torch.stack((pos_y[:,:,:,0::2].sin(),pos_y[:,:,:,1::2].cos()),dim=4).flatten(3)pos=torch.cat((pos_y,pos_x),dim=3).permute(0,3,1,2)self.cache[cache_key]=pos[0]returnpos1234567891011121314151617181920212223242526272829303132333435363738394041他的官方注释也注明了,它非常类似于Attentionisallyouneed中的位置编码:pk,2i=sin(k100002i/d)pk,2i+1=cos(k100002i/d)p_{k,2i}=sin\left(\frac{k}{10000^{2i/d}}\right)\\p_{k,2i+1}=cos\left(\frac{k}{10000^{2i/d}}\right)pk,2i​=sin(100002i/dk​)pk,2i+1​=cos(100002i/dk​)代码84、85两行就是在计算100002i/d10000^{2i/d}100002i/d。87、88两行分别计算了pos_x与pos_y的k100002i/d\frac{k}{10000^{2i/d}}100002i/dk​.89-94行则分别完成了对pos_x和pos_y的位置编码计算。注意:类似而非相同。代码所示的计算方式如下:对于pos_x:px,y,2i=sin(i100002i/d)px,y,2i+1=cos(i100002i/d)p_{x,y,2i}=sin\left(\frac{i}{10000^{2i/d}}\right)\\p_{x,y,2i+1}=cos\left(\frac{i}{10000^{2i/d}}\right)px,y,2i​=sin(100002i/di​)px,y,2i+1​=cos(100002i/di​)对于pos_y:px,y,2i=sin(y100002i/d)px,y,2i+1=cos(y100002i/d)p_{x,y,2i}=sin\left(\frac{y}{10000^{2i/d}}\right)\\p_{x,y,2i+1}=cos\left(\frac{y}{10000^{2i/d}}\right)px,y,2i​=sin(100002i/dy​)px,y,2i+1​=cos(100002i/dy​)写在后面感觉对于代码讲解blog,是不是用视频的形式更好一点🤔。如果大家对文章形式风格有建议或者对内容有疑问欢迎留言😁。如果大家有想要博主阅读分享的文章或者代码欢迎留言讨论!!!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 会员注册

本版积分规则

QQ|手机版|心飞设计-版权所有:微度网络信息技术服务中心 ( 鲁ICP备17032091号-12 )|网站地图

GMT+8, 2025-1-8 12:31 , Processed in 0.436947 second(s), 26 queries .

Powered by Discuz! X3.5

© 2001-2025 Discuz! Team.

快速回复 返回顶部 返回列表