跳转至

Gumbel-Softmax

动机

类别分布\(\boldsymbol\pi\)需采样one-hot, 但\(\arg\max\)不可微. Gumbel-Softmax用可微近似解决梯度回传, 思想类似reparameterization trick.

Gumbel-Max采样

\[ y=\arg\max_i\bigl(\log\pi_i+g_i\bigr),\;g_i\sim\text{Gumbel}(0,1) \]

结果与\(\text{Cat}(\boldsymbol\pi)\)同分布, 但依旧不可微.

温度化连续近似

\[ \tilde y_i=\frac{\exp\!\bigl((\log\pi_i+g_i)/\tau\bigr)}{\sum_j\exp\!\bigl((\log\pi_j+g_j)/\tau\bigr)} \]

\(\tau\to0\)趋近one-hot, \(\tau\to\infty\)趋近均匀; 退火训练先大后小兼顾稳定与离散化.

直通(ST)版

  • 前向: 先算\(\tilde{\mathbf y}\), 再\(\mathbf y=\text{one\_hot}(\arg\max\tilde{\mathbf y})\)得到真正离散token.
  • 反向: 梯度仅借\(\tilde{\mathbf y}\)传递, 可微且高效.

训练-推断一致性

前向已含Gumbel噪声→随机采样, 与推断阶段按\(p(k\mid\text{context})\)抽样同分布; 避开"无噪声硬argmax"引起的失配与码本坍缩.