找回密码
 会员注册
查看: 26|回复: 0

YOLOv8改进添加swintransformer

[复制链接]

2万

主题

0

回帖

6万

积分

超级版主

积分
69864
发表于 2024-9-10 14:57:11 | 显示全部楼层 |阅读模式
最近在做实验,需要改进YOLOv8,去网上找了很多教程都是充钱才能看的,NND这对一个一餐只能吃两个菜的大学生来说是多么的痛苦,所以自己去找代码手动改了一下,成功实现YOLOv8改进添加swintransformer,本人水平有限,改得不对的地方请自行改正。第一步,在ultralytics\nn\modules\block.py代码中的最后部分中添加swintransformer代码,代码如下:#----------swintf----C3STR---------------------------------classSwinTransformerBlock(nn.Module):def__init__(self,c1,c2,num_heads,num_layers,window_size=8):super().__init__()self.conv=Noneifc1!=c2:self.conv=Conv(c1,c2)#removeinput_resolutionself.blocks=nn.Sequential(*[SwinTransformerLayer(dim=c2,num_heads=num_heads,window_size=window_size,shift_size=0if(i%2==0)elsewindow_size//2)foriinrange(num_layers)])defforward(self,x):ifself.convisnotNone:x=self.conv(x)x=self.blocks(x)returnxclassWindowAttention(nn.Module):def__init__(self,dim,window_size,num_heads,qkv_bias=True,qk_scale=None,attn_drop=0.,proj_drop=0.):super().__init__()self.dim=dimself.window_size=window_size#Wh,Wwself.num_heads=num_headshead_dim=dim//num_headsself.scale=qk_scaleorhead_dim**-0.5#defineaparametertableofrelativepositionbiasself.relative_position_bias_table=nn.Parameter(torch.zeros((2*window_size[0]-1)*(2*window_size[1]-1),num_heads))#2*Wh-1*2*Ww-1,nH#getpair-wiserelativepositionindexforeachtokeninsidethewindowcoords_h=torch.arange(self.window_size[0])coords_w=torch.arange(self.window_size[1])coords=torch.stack(torch.meshgrid([coords_h,coords_w]))#2,Wh,Wwcoords_flatten=torch.flatten(coords,1)#2,Wh*Wwrelative_coords=coords_flatten[:,:,None]-coords_flatten[:,None,:]#2,Wh*Ww,Wh*Wwrelative_coords=relative_coords.permute(1,2,0).contiguous()#Wh*Ww,Wh*Ww,2relative_coords[:,:,0]+=self.window_size[0]-1#shifttostartfrom0relative_coords[:,:,1]+=self.window_size[1]-1relative_coords[:,:,0]*=2*self.window_size[1]-1relative_position_index=relative_coords.sum(-1)#Wh*Ww,Wh*Wwself.register_buffer("relative_position_index",relative_position_index)self.qkv=nn.Linear(dim,dim*3,bias=qkv_bias)self.attn_drop=nn.Dropout(attn_drop)self.proj=nn.Linear(dim,dim)self.proj_drop=nn.Dropout(proj_drop)nn.init.normal_(self.relative_position_bias_table,std=.02)self.softmax=nn.Softmax(dim=-1)defforward(self,x,mask=None):B_,N,C=x.shapeqkv=self.qkv(x).reshape(B_,N,3,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)q,k,v=qkv[0],qkv[1],qkv[2]#maketorchscripthappy(cannotusetensorastuple)q=q*self.scaleattn=(q@k.transpose(-2,-1))relative_position_bias=self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0]*self.window_size[1],self.window_size[0]*self.window_size[1],-1)#Wh*Ww,Wh*Ww,nHrelative_position_bias=relative_position_bias.permute(2,0,1).contiguous()#nH,Wh*Ww,Wh*Wwattn=attn+relative_position_bias.unsqueeze(0)ifmaskisnotNone:nW=mask.shape[0]attn=attn.view(B_//nW,nW,self.num_heads,N,N)+mask.unsqueeze(1).unsqueeze(0)attn=attn.view(-1,self.num_heads,N,N)attn=self.softmax(attn)else:attn=self.softmax(attn)attn=self.attn_drop(attn)#print(attn.dtype,v.dtype)try:x=(attn@v).transpose(1,2).reshape(B_,N,C)except:#print(attn.dtype,v.dtype)x=(attn.half()@v).transpose(1,2).reshape(B_,N,C)x=self.proj(x)x=self.proj_drop(x)returnxclassMlp(nn.Module):def__init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.SiLU,drop=0.):super().__init__()out_features=out_featuresorin_featureshidden_features=hidden_featuresorin_featuresself.fc1=nn.Linear(in_features,hidden_features)self.act=act_layer()self.fc2=nn.Linear(hidden_features,out_features)self.drop=nn.Dropout(drop)defforward(self,x):x=self.fc1(x)x=self.act(x)x=self.drop(x)x=self.fc2(x)x=self.drop(x)returnxclassSwinTransformerLayer(nn.Module):def__init__(self,dim,num_heads,window_size=8,shift_size=0,mlp_ratio=4.,qkv_bias=True,qk_scale=None,drop=0.,attn_drop=0.,drop_path=0.,act_layer=nn.SiLU,norm_layer=nn.LayerNorm):super().__init__()self.dim=dimself.num_heads=num_headsself.window_size=window_sizeself.shift_size=shift_sizeself.mlp_ratio=mlp_ratio#ifmin(self.input_resolution)0.elsenn.Identity()self.norm2=norm_layer(dim)mlp_hidden_dim=int(dim*mlp_ratio)self.mlp=Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer,drop=drop)defcreate_mask(self,H,W):#calculateattentionmaskforSW-MSAimg_mask=torch.zeros((1,H,W,1))#1HW1h_slices=(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))w_slices=(slice(0,-self.window_size),slice(-self.window_size,-self.shift_size),slice(-self.shift_size,None))cnt=0forhinh_slices:forwinw_slices:img_mask[:,h,w,:]=cntcnt+=1defwindow_partition(x,window_size):"""Args:xB,H,W,C)window_size(int):windowsizeReturns:windowsnum_windows*B,window_size,window_size,C)"""B,H,W,C=x.shapex=x.view(B,H//window_size,window_size,W//window_size,window_size,C)windows=x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)returnwindowsdefwindow_reverse(windows,window_size,H,W):"""Args:windowsnum_windows*B,window_size,window_size,C)window_size(int):WindowsizeH(int):HeightofimageW(int):WidthofimageReturns:xB,H,W,C)"""B=int(windows.shape[0]/(H*W/window_size/window_size))x=windows.view(B,H//window_size,W//window_size,window_size,window_size,-1)x=x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)returnxmask_windows=window_partition(img_mask,self.window_size)#nW,window_size,window_size,1mask_windows=mask_windows.view(-1,self.window_size*self.window_size)attn_mask=mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)attn_mask=attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))returnattn_maskdefforward(self,x):#reshapex[bchw]tox[blc]_,_,H_,W_=x.shapePadding=Falseifmin(H_,W_)0:attn_mask=self.create_mask(H,W).to(x.device)else:attn_mask=Noneshortcut=xx=self.norm1(x)x=x.view(B,H,W,C)#cyclicshiftifself.shift_size>0:shifted_x=torch.roll(x,shifts=(-self.shift_size,-self.shift_size),dims=(1,2))else:shifted_x=xdefwindow_partition(x,window_size):"""Args:xB,H,W,C)window_size(int):windowsizeReturns:windowsnum_windows*B,window_size,window_size,C)"""B,H,W,C=x.shapex=x.view(B,H//window_size,window_size,W//window_size,window_size,C)windows=x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)returnwindowsdefwindow_reverse(windows,window_size,H,W):"""Args:windowsnum_windows*B,window_size,window_size,C)window_size(int):WindowsizeH(int):HeightofimageW(int):WidthofimageReturns:xB,H,W,C)"""B=int(windows.shape[0]/(H*W/window_size/window_size))x=windows.view(B,H//window_size,W//window_size,window_size,window_size,-1)x=x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)returnx#partitionwindowsx_windows=window_partition(shifted_x,self.window_size)#nW*B,window_size,window_size,Cx_windows=x_windows.view(-1,self.window_size*self.window_size,C)#nW*B,window_size*window_size,C#W-MSA/SW-MSAattn_windows=self.attn(x_windows,mask=attn_mask)#nW*B,window_size*window_size,C#mergewindowsattn_windows=attn_windows.view(-1,self.window_size,self.window_size,C)shifted_x=window_reverse(attn_windows,self.window_size,H,W)#BH'W'C#reversecyclicshiftifself.shift_size>0:x=torch.roll(shifted_x,shifts=(self.shift_size,self.shift_size),dims=(1,2))else:x=shifted_xx=x.view(B,H*W,C)#FFNx=shortcut+self.drop_path(x)x=x+self.drop_path(self.mlp(self.norm2(x)))x=x.permute(0,2,1).contiguous().view(-1,C,H,W)#bchwifPadding:x=x[:,:,:H_,:W_]#reversepaddingreturnxclassC3STR(C3):#C3modulewithSwinTransformerBlock()def__init__(self,c1,c2,n=1,shortcut=True,g=1,e=0.5):super().__init__(c1,c2,n,shortcut,g,e)c_=int(c2*e)num_heads=c_//32self.m=SwinTransformerBlock(c_,c_,num_heads,n)123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288在开头引入代码:fromtimm.models.layersimportDropPath,to_2tuple,trunc_normal_一般是没有的,先在自己的环境下输入:pipinstalltimm-ihttps://mirrors.bfsu.edu.cn/pypi/web/simple下载第二步:ultralytics\nn\modules_init_.py中添加C3STR,首先在from.blockimport中添加如下:from.blockimport(C1,C2,C3,C3TR,DFL,SPP,SPPF,Bottleneck,BottleneckCSP,C2f,C2fAttn,ImagePoolingAttn,C3Ghost,C3x,GhostBottleneck,HGBlock,HGStem,Proto,RepC3,ResNetLayer,ContrastiveHead,BNContrastiveHead,RepNCSPELAN4,ADown,SPPELAN,CBFuse,CBLinear,Silence,C3STR,#添加swin_transfomer)1234567891011121314151617181920212223242526272829303132添加后如下图:在__all__中添加如下代码:__all__=("Conv","Conv2","LightConv","RepConv","DWConv","DWConvTranspose2d","ConvTranspose","Focus","GhostConv","ChannelAttention","SpatialAttention","CBAM","Concat","TransformerLayer","TransformerBlock","MLPBlock","LayerNorm2d","DFL","HGBlock","HGStem","SPP","SPPF","C1","C2","C3","C2f","C2fAttn","C3x","C3TR","C3Ghost","GhostBottleneck","Bottleneck","BottleneckCSP","Proto","Detect","Segment","Pose","Classify","TransformerEncoderLayer","RepC3","RTDETRDecoder","AIFI","DeformableTransformerDecoder","DeformableTransformerDecoderLayer","MSDeformAttn","MLP","ResNetLayer","OBB","WorldDetect","ImagePoolingAttn","ContrastiveHead","BNContrastiveHead","RepNCSPELAN4","ADown","SPPELAN","CBFuse","CBLinear","Silence","GAMAttention",#修改添加GAM"C3STR",#添加swin_transfomer)1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465添加后如下图:第三步,在ultralytics\nn\tasks.py中添加C3STR,首先在fromultralytics.nn.modulesimport处添加,效果如下:fromultralytics.nn.modulesimport(AIFI,C1,C2,C3,C3TR,OBB,SPP,SPPF,Bottleneck,BottleneckCSP,C2f,C2fAttn,ImagePoolingAttn,C3Ghost,C3x,Classify,Concat,Conv,Conv2,ConvTranspose,Detect,DWConv,DWConvTranspose2d,Focus,GhostBottleneck,GhostConv,HGBlock,HGStem,Pose,RepC3,RepConv,ResNetLayer,RTDETRDecoder,Segment,WorldDetect,RepNCSPELAN4,ADown,SPPELAN,CBFuse,CBLinear,Silence,GAMAttention,C3STR,)12345678910111213141516171819202122232425262728293031323334353637383940414243444546添加后图如下:其次:按住Ctrl+F输入ifmin在如下图位置输入:C3STR第四步,在ultralytics\cfg\models\v8中复制yolov8.yaml文件,修改如下:#UltralyticsYOLO🚀,AGPL-3.0license#YOLOv8objectdetectionmodelwithP3-P5outputs.ForUsageexamplesseehttps://docs.ultralytics.com/tasks/detect#Parametersnc:1#numberofclassesscales:#modelcompoundscalingconstants,i.e.'model=yolov8n.yaml'willcallyolov8.yamlwithscale'n'#[depth,width,max_channels]n:[0.33,0.25,1024]#YOLOv8nsummary:225layers,3157200parameters,3157184gradients,8.9GFLOPss:[0.33,0.50,1024]#YOLOv8ssummary:225layers,11166560parameters,11166544gradients,28.8GFLOPsm:[0.67,0.75,768]#YOLOv8msummary:295layers,25902640parameters,25902624gradients,79.3GFLOPsl:[1.00,1.00,512]#YOLOv8lsummary:365layers,43691520parameters,43691504gradients,165.7GFLOPsx:[1.00,1.25,512]#YOLOv8xsummary:365layers,68229648parameters,68229632gradients,258.5GFLOPs#YOLOv8.0nbackbonebackbone:#[from,repeats,module,args]-[-1,1,Conv,[64,3,2]]#0-P1/2-[-1,1,Conv,[128,3,2]]#1-P2/4-[-1,3,C2f,[128,True]]-[-1,1,Conv,[256,3,2]]#3-P3/8-[-1,6,C2f,[256,True]]-[-1,1,Conv,[512,3,2]]#5-P4/16-[-1,6,C2f,[512,True]]-[-1,1,Conv,[1024,3,2]]#7-P5/32-[-1,3,C2f,[1024,True]]-[-1,3,C3STR,[1024]]-[-1,1,SPPF,[1024,5]]#10#YOLOv8.0nheadhead:-[-1,1,nn.Upsample,[None,2,"nearest"]]-[[-1,6],1,Concat,[1]]#catbackboneP4-[-1,3,C2f,[512]]#13-[-1,1,nn.Upsample,[None,2,"nearest"]]-[[-1,4],1,Concat,[1]]#catbackboneP3-[-1,3,C2f,[256]]#16(P3/8-small)-[-1,1,Conv,[256,3,2]]-[[-1,12],1,Concat,[1]]#catheadP4-[-1,3,C2f,[512]]#19(P4/16-medium)-[-1,1,Conv,[512,3,2]]-[[-1,9],1,Concat,[1]]#catheadP5-[-1,3,C2f,[1024]]#22(P5/32-large)-[[16,19,22],1,Detect,[nc]]#Detect(P3,P4,P5)12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849在本项目中创建一个train.py文件,大致代码为:fromultralyticsimportYOLO#importos#os.environ['CUDA_VISIBLE_DEVICES']='1'if__name__=='__main__':#Loadamodelmodel=YOLO(r'E:\yolov8_car_deepsort\yolov8_car\ultralytics\cfg\models\v8\yolov8-Swin_transformer.yaml')#不使用预训练权重训练#model=YOLO(r'yolov8p.yaml').load("yolov8n.pt")#使用预训练权重训练#Trainparameters----------------------------------------------------------------------------------------------model.train(data=r'E:\yolov8_car_deepsort\yolov8_car\ultralytics\cfg\datasets\mycar.yaml',epochs=300,#(int)numberofepochstotrainforbatch=16,#(int)numberofimagesperbatch(-1forAutoBatch)imgsz=640,#(int)sizeofinputimagesasintegerorw,hsave=True,#(bool)savetraincheckpointsandpredictresultssave_period=-1,#(int)Savecheckpointeveryxepochs(disabledif
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 会员注册

本版积分规则

QQ|手机版|心飞设计-版权所有:微度网络信息技术服务中心 ( 鲁ICP备17032091号-12 )|网站地图

GMT+8, 2025-1-7 07:08 , Processed in 0.655138 second(s), 25 queries .

Powered by Discuz! X3.5

© 2001-2025 Discuz! Team.

快速回复 返回顶部 返回列表