Skip to content

重参数化技巧

反向传播和可导性的基础

反向传播是神经网络训练中用来优化参数的核心算法. 它通过计算损失函数对网络参数的梯度, 利用链式法则将梯度从输出层逐步传递到输入层. 要实现这一点, 网络中的每一步操作必须是可导的, 也就是说, 每个操作都可以定义一个明确的导数. 如果某个操作不可导, 梯度就无法通过这个操作传播, 反向传播就会中断. 在数学上, 一个函数\(f(x)\)是可导的, 前提是对于输入\(x\)的微小变化, 输出\(f(x)\)的变化是平滑且确定的. 然而, 如果某个操作的结果是随机的或不连续的, 那么它的导数就无法被定义.

采样过程的随机性

现在来看从一个分布中采样的过程. 以正态分布为例, 采样可以表示为:

\[ z \sim \mathcal{N}(\mu, \sigma^2) \]

其中, \( \mu \) 是均值, \( \sigma^2 \) 是方差, \( z \) 是采样的结果. 关键在于, 这个过程是随机的. 即使 \( \mu \)\( \sigma^2 \) 是固定的, 每次采样得到的 \( z \) 值都会不同. 这种随机性意味着采样操作不是一个确定的函数, 而是一个概率过程.

例如, 假设 \( \mu = 0 \)\( \sigma^2 = 1 \), 某次采样可能得到 \( z = 0.5 \), 下一次可能是 \( z = -1.2 \). 这种输出值的跳跃性使得我们无法定义一个稳定的 \( \frac{\partial z}{\partial \mu} \)\( \frac{\partial z}{\partial \sigma} \).

为什么不可导

在神经网络中, 假设有一个损失函数 \( L \) 依赖于采样值 \( z \), 即 \( L = L(z) \). 为了通过反向传播优化参数 \( \mu \)\( \sigma \), 我们需要计算梯度 \( \frac{\partial L}{\partial \mu} \)\( \frac{\partial L}{\partial \sigma} \). 根据链式法则, 这需要知道 \( \frac{\partial z}{\partial \mu} \)\( \frac{\partial z}{\partial \sigma} \). 然而, 由于 \( z \) 是随机采样的结果, 它的变化不随 \( \mu \)\( \sigma \) 平滑变化, 因此这些偏导数不存在, 或者说不是确定的值.

为了更直观地理解:

  • 如果分布是离散的(例如二项分布), 采样值会在离散点之间跳跃, 显然不连续, 也不可导.
  • 即使是连续分布(如正态分布), 采样过程的随机性仍然会导致 \( z \) 的值没有确定的变化规律, 无法计算导数.

简单来说, 采样的输出 \( z \) 对输入参数 \( \mu \)\( \sigma \) 的依赖关系不是一个平滑的函数, 而是依赖于随机性, 这违反了可导性的要求.

如何解决这个问题

尽管直接采样不可导, 但在神经网络中, 我们可以通过重参数化技巧绕过这个问题. 以正态分布为例, 我们可以将采样过程改写为:

\[ z = \mu + \sigma \epsilon \]

其中, \( \epsilon \sim \mathcal{N}(0, 1) \) 是从标准正态分布中采样的随机噪声, 但我们将其视为一个固定的值(在每次前向传播中采样一次后保持不变). 这样, \( z \) 就变成了 \( \mu \)\( \sigma \) 的一个确定函数, 其导数可以计算为:

\[ \frac{\partial z}{\partial \mu} = 1 \]
\[ \frac{\partial z}{\partial \sigma} = \epsilon \]

有了这些导数, 梯度就可以通过链式法则传播, 从而实现反向传播:

\[ \frac{\partial L}{\partial \mu} = \frac{\partial L}{\partial z} \cdot 1 \]
\[ \frac{\partial L}{\partial \sigma} = \frac{\partial L}{\partial z} \cdot \epsilon \]

这种方法将随机性从梯度计算中剥离, 使得整个过程可导.

