重参数化技巧
反向传播和可导性的基础¶
反向传播是神经网络训练中用来优化参数的核心算法. 它通过计算损失函数对网络参数的梯度, 利用链式法则将梯度从输出层逐步传递到输入层. 要实现这一点, 网络中的每一步操作必须是可导的, 也就是说, 每个操作都可以定义一个明确的导数. 如果某个操作不可导, 梯度就无法通过这个操作传播, 反向传播就会中断. 在数学上, 一个函数\(f(x)\)是可导的, 前提是对于输入\(x\)的微小变化, 输出\(f(x)\)的变化是平滑且确定的. 然而, 如果某个操作的结果是随机的或不连续的, 那么它的导数就无法被定义.
采样过程的随机性¶
现在来看从一个分布中采样的过程. 以正态分布为例, 采样可以表示为:
其中, \( \mu \) 是均值, \( \sigma^2 \) 是方差, \( z \) 是采样的结果. 关键在于, 这个过程是随机的. 即使 \( \mu \) 和 \( \sigma^2 \) 是固定的, 每次采样得到的 \( z \) 值都会不同. 这种随机性意味着采样操作不是一个确定的函数, 而是一个概率过程.
例如, 假设 \( \mu = 0 \) 且 \( \sigma^2 = 1 \), 某次采样可能得到 \( z = 0.5 \), 下一次可能是 \( z = -1.2 \). 这种输出值的跳跃性使得我们无法定义一个稳定的 \( \frac{\partial z}{\partial \mu} \) 或 \( \frac{\partial z}{\partial \sigma} \).
为什么不可导¶
在神经网络中, 假设有一个损失函数 \( L \) 依赖于采样值 \( z \), 即 \( L = L(z) \). 为了通过反向传播优化参数 \( \mu \) 和 \( \sigma \), 我们需要计算梯度 \( \frac{\partial L}{\partial \mu} \) 和 \( \frac{\partial L}{\partial \sigma} \). 根据链式法则, 这需要知道 \( \frac{\partial z}{\partial \mu} \) 和 \( \frac{\partial z}{\partial \sigma} \). 然而, 由于 \( z \) 是随机采样的结果, 它的变化不随 \( \mu \) 或 \( \sigma \) 平滑变化, 因此这些偏导数不存在, 或者说不是确定的值.
为了更直观地理解:
- 如果分布是离散的(例如二项分布), 采样值会在离散点之间跳跃, 显然不连续, 也不可导.
- 即使是连续分布(如正态分布), 采样过程的随机性仍然会导致 \( z \) 的值没有确定的变化规律, 无法计算导数.
简单来说, 采样的输出 \( z \) 对输入参数 \( \mu \) 和 \( \sigma \) 的依赖关系不是一个平滑的函数, 而是依赖于随机性, 这违反了可导性的要求.
如何解决这个问题¶
尽管直接采样不可导, 但在神经网络中, 我们可以通过重参数化技巧绕过这个问题. 以正态分布为例, 我们可以将采样过程改写为:
其中, \( \epsilon \sim \mathcal{N}(0, 1) \) 是从标准正态分布中采样的随机噪声, 但我们将其视为一个固定的值(在每次前向传播中采样一次后保持不变). 这样, \( z \) 就变成了 \( \mu \) 和 \( \sigma \) 的一个确定函数, 其导数可以计算为:
有了这些导数, 梯度就可以通过链式法则传播, 从而实现反向传播:
这种方法将随机性从梯度计算中剥离, 使得整个过程可导.
Gumbel-softmax¶
离散采样的不可微性问题¶
假设我们有一个离散分布, 比如通过 softmax 或 log-softmax 得到的类别概率 \(\pi_i\). 在某些任务中(例如强化学习, 生成模型或变分自编码器), 我们需要从这个分布中采样一个离散值, 比如选择一个类别 \(y\).
这就会导致一个问题: 直接采样(例如 \(y \sim \text{Categorical}(\pi)\))是一个离散的随机过程, 无法定义梯度. 这意味着反向传播无法通过采样步骤传播梯度, 导致模型无法端到端地训练.
Gumbel-softmax的作用¶
Gumbel-softmax 提供了一种解决方案: 它通过引入 Gumbel 噪声, 将离散采样过程转化为一个连续的, 可微的近似. 这样, 网络可以在训练时使用这个连续近似来计算梯度, 而在推理时仍可以使用离散采样.
核心思想是在 log-softmax 的对数概率基础上加入 Gumbel 噪声, 生成一个"软化"的概率分布, 这个分布既接近离散采样的结果, 又是可微的.
Gumbel-softmax的工作原理¶
具体来说, Gumbel-softmax 的操作步骤如下:
-
从 log-softmax 获取对数概率:
假设网络输出了一组 logits \(\mathbf{z}\), 通过 log-softmax 得到每个类别的对数概率 \(\log \pi_i\).
-
加入 Gumbel 噪声:
从 Gumbel(0,1) 分布中采样独立噪声 \(g_i\), 将其加到对数概率上, 得到"扰动后的 logits":
\[ \tilde{z}_i = \log \pi_i + g_i \] -
应用 softmax 变换:
对扰动后的 logits \(\tilde{z}_i\) 应用 softmax 函数, 生成一个连续的概率分布 \(\tilde{y}\). 公式如下:
\[ \tilde{y}_i = \frac{\exp(\tilde{z}_i / \tau)}{\sum_{j} \exp(\tilde{z}_j / \tau)} \]其中 \(\tau\) 是温度参数, 用于控制输出的"软硬"程度:
- 当 \(\tau\) 很小时(\(\tau \to 0\)), \(\tilde{y}\) 趋近于 one-hot 向量, 接近离散采样的结果.
- 当 \(\tau\) 较大时, \(\tilde{y}\) 更平滑, 更像普通的 softmax 输出.
通过这个过程, Gumbel-softmax 生成了一个可微的连续输出 \(\tilde{y}\), 它既能近似离散采样的行为, 又允许梯度通过这个近似传播, 从而优化模型参数.
为什么要在 log-softmax 后面加 Gumbel-softmax?¶
- 解决不可微性:
Log-softmax 本身是可微的, 可以直接用于连续概率输出的任务(例如分类). 但如果任务需要从离散分布中采样, 单纯的 log-softmax 不够, 因为采样步骤不可微. Gumbel-softmax 通过引入噪声和 softmax 变换, 提供了一个可微的替代方案. - 支持端到端训练:
在需要离散随机变量的场景中(比如生成离散序列或选择离散动作), Gumbel-softmax 使得整个模型可以通过梯度下降进行端到端的优化. - 灵活性:
在训练时, 可以使用 Gumbel-softmax 的连续近似; 在推理时, 可以切换到真正的离散采样. 这种灵活性非常适合需要在训练和推理之间平衡的任务.
温度参数 \(\tau\) 的作用¶
温度参数 \(\tau\) 是 Gumbel-softmax 的一个关键控制因素:
- 小的 \(\tau\): 输出更接近离散采样(one-hot 向量), 但可能导致梯度稀疏, 训练不稳定.
- 大的 \(\tau\): 输出更平滑, 梯度更稳定, 但偏离离散采样的目标.
在实践中, 通常从较大的 \(\tau\) 开始, 随着训练进行逐渐减小 \(\tau\), 以在探索和收敛之间取得平衡.