当前位置: 首页 > news >正文

别再死磕公式了!用Python和PyTorch手把手复现DDPM图像去噪(附完整代码)

从零构建DDPMPython与PyTorch实战图像去噪在计算机视觉领域扩散模型正迅速成为生成高质量图像的主流方法。本文将带您从零开始使用PyTorch框架完整实现一个基础的Denoising Diffusion Probabilistic ModelDDPM无需深入复杂的数学推导通过代码直观理解这一强大模型的工作原理。1. 扩散模型基础概念扩散模型的核心思想是通过逐步添加噪声破坏图像再学习逆向去噪过程。想象一下把一杯清水慢慢滴入墨水的过程——扩散模型的正向过程就如同这个污染过程而逆向过程则是神奇的净化操作。与传统GAN或VAE不同DDPM具有几个独特优势训练稳定性不依赖对抗训练避免了模式坍塌问题生成质量逐步细化生成过程能产生更自然的高频细节理论优雅基于热力学的非平衡统计物理基础在技术实现层面DDPM主要包含两个关键阶段前向扩散过程Fixed Markov Chain逐步向数据添加高斯噪声逆向去噪过程Learned Transition训练神经网络逐步去噪# 基础配置 import torch import torch.nn as nn import numpy as np from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt device torch.device(cuda if torch.cuda.is_available() else cpu)2. 前向扩散过程实现前向过程定义为马尔可夫链逐步将数据转化为各向同性高斯分布。关键在于设计合理的噪声调度noise schedule控制不同时间步的噪声添加量。2.1 噪声调度设计我们采用线性噪声调度定义从β₁1e-4到β_T0.02的线性增长序列def linear_beta_schedule(timesteps, start1e-4, end0.02): return torch.linspace(start, end, timesteps) T 1000 # 总时间步数 betas linear_beta_schedule(T) alphas 1. - betas alphas_cumprod torch.cumprod(alphas, axis0) # α的连乘积 sqrt_alphas_cumprod torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod torch.sqrt(1. - alphas_cumprod)2.2 单步扩散实现给定原始图像x₀和时间步t计算加噪后的图像x_tdef q_sample(x_start, t, noiseNone): if noise is None: noise torch.randn_like(x_start) sqrt_alpha_cumprod_t sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alpha_cumprod_t sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) return sqrt_alpha_cumprod_t * x_start sqrt_one_minus_alpha_cumprod_t * noise可视化不同时间步的加噪效果def plot_diffusion_process(image, num_steps5): plt.figure(figsize(15, 3)) plt.subplot(1, num_steps1, 1) plt.imshow(image.squeeze(), cmapgray) plt.title(Original) plt.axis(off) for i in range(1, num_steps1): t torch.tensor([i*(T//num_steps)-1]) noisy_image q_sample(image, t) plt.subplot(1, num_steps1, i1) plt.imshow(noisy_image.squeeze().cpu().numpy(), cmapgray) plt.title(fStep {t.item()1}) plt.axis(off) plt.show()3. 逆向去噪模型构建逆向过程的核心是训练一个噪声预测网络。我们采用改进的U-Net架构包含下采样和上采样路径并加入时间步嵌入。3.1 时间步嵌入将离散时间步转换为连续向量表示class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim dim def forward(self, t): device t.device half_dim self.dim // 2 embeddings torch.log(torch.tensor(10000.0)) / (half_dim - 1) embeddings torch.exp(torch.arange(half_dim, devicedevice) * -embeddings) embeddings t[:, None] * embeddings[None, :] embeddings torch.cat((embeddings.sin(), embeddings.cos()), dim-1) return embeddings3.2 基础残差块class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim): super().__init__() self.time_mlp nn.Linear(time_emb_dim, out_ch) self.conv1 nn.Conv2d(in_ch, out_ch, 3, padding1) self.conv2 nn.Conv2d(out_ch, out_ch, 3, padding1) self.act nn.SiLU() self.bn nn.BatchNorm2d(out_ch) def forward(self, x, t): h self.bn(self.act(self.conv1(x))) time_emb self.act(self.time_mlp(t)) h h time_emb.reshape(-1, h.shape[1], 1, 1) return self.act(self.conv2(h))3.3 完整U-Net实现class UNet(nn.Module): def __init__(self, in_channels1, out_channels1, dim32, dim_mults(1, 2, 4, 8)): super().__init__() self.time_mlp nn.Sequential( SinusoidalPositionEmbeddings(dim), nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim) ) dims [in_channels] [dim * m for m in dim_mults] self.downs nn.ModuleList([]) self.ups nn.ModuleList([]) # 下采样路径 for i in range(len(dims)-1): self.downs.append(Block(dims[i], dims[i1], dim)) # 中间层 self.mid Block(dims[-1], dims[-1], dim) # 上采样路径 for i in reversed(range(len(dims)-1)): self.ups.append(nn.ConvTranspose2d(dims[i1], dims[i], 4, 2, 1)) self.ups.append(Block(dims[i]*2, dims[i], dim)) self.final nn.Conv2d(dim, out_channels, 1) def forward(self, x, t): t self.time_mlp(t) hs [] # 下采样 for block in self.downs: x block(x, t) hs.append(x) x nn.functional.avg_pool2d(x, 2) # 中间层 x self.mid(x, t) # 上采样 for i in range(0, len(self.ups), 2): x self.ups[i](x) skip hs.pop() x torch.cat([x, skip], dim1) x self.ups[i1](x, t) return self.final(x)4. 训练流程实现DDPM的训练目标是最小化预测噪声与真实噪声之间的L2距离。4.1 损失函数定义def p_losses(denoise_model, x_start, t, noiseNone): if noise is None: noise torch.randn_like(x_start) x_noisy q_sample(x_start, t, noise) predicted_noise denoise_model(x_noisy, t) return torch.mean((noise - predicted_noise)**2)4.2 训练循环def train(model, dataloader, epochs100, lr1e-3): optimizer torch.optim.Adam(model.parameters(), lrlr) model.train() for epoch in range(epochs): total_loss 0 for batch, _ in dataloader: batch batch.to(device) # 随机采样时间步 t torch.randint(0, T, (batch.size(0),), devicedevice) # 计算损失 loss p_losses(model, batch, t) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss loss.item() print(fEpoch {epoch1} | Loss: {total_loss/len(dataloader):.4f}) return model5. 采样生成图像训练完成后我们可以通过逐步去噪从随机噪声生成新图像。5.1 单步采样torch.no_grad() def p_sample(model, x, t, t_index): betas_t betas[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alphas_cumprod_t sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_recip_alphas_t torch.sqrt(1.0 / alphas[t]).reshape(-1, 1, 1, 1) # 预测噪声 pred_noise model(x, t) # 计算均值 model_mean sqrt_recip_alphas_t * (x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t) if t_index 0: return model_mean else: posterior_variance_t (1 - alphas_cumprod[t-1]) / (1 - alphas_cumprod[t]) * betas[t] noise torch.randn_like(x) return model_mean torch.sqrt(posterior_variance_t).reshape(-1, 1, 1, 1) * noise5.2 完整采样流程torch.no_grad() def p_sample_loop(model, shape): # 从随机噪声开始 img torch.randn(shape, devicedevice) imgs [] for i in reversed(range(0, T)): t torch.full((shape[0],), i, devicedevice, dtypetorch.long) img p_sample(model, img, t, i) if i % (T//10) 0 or i T-1: imgs.append(img.cpu()) return imgs6. 实战演示与结果分析让我们在MNIST数据集上训练模型并观察生成效果。6.1 数据准备transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset datasets.MNIST(./data, trainTrue, downloadTrue, transformtransform) dataloader DataLoader(dataset, batch_size128, shuffleTrue)6.2 模型训练model UNet().to(device) trained_model train(model, dataloader, epochs20)6.3 生成新图像sample_size 16 generated_images p_sample_loop(trained_model, (sample_size, 1, 28, 28)) # 可视化生成过程 plt.figure(figsize(15, 15)) for i in range(len(generated_images)): plt.subplot(1, len(generated_images), i1) plt.imshow(generated_images[i][0].squeeze(), cmapgray) plt.title(fStep {i*(T//len(generated_images))}) plt.axis(off) plt.show()通过这个完整实现我们不仅理解了DDPM的核心原理还获得了可以实际运行的代码。虽然我们的示例基于简单的MNIST数据集但同样的架构经过适当调整可以扩展到更复杂的图像生成任务。
http://www.zskr.cn/news/1364590.html

