009、ESRGAN改进:RRDB残差密集块与相对对抗损失的实战优化

009、ESRGAN改进:RRDB残差密集块与相对对抗损失的实战优化

009、ESRGAN改进:RRDB残差密集块与相对对抗损失的实战优化

上周帮师弟调一个老照片修复项目,他直接套用SRGAN的生成器,结果纹理细节糊成一团,边缘还有诡异的伪影。我盯着loss曲线看了半天,发现判别器收敛太快,生成器根本没学到高频信息。这让我想起两年前自己踩过的坑——ESRGAN的论文读了三遍,代码跑通却出不来效果,最后发现是残差块结构写错了。今天把RRDB和相对对抗损失这两个核心改进的实战细节掰开揉碎,全是血泪教训。

从SRGAN到ESRGAN:那个让纹理起死回生的RRDB

先说说为什么SRGAN在真实场景下容易翻车。原始SRGAN用残差块堆叠,每个块里两个卷积加BN层。问题出在BN上——训练时BN统计量依赖batch内其他样本,超分任务里不同图像的高频特征差异极大,BN反而把纹理差异抹平了。我试过把batch size调到16,结果生成器输出像蒙了一层雾。

ESRGAN的RRDB(Residual-in-Residual Dense Block)直接砍掉BN,改用密集连接。这里有个容易写错的地方:RRDB内部是三个密集块串联,每个密集块里四个卷积层,每层输出都拼接到后续层输入。别写成普通残差块那种加法,要像DenseNet那样在通道维度拼接。我最初写代码时把torch.cat写成了torch.add,跑了三天才发现PSNR不升反降。

classDenseBlock(nn.Module):def__init__(self,in_channels,growth_rate=32):super().__init__()self.conv1=nn.Conv2d(in_channels,growth_rate,3,1,1)self.conv2=nn.Conv2d(in_channels+growth_rate,growth_rate,3,1,1)self.conv3=nn.Conv2d(in_channels+2*growth_rate,growth_rate,3,1,1)self.conv4=nn.Conv2d(in_channels+3*growth_rate,growth_rate,3,1,1)# 这里踩过坑:growth_rate别设太大,显存会爆,32够用self.lrelu=nn.LeakyReLU(0.2,inplace=True)defforward(self,x):x1=self.lrelu(self.conv1(x))x2=self.lrelu(self.conv2(torch.cat([x,x1],1)))x3=self.lrelu(self.conv3(torch.cat([x,x1,x2],1)))x4=self.conv4(torch.cat([x,x1,x2,x3],1))# 别这样写:直接return x4,要加上残差连接returntorch.cat([x,x1,x2,x3,x4],1)# 输出通道数 = in_channels + 4*growth_rate

RRDB的“残差中的残差”体现在:每个DenseBlock输出后要乘以残差缩放因子(通常0.2),再与输入相加。三个DenseBlock串联后,整体再加一次残差。这个缩放因子很关键,我试过0.1和0.5,0.2最稳,太大容易梯度爆炸,太小收敛慢。

相对对抗损失:让判别器学会“挑刺”

SRGAN用的标准GAN损失有个致命问题:判别器只判断输入是真实还是生成,相当于二分类。超分任务里,生成的高频细节哪怕有一丁点不自然,判别器都能轻松区分,导致生成器梯度消失。ESRGAN提出的RaGAN(Relativistic average GAN)改成了“真实图像比生成图像更真实”的相对判断。

具体实现时,判别器输出不再是单个概率,而是计算真实图像和生成图像特征之间的相对差异。我写代码时卡在怎么把标准判别器改成相对形式,后来发现其实就改一行loss计算:

defragan_loss(discriminator,real_imgs,fake_imgs):# 这里踩过坑:别直接用sigmoid,要计算相对概率real_logits=discriminator(real_imgs)fake_logits=discriminator(fake_imgs)# 相对判别器:真实图像比生成图像更真实的概率real_rel=real_logits-torch.mean(fake_logits,dim=0,keepdim=True)fake_rel=fake_logits-torch.mean(real_logits,dim=0,keepdim=True)# 别这样写:直接用BCEWithLogitsLoss,要手动算sigmoidd_loss_real=-torch.mean(torch.log(torch.sigmoid(real_rel)+1e-12))d_loss_fake=-torch.mean(torch.log(1-torch.sigmoid(fake_rel)+1e-12))d_loss=d_loss_real+d_loss_fake# 生成器损失:希望生成图像比真实图像更真实(反向)g_loss=-torch.mean(torch.log(torch.sigmoid(fake_rel)+1e-12))returnd_loss,g_loss

