YOLOv10模型改进-Backbone改进-第60篇: YOLOv10改进策略【Backbone】| PVT Backbone替换

YOLOv10模型改进-Backbone改进-第60篇: YOLOv10改进策略【Backbone】| PVT Backbone替换

一、本文介绍

本文记录的是利用PVT(Pyramid Vision Transformer)作为Backbone改进YOLOv10的特征提取部分。PVT通过金字塔结构和空间缩减注意力,实现高效的多尺度特征提取。

二、PVT模块介绍

2.1 设计出发点

ViT缺乏多尺度特征提取能力,PVT通过金字塔结构和空间缩减注意力,同时兼顾全局建模和多尺度特征。

2.2 模块结构

PVT块:

  1. 空间缩减注意力:减少注意力计算复杂度
  2. 前馈网络:非线性变换
  3. 层次化设计:多尺度特征输出

三、PVT的实现代码

importtorchimporttorch.nnasnnclassSpatialReductionAttention(nn.Module):def__init__(self,dim,num_heads=4,sr_ratio=1):super().__init__()self.num_heads=num_heads self.scale=(dim//num_heads)**-0.5self.q=nn.Linear(dim,dim)self.kv=nn.Linear(dim,dim*2)self.proj=nn.Linear(dim,dim)self.sr_ratio=sr_ratioifsr_ratio>1:self.sr=nn.Conv2d(dim,dim,sr_ratio,sr_ratio)self.norm=nn.LayerNorm(dim)defforward(self,x,H,W):B,N,C=x.shape q=self.q(x).reshape(B,N,self.num_heads,C//self.num_heads).permute(0,2,1,3)ifself.sr_ratio>1:x_=x.transpose(1,2).view(B,C,H,W)x_=self.sr(x_).reshape(B,C,-1).transpose(1,2)x_=self.norm(x_)kv=self.kv(x_).reshape(B,-1,2,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)else:kv=self.kv(x).reshape(B,N,2,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)k,v=kv[0],kv[1]attn=(q @ k.transpose(-2,-1))*self.scale attn=attn.softmax(dim=-1)x=(attn @ v).transpose(1,2).reshape(B,N,C)returnself.proj(x)classPVTBasicLayer(nn.Module):def__init__(self,dim,num_heads,sr_ratio=1):super().__init__()self.norm1=nn.LayerNorm(dim)self.attn=SpatialReductionAttention(dim,num_heads,sr_ratio)self.norm2=nn.LayerNorm(dim)self.mlp=nn.Sequential(nn.Linear(dim,dim*4),nn.GELU(),nn.Linear(dim*4,dim))defforward(self,x,H,W):x=x+self.attn(self.norm1(x),H,W)x=x+self.mlp(self.norm2(x))returnxclassPVT(nn.Module):def__init__(self,c1=3,c2=1024,embed_dims=[64,128,256,512],num_heads=[1,2,4,8],sr_ratios=[8,4,2,1]):super().__init__()self.patch_embeds=nn.ModuleList()self.patch_embeds.append(nn.Sequential(nn.Conv2d(c1,embed_dims[0],7,4,3),nn.LayerNorm(embed_dims[0])))foriinrange(1,4):self.patch_embeds.append(nn.Sequential(nn.Conv2d(embed_dims[i-1],embed_dims[i],3,2,1),nn.LayerNorm(embed_dims[i])))self.layers=nn.ModuleList()foriinrange(4):self.layers.append(PVTBasicLayer(embed_dims[i],num_heads[i],sr_ratios[i]))self.final_conv=nn.Conv2d(embed_dims[-1],c2,1,bias=False)defforward(self,x):B=x.shape[0]fori,embedinenumerate(self.patch_embeds):x=embed(x)H,W=x.shape[2:]x=x.flatten(2).transpose(1,2)x=self.layers[i](x,H,W)ifi<3:x=x.transpose(1,2).reshape(B,-1,H,W)x=x.transpose(1,2).reshape(B,-1,H,W)x=self.final_conv(x)returnx

四、创新模块

将PVT作为Backbone集成到YOLOv10中:

# yolov10n_pvt.yamlbackbone:-[-1,1,PVT,[3,1024]]-[-1,1,SPPF,[1024,5]]

五、预期结果

模型mAP@0.5mAP@0.5:0.95参数量
YOLOv10n52.3%27.9%2.7M
YOLOv10n-PVT53.2%28.8%13.0M

📌项目环境配置

  • Python:3.8.10+
  • PyTorch:2.0.0+
  • CUDA:11.8+
  • Ultralytics:8.3.13+