相关文章:

  • ALE与SHAP结合:从黑盒模型到可解释灰盒的实战指南
  • 神经符号系统实践:耦合机器学习与本体论提升机器人自主诊断能力
  • 布里渊散射与机器学习势场协同表征MOF力学性能
  • 新电脑到手别急着用!Win11必做的3个存储优化设置(磁盘分区+改默认路径+软件安装避坑)
  • 量子核方法:从经典核技巧到量子特征映射的实践指南
  • Unity Android读取SD卡图片的5种实战方案与选型指南
  • Linux 文本三剑客组合实战(grep + sed + awk)
  • GitHub界面本地化:从语言障碍到无障碍协作的技术演进
  • 2026年4月比较好的探伤仪源头厂家口碑推荐,MP-2B金相磨抛机/棒材拉力试验机/铸件拉力试验机,探伤仪源头厂家推荐 - 品牌推荐师
  • 2026年锦城学院深度解析:民办高校招生竞争白热化与品牌信任构建 - 品牌推荐
  • uLipSync深度配置指南:从音素对齐到跨平台部署
  • 保姆级教程:手把手教你为ESXi 6.7配置主板BIOS(VT-x/VT-d/AES-NI全开)
  • 构建鲁棒机器学习系统:MLOps实战中的数据漂移、模型监控与自动化应对
  • 信用评分模型可解释性:从SHAP到反事实解释的工程实践
  • L2正则化:从防过拟合到抗成员推理攻击的轻量级隐私保护
  • 别再只调0.5了!Cascade R-CNN源码实战:用Python一步步复现多阈值级联检测
  • 利用随机森林从星系图像预测外生恒星质量分数
  • 临床机器学习中缺失值处理:医生信任哪种可解释模型方法?
  • BudgetMLAgent:多智能体协同与级联决策,实现低成本自动化机器学习
  • 客服机器人核心模型评估:从NLU、DM到NLG的Pipeline架构实战对比
  • NVIDIA Profile Inspector终极指南:5步解锁显卡隐藏功能,轻松提升游戏性能30%
  • GitHub汉化插件终极指南:3分钟打造高效中文开发环境
  • 1-3 电压和电流
  • C#调用C++ DLL崩溃的真正原因:调用约定错配详解
  • 腾讯点选VMP环境补全与Hook实战:构建可信浏览器沙盒
  • 【Midjourney怀旧美学权威白皮书】:基于3726张训练集图像反向工程的年代特征数据库(1920–1999分段建模)
  • 从各向同性到各向异性:高精度预测超导转变温度的计算方法与实战
  • 百度网盘全速下载终极指南:5分钟告别限速困扰
  • 充电桩监控系统容器化实践与数据标准化解析
  • ContextMenuManager:重新定义Windows右键菜单的交互设计思维