Gumbel-Softmax

背景:离散决策(one-hot 选择)是不可导的,不能直接在神经网络里反向传播梯度。

解决方法: 用 Gumbel-Softmax trick

Gumbel-Softmax

普通 softmax:

  • 各个符号的含义

    • logitsi
      模型输出的第 i 类的”原始分数”,还没有经过归一化。

    • τ (temperature, 温度参数)
      控制分布的”平滑程度”。当 τ=1 时,就是普通的 softmax。τ 越大,分布越平滑,概率越接近均匀分布。

    • exp(·)
      指数函数,把分数变成正数。

    • 分母
      对所有类别的指数和求和,用来做归一化,保证结果是概率分布。

    • 结果 pi
      这是第 i 类的概率,满足 $\sum_i p_i = 1$。


Gumbel-Softmax 的公式:

在 logits 上加上 Gumbel 噪声,再 softmax,有两种模式:

1. hard = False (连续模式)

  • 输出一个概率向量(类似 soft one-hot)
  • 比如:[0.7, 0.2, 0.1]
  • 这个输出是连续的,可以直接反向传播梯度
  • 好处:完全可导,可以用作常规训练
  • 缺点:不是真正的离散化,模型真正使用时可能表现不同

2. hard = True (离散模式)

  • 在前向传播时,会输出”硬化”为 one-hot,比如上面的数据会变成:[1, 0, 0]
  • 这样看起来像真正的离散选择
  • 但在反向传播时,仍然是按照 soft 版本的梯度(就是那个 [0.7, 0.2, 0.1])- 这叫 straight-through estimator
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

import torch
import torch.nn.functional as F

# 假设 logits 来自一个 3 类分类器
logits = torch.tensor([[2.0, 1.0, 0.1]], requires_grad=True)

tau = 1.0 # 温度参数

# ---------------------------
# 1. hard=False (软 one-hot,连续概率分布)
# ---------------------------
y_soft = F.gumbel_softmax(logits, tau=tau, hard=False)
print("hard=False 输出:", y_soft)

loss_soft = y_soft.sum() # 简单构造一个损失
loss_soft.backward()
print("hard=False 梯度:", logits.grad)

# 梯度清零
logits.grad.zero_()

# ---------------------------
# 2. hard=True (前向 one-hot,但梯度来自 soft 版本)
# ---------------------------
y_hard = F.gumbel_softmax(logits, tau=tau, hard=True)
print("\nhard=True 输出:", y_hard)

loss_hard = y_hard.sum()
loss_hard.backward()
print("hard=True 梯度:", logits.grad)

# hard=False 输出: tensor([[0.6591, 0.2479, 0.0930]], grad_fn=<GumbelSoftmaxBackward>)
# hard=False 梯度: tensor([[0., 0., 0.]])

# hard=True 输出: tensor([[1., 0., 0.]], grad_fn=<GumbelSoftmaxBackward>)
# hard=True 梯度: tensor([[0., 0., 0.]])