注意这里生成器loss是让fake_rel变大,即生成图像在相对比较中胜出。我刚开始写反了,把g_loss写成了d_loss_fake的形式,结果生成器疯狂生成模糊图像,因为模糊图像更容易骗过判别器。

训练策略:那些论文没写的调参细节

RRDB和RaGAN组合后,训练稳定性比SRGAN好很多,但仍有几个坑:

学习率设置:生成器用1e-4,判别器用4e-4,这个比例是经验值。判别器学习率太低,生成器会过拟合;太高,判别器太强导致生成器梯度消失。我试过两个都用1e-4,结果判别器loss一直下不去。

预热训练:前10个epoch只训练生成器,用L1损失。这步很关键,让生成器先学会基本结构,再引入对抗损失。别上来就开GAN,否则生成器会陷入局部最优,输出全是噪点。我有个项目因为没做预热,跑了200个epoch还是糊的。

损失权重:感知损失(VGG特征)权重设为1e-2,对抗损失权重设为1e-3。这个比例要微调,感知损失太强会压制纹理多样性,对抗损失太强会产生伪影。我习惯先固定感知损失权重,调整对抗损失,观察生成图像的高频细节是否自然。

梯度惩罚:虽然RaGAN比标准GAN稳定,但判别器还是容易过拟合。我在判别器里加了个梯度惩罚项,权重设为10。别用WGAN-GP那种复杂的梯度惩罚,简单的L2正则化就够用。

实战踩坑:从PSNR到感知质量的转变

用ESRGAN跑完训练,PSNR可能比SRGAN还低0.5dB,但人眼看起来更清晰。这是因为PSNR只衡量像素级差异,而ESRGAN牺牲了像素精度换来了纹理真实感。我刚开始纠结PSNR,后来发现用户反馈更好,才明白超分任务的评价指标要转向感知质量。

有个实际案例:给一个老照片修复项目,原图是1920x1080的模糊视频帧。用SRGAN放大4倍,人脸五官清晰但皮肤像塑料;用ESRGAN,皮肤纹理自然,但边缘有轻微锯齿。后来我在RRDB里加了边缘增强模块(一个sobel滤波器分支),锯齿问题解决,但训练时间增加了30%。这个取舍要看应用场景——如果是医疗影像,宁可模糊也不能有伪影;如果是娱乐应用,纹理真实感更重要。

个人经验性建议

  1. 别迷信论文里的超参数:ESRGAN原论文用DIV2K数据集,batch size 16,训练60万步。你的数据集如果只有几千张,把batch size降到4,学习率减半,否则过拟合。我试过用原参数训练小数据集,生成器直接崩了。

  2. RRDB的深度要适配硬件:原论文用23个RRDB,显存占用约8GB。如果你只有4GB显存,降到8个RRDB,效果损失不大。别硬撑,OOM了还得重来。

  3. 相对对抗损失不是万能药:对于严重模糊的图像,RaGAN可能产生过度锐化的伪影。这时候可以降低对抗损失权重,或者先用L1损失预训练到PSNR 30以上再开GAN。

  4. 调试时盯着生成图像看:loss曲线只能告诉你收敛趋势,具体效果得看图像。我习惯每10个epoch保存一批生成结果,观察纹理细节是否自然。如果出现棋盘格伪影,检查反卷积层是否用了正确的上采样方式(推荐pixel shuffle)。

  5. 代码实现要逐行检查:RRDB的密集连接、残差缩放、RaGAN的损失计算,这三个地方最容易出错。我建议先在小数据集(比如Set5)上跑通,确认PSNR和论文一致,再上大数据集。

最后说句实在话:ESRGAN的改进思路很巧妙,但实战中90%的问题出在实现细节上。把RRDB和RaGAN写对,你的超分效果就能超过90%的现有方法。剩下的10%,靠的是对数据集的深入理解和耐心调参。