当前位置: 首页 > news >正文

YAML 配置深度学习网络

YAML 配置深度学习网络

  • launcher 启动器、加载器
  • scratch 从零初始化模型训练
  • dataset 数据集
  • Video Object Segmentation(视频目标分割)
  • trainer 训练器
    • trainer
    • model
    • data
    • optim
    • loss
    • distributed
    • logging
    • checkpoint

YAML源文件:RobotSeg的robotseg-train.yaml

launcher 启动器、加载器

功能:训练资源 + 启动方式的配置

launcher:num_nodes:1# 使用 1 个计算节点(单机训练)gpus_per_node:8# 每个节点使用 8 张 GPUexperiment_log_dir:./logs/robotseg# 日志/模型保存路径
  • num_nodes: 1 表示单机训练,如果是多机分布式训练,需改为2、3、4…
  • gpus_per_node: 8 表示该单机上使用8块GPU并行训练
  • experiment_log_dir:训练日志、TensorBoard 文件、检查点(checkpoint)、配置备份地址

scratch 从零初始化模型训练

功能:不加载预训练权重,从零开始训练(Train from Scratch)的核心超参配置块,模型怎么训练、用多大资源、训练多久。

scratch:resolution:1024# 图像分辨率:输入图片大小 1024x1024train_batch_size_image:10# 图片训练批次:一次喂 10 张图train_batch_size_video:1# 视频训练批次:一次喂 1 个视频num_train_workers:10# 数据加载线程数:10 个线程并行读数据num_frames:7# 每个视频抽 7 帧图像max_num_objects:1# 每个样本最多处理 1 个物体object_token_num:3# 每个物体用 3 个 token 表示base_lr:3.0e-4# 基础学习率:0.0003vision_lr:6.0e-05# 视觉模型分支学习率:0.00006phases_per_epoch:1# 每个 epoch 执行 1 个训练阶段num_epochs:25# 总共训练 25 轮

dataset 数据集

功能:数据集配置块,告诉代码数据存在哪个文件夹、去哪里读图片、去哪里读标签

dataset:# 1. 数据集总根目录(所有数据的大文件夹)root:/workspace/RobotSeg/dataset# 2. RoboEngine 数据集(仿真数据)roboengine_img_folder:${dataset.root}/RoboEngine/train/image# 图片路径roboengine_gt_folder:${dataset.root}/RoboEngine/train/mask# 掩码标签路径roboengine_file_list_txt:train/txt/roboengine_train_list.txt# 样本列表文件# 3. VRS 数据集(真实机器人数据)vrs_img_folder:${dataset.root}/VRS/train/image# 图片路径vrs_gt_folder:${dataset.root}/VRS/train/mask_gt_dinov3# 掩码标签路径vrs_file_list_txt:train/txt/vrs_train_list.txt# 样本列表文件

注意这里使用${dataset.root}变量引用,等价于/workspace/RobotSeg/dataset。

Video Object Segmentation(视频目标分割)

功能:视频数据的训练增强(Transforms)配置,作用是在训练前,对视频帧自动做随机变形、颜色抖动、缩放等处理,让模型更鲁棒、泛化更强。

# 视频目标分割的数据增强vos:train_transforms:# 1. 组合增强器-_target_:train.dataset.transforms.ComposeAPItransforms:# --------------------------# 【几何增强】对所有帧保持一致# --------------------------# 随机水平翻转-_target_:train.dataset.transforms.RandomHorizontalFlipconsistent_transform:True# 随机仿射变换(旋转、拉伸)-_target_:train.dataset.transforms.RandomAffinedegrees:25# 最大旋转 ±25°shear:20# 最大拉伸 ±20°image_interpolation:bilinearconsistent_transform:True# 随机 resize 到统一分辨率-_target_:train.dataset.transforms.RandomResizeAPIsizes:${scratch.resolution}# 用前面配置的 1024square:true# 变成正方形consistent_transform:True# --------------------------# 【颜色增强】部分帧独立变化# --------------------------# 颜色抖动(亮度、对比度、饱和度)-_target_:train.dataset.transforms.ColorJitterconsistent_transform:Truebrightness:0.1contrast:0.03saturation:0.03hue:null# 5% 概率随机灰度化-_target_:train.dataset.transforms.RandomGrayscalep:0.05consistent_transform:True# 第二次颜色抖动(独立增强,让数据更多样)-_target_:train.dataset.transforms.ColorJitterconsistent_transform:Falsebrightness:0.1contrast:0.05saturation:0.05hue:null# --------------------------# 【最后:标准化】# ---------------------------_target_:train.dataset.transforms.ToTensorAPI# 转张量-_target_:train.dataset.transforms.NormalizeAPI# 归一化mean:[0.485,0.456,0.406]# ImageNet 均值std:[0.229,0.224,0.225]# ImageNet 方差