Gumbel-softmax

离散采样的不可微性问题

假设我们有一个离散分布, 比如通过 softmax 或 log-softmax 得到的类别概率 \(\pi_i\). 在某些任务中(例如强化学习, 生成模型或变分自编码器), 我们需要从这个分布中采样一个离散值, 比如选择一个类别 \(y\).

这就会导致一个问题: 直接采样(例如 \(y \sim \text{Categorical}(\pi)\))是一个离散的随机过程, 无法定义梯度. 这意味着反向传播无法通过采样步骤传播梯度, 导致模型无法端到端地训练.

Gumbel-softmax的作用

Gumbel-softmax 提供了一种解决方案: 它通过引入 Gumbel 噪声, 将离散采样过程转化为一个连续的, 可微的近似. 这样, 网络可以在训练时使用这个连续近似来计算梯度, 而在推理时仍可以使用离散采样.

核心思想是在 log-softmax 的对数概率基础上加入 Gumbel 噪声, 生成一个"软化"的概率分布, 这个分布既接近离散采样的结果, 又是可微的.

Gumbel-softmax的工作原理

具体来说, Gumbel-softmax 的操作步骤如下:

  1. 从 log-softmax 获取对数概率:

    假设网络输出了一组 logits \(\mathbf{z}\), 通过 log-softmax 得到每个类别的对数概率 \(\log \pi_i\).

  2. 加入 Gumbel 噪声:

    从 Gumbel(0,1) 分布中采样独立噪声 \(g_i\), 将其加到对数概率上, 得到"扰动后的 logits":

    \[ \tilde{z}_i = \log \pi_i + g_i \]
  3. 应用 softmax 变换:

    对扰动后的 logits \(\tilde{z}_i\) 应用 softmax 函数, 生成一个连续的概率分布 \(\tilde{y}\). 公式如下:

    \[ \tilde{y}_i = \frac{\exp(\tilde{z}_i / \tau)}{\sum_{j} \exp(\tilde{z}_j / \tau)} \]

    其中 \(\tau\) 是温度参数, 用于控制输出的"软硬"程度:

    • \(\tau\) 很小时(\(\tau \to 0\)), \(\tilde{y}\) 趋近于 one-hot 向量, 接近离散采样的结果.
    • \(\tau\) 较大时, \(\tilde{y}\) 更平滑, 更像普通的 softmax 输出.

通过这个过程, Gumbel-softmax 生成了一个可微的连续输出 \(\tilde{y}\), 它既能近似离散采样的行为, 又允许梯度通过这个近似传播, 从而优化模型参数.


为什么要在 log-softmax 后面加 Gumbel-softmax?

  • 解决不可微性:
    Log-softmax 本身是可微的, 可以直接用于连续概率输出的任务(例如分类). 但如果任务需要从离散分布中采样, 单纯的 log-softmax 不够, 因为采样步骤不可微. Gumbel-softmax 通过引入噪声和 softmax 变换, 提供了一个可微的替代方案.
  • 支持端到端训练:
    在需要离散随机变量的场景中(比如生成离散序列或选择离散动作), Gumbel-softmax 使得整个模型可以通过梯度下降进行端到端的优化.
  • 灵活性:
    在训练时, 可以使用 Gumbel-softmax 的连续近似; 在推理时, 可以切换到真正的离散采样. 这种灵活性非常适合需要在训练和推理之间平衡的任务.

温度参数 \(\tau\) 的作用

温度参数 \(\tau\) 是 Gumbel-softmax 的一个关键控制因素:

  • 小的 \(\tau\): 输出更接近离散采样(one-hot 向量), 但可能导致梯度稀疏, 训练不稳定.
  • 大的 \(\tau\): 输出更平滑, 梯度更稳定, 但偏离离散采样的目标.

在实践中, 通常从较大的 \(\tau\) 开始, 随着训练进行逐渐减小 \(\tau\), 以在探索和收敛之间取得平衡.