type
status
date
slug
summary
tags
category
icon
password

1. 技术背景:为什么 MoE 训练容易训推不一致

要理解路由回放,首先需要理解现代 LLM 强化学习(如 PPO)的训练循环以及底层计算架构。

1.1 PPO 的“三次前向传播”

一次标准的 PPO 迭代包含三个关键步骤,同一个数据 会被计算多次:
  1. Rollout (采样/推理): 智能体在推理引擎中生成数据(Prompt Response )。
      • 产出: 只有文本数据 。此时不保留计算图,因为推理引擎追求极致速度。
  1. Recompute (重计算): 训练引擎拿着数据 同样的模型权重(旧策略 ),重新跑一遍前向传播。
      • 目的: 计算出 的对数概率(Logits)。这是计算 PPO Loss 分母(基准概率)所必须的。
  1. Update (更新): 模型参数开始微调(新策略 ),再次前向传播计算新概率,并结合优势函数计算梯度。
问题的核心在于:第1步(Rollout)和第2步(Recompute)虽然理论上是同一个模型在处理同一个输入,但在工程实现上往往是两个框架。

1.2 致命的“框架鸿沟”

为了追求效率,现代架构通常是物理分离的:
  • 推理引擎 (SGLang/vLLM): 使用 FP8/Int8 量化、非连续内存(PagedAttention)、自定义 CUDA 核。
  • 训练引擎 (Megatron-LM): 使用 BF16/FP32 高精度、标准的 PyTorch 算子。

1.3 误差放大器:MoE 的 Top-K 敏感性

对于传统的 Dense 模型, 的 Logits 误差可能只是让概率分布稍微平移,影响微乎其微。但对于 MoE 架构,这种微小的扰动是致命的。 MoE 的 Router 执行的是 Top-K 离散选择(通常是 Argmax)。这导致 Router 的输出对扰动极度敏感:
  • 蝴蝶效应: Router Logits 上微小的数值差异,可能导致 个被选中的专家中,有一个甚至全部发生变化。
  • 后果: 训练引擎在 Recompute 时激活的专家路径,可能与推理引擎 Rollout 时实际走的路径完全不同。

2. 方案一:Recompute Routing Replay (重计算路由回放)

这一机制主要关注训练循环内部的一致性,是GSPO的解决方案。

2.1 核心逻辑

  • 回放源头: 训练引擎内部的 Recompute 阶段
即缓存 激活的专家,并在计算重要性比率时在 中“回放”这些路由模式
  • 操作步骤:
      1. 当训练引擎执行 Step 2 (Recompute) 时,计算出了一个专家掩码
      1. 系统将这个掩码缓存下来。
      1. 在紧接着的 Step 3 (Update) 中,无论参数如何更新,强制复用
  • 目的: 解决参数更新前后()导致的路由漂移。

3. 方案二:Rollout Routing Replay (R3, 推理路由回放)

这是论文 Stabilizing MoE Reinforcement Learning by Aligning Training and Inference Routers 提出的核心方案,旨在解决跨系统的根本性不一致。

3.1 核心逻辑:端到端的全链路锁定

  • 回放源头: 推理引擎的 Rollout 阶段
  • 操作步骤:
      1. 记录 (Record): 修改推理引擎,在生成 Token 时,将每一层的 Routing Mask () 记录下来。
      1. 透传 (Transfer): 作为数据的一部分(类似 Attention Mask)传给训练引擎。
      1. 回放 (Replay): 在训练引擎的 Step 2 (Recompute)Step 3 (Update) 中,彻底丢弃自己的路由选择权,直接使用

3.2 关键细节:如何保留梯度?

如果我们强制锁定了专家,Router 还能被训练吗?R3 采用混合计算公式:
  • 路径锁定 (): 确保物理上激活的专家与推理时完全一致(解决偏差)。
  • 权重可导 (): 保留训练引擎计算的原始 Logits 参与运算,确保梯度能回传给 Router,教它如何“打分”。

4. 总结对比

维度
Recompute Routing Replay (GSPO)
Rollout Routing Replay (R3)
数据源头
Training Engine (Recompute Step)
Inference Engine (Rollout Step)
解决的问题
误差:参数随时间更新导致的漂移。
误差:系统架构差异导致的漂移。
覆盖范围
仅 Update 阶段一致。
Rollout Recompute Update 全链路一致。
mini_step=1
无效。此时 Recompute 和 Update 本就一致。
有效。依然能纠正推理与训练的偏差。
工程复杂度
。仅需修改训练代码。
中高。需打通推理与训练引擎的 Mask 传输。
核心价值
无法阻止由框架差异引起的训练崩溃。
将 MoE 的稳定性提升至 Dense 模型水平。

 
百度云4机A800测试最近思考:少即是多
Loading...