背景:离散决策(one-hot 选择)是不可导的,不能直接在神经网络里反向传播梯度。
解决方法: 用 Gumbel-Softmax trick
Gumbel-Softmax
普通 softmax:
各个符号的含义
logitsi:
模型输出的第 i 类的”原始分数”,还没有经过归一化。τ (temperature, 温度参数):
控制分布的”平滑程度”。当 τ=1 时,就是普通的 softmax。τ 越大,分布越平滑,概率越接近均匀分布。exp(·):
指数函数,把分数变成正数。分母:
对所有类别的指数和求和,用来做归一化,保证结果是概率分布。结果 pi:
这是第 i 类的概率,满足 $\sum_i p_i = 1$。
Gumbel-Softmax 的公式:
在 logits 上加上 Gumbel 噪声,再 softmax,有两种模式:
1. hard = False (连续模式)
- 输出一个概率向量(类似 soft one-hot)
- 比如:
[0.7, 0.2, 0.1] - 这个输出是连续的,可以直接反向传播梯度
- 好处:完全可导,可以用作常规训练
- 缺点:不是真正的离散化,模型真正使用时可能表现不同
2. hard = True (离散模式)
- 在前向传播时,会输出”硬化”为 one-hot,比如上面的数据会变成:
[1, 0, 0] - 这样看起来像真正的离散选择
- 但在反向传播时,仍然是按照 soft 版本的梯度(就是那个
[0.7, 0.2, 0.1])- 这叫 straight-through estimator
1 |
|