type
status
date
slug
summary
tags
category
icon
password
阿里最新的RLHF论文,一直没搞懂
“grpo是用两个模型对每个token做重要性采样,gspo是对整个句子做重要性采样。对于一个完整的答案,token有很强的偶然性,而且对每个token的更新都要单独做截断,会导致有的更新到位,有的不到位,学到一些奇怪的东西”
做个对比
后面出现的在GRPO的基础上改进的算法更多了,索性都加进来吧
一、GRPO:基于Token独立加权的梯度
GRPO 的核心问题在于它如何计算策略梯度。
- 目标函数: 我们先看 GRPO 的优化目标:
- 关键项定义:
- 优势 (Advantage) :对于序列 中的所有Token,该值是恒定的。
- 重要性权重 (Importance Weight) :这是Token级别的,每个Token都有自己独立的权重 。
- 梯度分析 (核心问题所在): 对目标函数求导,我们得到 GRPO 的梯度: 公式解读:
- 是Token 的标准策略梯度,代表了模型参数应该移动的方向。
- 关键在于:每个Token的梯度方向,都被其各自的、独立的重要性权重 进行了缩放。
- 后果:由于 是基于单个Token的采样计算的,其值波动极大,充满了“偶然性”。这意味着在一个序列内部,Token 1 的梯度可能被放大 1.2 倍,Token 2 的梯度可能被缩小到 0.8 倍,Token 3 的梯度又被放大 1.5 倍。这种在序列内部不一致的、嘈杂的加权,导致了梯度的巨大方差,最终使训练过程不稳定,甚至导致模型崩溃。
二、GSPO:基于序列统一加权的梯度
GSPO 通过在序列层面进行加权,从根本上解决了这个问题。
- 目标函数: GSPO 的优化目标如下:
- 关键项定义:
- 优势 (Advantage) :与 GRPO 相同,是序列级别的。
- 重要性权重 (Importance Weight) :这是序列级别的,一个序列共享一个权重。
- 梯度分析 (解决方案): 对 GSPO 的目标函数求导: 公式解读:
- 是整个序列中所有Token梯度的总和。
- 关键在于:这个“总梯度”被一个统一的、序列级别的权重 进行缩放。
- 后果:序列 中的所有Token的梯度更新方向是一致的。权重 作为一个整体的调节因子,判断整个序列的“偏离程度”,然后统一地放大或缩小整个序列的更新步长。这种做法消除了 GRPO 中由于Token间权重不一致而引入的内部方差,从而“消除了这个不稳定性因素”。
2.1 GSPO在计算重要性权重的时候为什么用几何平均值,不用算数平均?
特性 | 算术平均 (Arithmetic Mean) | 几何平均 (Geometric Mean) |
公式 | ||
核心逻辑 | 加法逻辑:假设各项是独立的、平等的累加关系。 | 乘法逻辑:假设各项是连续的、相互影响的倍数关系。 |
物理含义 | 代表“总和的平均分摊”。 | 代表“总增长率的平均”。 |
对极端值的敏感度 | 受大数影响大:一个极大的异常值会显著拉高平均值。 | 受小数(接近0)影响大:任何一个项接近0,结果就会急剧下降。 |
- 联合概率的乘法本质 (The Multiplicative Nature of Joint Probability)
- 几何平均:,这保留了概率作为乘法过程的物理意义。它代表了“平均每个 Token 的变化倍数”。
- 算术平均:,这假设各个 Token 之间是独立的加法关系,这违反了序列生成的概率定义。
- 反例:如果使用算术平均,模型会错误地认为“只要大部分 Token 概率变大,哪怕中间有一个关键 Token 概率变成了 0(导致整个句子逻辑崩塌),整体权重依然很高”,这显然是错误的。几何平均在这种情况下会趋近于 0,正确反映了序列的整体质量下降。
这是最根本的原因。大语言模型是自回归模型,一个序列 的联合概率 是所有 Token 条件概率的连乘积,而不是和:
因此,序列整体的概率比率(Probability Ratio)也是一个连乘积:
如果我们想把这个序列级别的比率“归一化”或者“平摊”到每个 Token 的尺度上,数学上正确的逆操作是开 次方根(即几何平均)。
- 与“平均对数似然”的数学等价性
在 LLM 训练中,我们通常优化的是对数似然(Log-Likelihood)。对数函数将乘法转化为加法。请看以下推导:
如果我们计算 Token 级别比率的对数的算术平均值(即平均 Token 层面的变化量):
这清楚地表明:概率比率的几何平均值,直接对应于“对数空间下的算术平均值”。
这意味着 代表了模型在该序列上平均每个 Token 的概率变化幅度。这使得不同长度的序列(长序列和短序列)的权重具有了可比性,实现了长度归一化(Length Normalization),防止长序列对 loss 的贡献被不合理地放大或缩小。
- 数值稳定性与方差控制
- 若每个 Token 的比率略大于 1(如 1.1),长度为 100 的序列,总比率就是 。
- 若每个 Token 的比率略小于 1(如 0.9),长度为 100 的序列,总比率就是 。
如果直接使用序列的总比率 ,数值会随着序列长度呈指数级爆炸或消失(Exploding/Vanishing)。
这种巨大的数值波动会导致梯度估计极不稳定,甚至导致浮点数溢出。
通过取几何平均值(开 次方),我们将数值拉回到了 附近的常数尺度(即单 Token 尺度的变化率),从而极大地降低了估计量的方差,保证了训练的稳定性。
- 总结
GSPO 使用几何平均值是因为:LLM 生成序列是概率的连乘过程,只有几何平均能够正确地将“序列级的概率比率”映射回“平均单 Token 级的概率比率”,同时这种方式在对数空间上等价于算术平均,确保了不同长度序列权重的数值稳定性和公平性。
三、SAPO:基于软自适应门控的梯度
三、SAPO:软自适应策略优化 (Soft Adaptive Policy Optimization)
SAPO 通过引入平滑的门控机制(Soft Gating)和非对称温度控制(Asymmetric Temperature),在保留 Token 级自适应能力的同时,实现了序列级的连贯性,解决了硬截断(Hard Clipping)带来的梯度消失和训练不稳定问题。
3.1 目标函数
SAPO 摒弃了传统的
min/max 硬截断,而是最大化以下目标函数:关键项定义
- 门控函数 (Gating Function) : 这是一个基于 Sigmoid 的平滑函数,用于替代 PPO/GRPO 中的硬截断: 其中 是 Sigmoid 函数。
- 非对称温度 (Asymmetric Temperature) : SAPO 针对正负优势样本采用不同的温度参数,这是为了应对负样本梯度带来的不稳定性: 通常设置 (例如 ),意味着对负样本施加更强的衰减。
3.2 梯度分析 (核心创新)
对 SAPO 目标函数求导,得到加权的对数策略梯度:
权重计算公式:
3.3 公式解读与机制分析
- 连续的信任域 (Continuous Trust Region):
- 权重函数 在 (On-policy)处达到峰值 1,随着比率 偏离 1,权重呈“钟形曲线”平滑衰减,而不是像 GRPO 那样在某个阈值突然变为 0。
- 后果:这保留了那些稍微偏离策略但仍具信息量的样本的梯度信号,同时平滑地抑制了严重偏离策略的噪声,避免了硬截断导致的“梯度消失”或“全有全无”的脆性更新。
- 序列连贯性与Token自适应的统一 (Sequence-Coherent & Token-Adaptive):
- 近似 GSPO:在策略更新步幅较小且序列内方差较低(常见情况)时,SAPO 的平均 Token 门控近似于一个序列级的门控:。这意味着它自然地继承了 GSPO 的序列一致性优势。
- 优于 GSPO:当序列中仅有少数“害群之马”(极度 Off-policy 的 Token)时,GSPO 会因为序列级截断而丢弃整个序列的学习机会;而 SAPO 能够精准打击,仅降低那些异常 Token 的权重,同时保留序列中其他正常 Token 的学习信号,从而提高了样本效率。
- 非对称温度的物理意义:
- 正优势 ():增加采样 Token 的概率,降低未采样 Token 的概率。
- 负优势 ():降低采样 Token 的概率,但这会导致词表中成千上万个其他未采样 Token (Irrelevant Tokens) 的概率被动升高。
- 后果:负梯度的扩散效应更容易引入不稳定性。通过设置 ,SAPO 让负样本的权重 衰减得更快,从而更严格地限制负更新带来的破坏性,显著防止了训练早期的模型坍塌(Collapse)。
- 作者:SimonSun
- 链接:https://simonsun.xyz//article/llm-14
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章
.png?table=collection&id=cb472e47-cf59-4081-bd5f-899a844344db&t=cb472e47-cf59-4081-bd5f-899a844344db)