参考各_target_的实现可以看的更清晰:

classComposeAPI:def__init__(self,transforms):# 需要的参数就是transformsself.transforms=transforms# 所以在上面的YAML中,- _target_: train.dataset.transforms.ComposeAPI# 后面跟着一个 transforms 参数
classRandomHorizontalFlip:def__init__(self,consistent_transform,p=0.5):self.p=p self.consistent_transform=consistent_transform# 这就是为啥- _target_: train.dataset.transforms.RandomHorizontalFlip# 后面跟着一个 consistent_transform: True 参数
  • 其它的_target_类似
vos └── train_transforms # 训练阶段视频帧预处理/数据增强流水线 └── ComposeAPI # 容器:按顺序串行执行所有变换 ├─ 1. RandomHorizontalFlip 随机水平翻转 │ └─ consistent_transform: True # 视频所有帧保持一致变换 ├─ 2. RandomAffine 随机仿射变换 │ ├─ degrees: 25 # 旋转角度范围 ±25° │ ├─ shear: 20 # 错切变形 ±20° │ ├─ image_interpolation: bilinear # 双线性插值 │ └─ consistent_transform: True # 帧间变换统一 ├─ 3. RandomResizeAPI 随机缩放到指定尺寸 │ ├─ sizes: ${scratch.resolution} # 引用全局分辨率 1024 │ ├─ square: true # 输出正方形图像 │ └─ consistent_transform: True # 帧间变换统一 ├─ 4. ColorJitter 颜色抖动(第一组) │ ├─ consistent_transform: True │ ├─ brightness: 0.1 │ ├─ contrast: 0.03 │ ├─ saturation: 0.03 │ └─ hue: null # 不做色相变换 ├─ 5. RandomGrayscale 随机转灰度图 │ ├─ p: 0.05 # 5% 概率执行 │ └─ consistent_transform: True ├─ 6. ColorJitter 颜色抖动(第二组) │ ├─ consistent_transform: False # 每一帧独立变换 │ ├─ brightness: 0.1 │ ├─ contrast: 0.05 │ ├─ saturation: 0.05 │ └─ hue: null ├─ 7. ToTensorAPI 图像转 PyTorch 张量 └─ 8. NormalizeAPI 归一化 ├─ mean: [0.485, 0.456, 0.406] # ImageNet 均值 └─ std: [0.229, 0.224, 0.225] # ImageNet 标准差

trainer 训练器

功能:训练任务总调度器,是整个训练流程的入口与总指挥,把模型、数据、优化器、损失、日志、分布式、权重读写所有模块组装到一起,统一调度运行,启动训练后,所有逻辑都由它接管。

# ==============================================================================# 训练器总配置:整个模型训练的“大脑”,控制所有流程# ==============================================================================trainer:_target_:train.trainer.Trainer# 训练器类(代码入口)mode:train_only# 模式:只训练,不测试max_epochs:${times:${scratch.num_epochs},${scratch.phases_per_epoch}}# 总训练轮数 = 25*1=25accelerator:cuda# 使用GPU加速seed_value:2026# 随机种子,保证实验可复现# ============================================================================# 模型结构:RobotSeg(基于SAM 2.1的视频目标分割模型)# ============================================================================model:_target_:train.model.robotseg.RobotSegTrain# 训练用的模型主体# ----------------------------# 1. 图像编码器:提取图像特征# ----------------------------image_encoder:_target_:robotseg.modeling.backbones.image_encoder.ImageEncoderscalp:1trunk:_target_:robotseg.modeling.backbones.hieradet.Hiera# SAM 2.1主干:HieraTransformerembed_dim:96# 基础特征维度num_heads:1# 注意力头数stages:[1,2,7,2]# 网络层级结构global_att_blocks:[5,7,9]# 全局注意力层位置window_pos_embed_bkg_spatial_size:[7,7]drop_path_rate:0.1# 随机深度,防止过拟合neck:_target_:robotseg.modeling.backbones.image_encoder.FpnNeck# FPN特征金字塔position_encoding:_target_:robotseg.modeling.position_encoding.PositionEmbeddingSine# 位置编码num_pos_feats:256normalize:truescale:nulltemperature:10000d_model:256# 特征通道数backbone_channel_list:[768,384,192,96]# 各层级输出通道fpn_top_down_levels:[2,3]# FPN顶层向下融合层fpn_interp_model:nearest# 上采样方式# ----------------------------# 2. 记忆注意力:视频时序信息融合# ----------------------------memory_attention:_target_:robotseg.modeling.memory_attention_with_structure.MemoryAttentiond_model:256pos_enc_at_input:truelayer:_target_:robotseg.modeling.memory_attention_with_structure.MemoryAttentionLayeractivation:reludim_feedforward:2048dropout:0.1pos_enc_at_attn:false# 自注意力(当前帧特征)self_attention:_target_:robotseg.modeling.sam.transformer.RoPEAttentionrope_theta:10000.0feat_sizes:[64,64]embedding_dim:256num_heads:1downsample_rate:1dropout:0.1# 交叉注意力(当前帧 ↔ 记忆帧)cross_attention:_target_:robotseg.modeling.sam.transformer.RoPEAttentionrope_theta:10000.0feat_sizes:[64,64]rope_k_repeat:Trueembedding_dim:256num_heads:1downsample_rate:1dropout:0.1kv_in_dim:64# 结构感知交叉注意力cross_attention_structure:_target_:robotseg.modeling.sam.transformer.RoPEAttentionrope_theta:10000.0feat_sizes:[64,64]
http://www.zskr.cn/news/1488257.html

