增加感知损失perceptual_loss
- 新建models/perceptual_loss.py
import torch
import torch.nn as nn
import torchvision.models as modelsclass PerceptualLoss(nn.Module):def __init__(self, device):super(PerceptualLoss, self).__init__()vgg = models.vgg16(pretrained=True).features# 只取前4层(非常重要:省显存)self.vgg = nn.Sequential(*list(vgg[:4])).to(device)for param in self.vgg.parameters():param.requires_grad = Falseself.criterion = nn.L1Loss()def forward(self, fake, real):# VGG要求输入是 [-1,1] → [0,1]fake = (fake + 1) / 2.0real = (real + 1) / 2.0fake_f = self.vgg(fake)real_f = self.vgg(real)loss = self.criterion(fake_f, real_f)return loss
- 修改models/pix2pix_model.py
## 文件顶部增加
from models.perceptual_loss import PerceptualLoss## 初始化loss
self.lambda_perceptual = 10 # 可调(建议5~10)
self.criterionPerceptual = PerceptualLoss(self.device)
本文来自博客园,作者:jsqup,转载请注明原文链接:https://www.cnblogs.com/jsqup/p/20676333