相关文章:

  • 从ImageNet到CLIP:手把手带你用PyTorch复现对比学习的关键训练技巧(附避坑指南)
  • 如何快速掌握Reloaded-II:终极游戏Mod加载器完全指南
  • S32G LLCE CAN硬件对象配置详解与CAN2CAN应用实战
  • 10分钟搞定黑苹果:OpCore-Simplify一键自动化EFI配置工具终极指南
  • NXP DPAA2 SerDes Lane复位操作:解决链路正常但数据不通的底层调试方法
  • 2026 年 6 月沈阳手表回收行情,变现干货速看 - 讯息早知道
  • 专利
  • 无线RS-232通信系统设计:基于动态直流平衡编码的可靠链路实现
  • 保姆级教程:用Kali Linux和Aircrack-ng抓取自家智能家居的加密流量(附Wireshark解密配置)
  • 招聘数据一键抓取分析包:智联/拉勾/51job多平台Python爬虫+词云可视化
  • UKI.js终极指南:10分钟掌握轻量级Web应用UI工具包
  • 沈阳手表回收常见压价套路,内行干货拆解 - 讯息早知道
  • 长沙家居定制厂家实力解析:湖南桦美家家居全维度展示 - 互联网科技品牌测评
  • Steam创意工坊下载终极解决方案:WorkshopDL跨平台模组管理工具
  • 8.2 | 负压收集+生物滤池+化学洗涤:除臭系统的三级防线设计
  • 2026 深圳奢包回收测评榜单:爱马仕香奈儿回收优选机构盘点! - 奢侈品交易观察员
  • 向量空间JBoltAI:企业大脑与数字员工的双引擎
  • 如何用STIX Two字体彻底解决学术文档的排版难题:终极指南
  • CANoe诊断安全访问避坑指南:二次封装DLL时LoadLibrary失败与路径问题的解决
  • 深圳收的顶本地老牌回收商家,专注高端首饰,各大奢侈品牌全覆盖 - 奢侈品回收测评
  • 2026专业的通风设备公司推荐及行业发展解析 - 品牌排行榜
  • 告别虚拟机!用DosBox+MASM6.15在Win10/Win11上快速搭建汇编学习环境(保姆级图文)
  • 2026 年电动汽车充电桩厂家排名怎么选?结合市场数据解析电动汽车充电桩品牌排名,客观对比各厂家综合实力与适配场景 - 栗子测评
  • WebPShop:Photoshop最佳WebP插件,轻松优化网页图片和动画
  • 2026成都卖黄金别乱选!6 家主流回收机构深度盘点,新手也能安心变现 - 薛定谔的梨花猫
  • 2026 东莞正规专业回收公司推荐|钨钢铣刀 钨钢粒 钨钢粉 钨钢泥 线路板 电缆线 紫铜红铜 铜渣铜线 锡块锡条锡线回收指南 - 星际AI
  • 2026年哈尔滨市CPPM考试最新全攻略:科目题型、通过率、备考重点及官方双认证报考机构推荐 - 众智商学院课程中心
  • eBay账户保护机制深度解读:为什么你的竞价会被限制?如何主动预防?
  • TJA1101A汽车以太网PHY寄存器配置与低功耗模式实战指南
  • 2026年超高效过滤器深度解析:高效净化技术与应用 - 品牌排行榜