VQ-VAE1¶
离散空间¶
VAE和VQ-VAE的根本区别在于VAE学习连续的潜在表示, 而VQ-VAE学习离散的潜在表示. 一般来说, 我们在现实世界中遇到的许多数据都倾向于离散表示. 例如, 人类语音可以由离散的因素和语言很好地表示. 此外, 图像还包含具有一些离散限定字符集的离散对象. 可以想象用一个离散变量表示对象类型, 一个用于表示其颜色, 一个用于表示其大小, 一个用于表示其方向, 一个用于表示其形状, 一个用于表示其纹理, 一个用于表示背景颜色, 一个用于表示背景纹理, 等等.. 除了表示之外, 还有许多算法, 例如transformer旨在处理离散数据, 因此我们希望有一个离散的数据表示提供这些算法使用.
我们如何学习离散表示呢? 乍一看, 这似乎非常具有挑战性, 因为一半来说, 离散的东西在深度学习中并不太好做. 幸运的是, VQ-VAE设法使深度学习为这项任务工作, 只需要对原始的自编码器做一些调整.
量化自编码器¶
VQ-VAE通过向网络添加离散的codebook组件来扩展标准自编码器. codebook是与之相应索引关联的向量列表. 将编码器网络的输出和codebook的所有向量进行比较, 并将欧氏距离最近的codebook向量喂给解码器. 数学上可以写成\(z_q(x)=\argmin _i||z_e(x)-e_i||_2\), 其中, \(z_e(x)\)是原始输入的encoder向量, 比如\(i\)表示第\(i\)个codebook向量, \(z_q(x)\)表示生成的量化矢量, 作为输入传递给解码器. 这个\(\argmin\)的操作有点让人担忧, 因为它无法传递梯度. 为了解决这个问题, 在反向传播的时候, 会对那个被选中的codebook向量直接设定为\(1\), 即\(\frac{\partial z_q(x)}{z_e(x)}\simeq 1\), 而其他的codebook中的向量其梯度为\(0\), 即\(\frac{\partial z_q(x)}{z_{\neq e(x)}}=0\). 然后, 解码器的任务就是重构来自量化矢量的输入, 就像在标准的VAE公式中的那样.
矢量网格¶
可能感到困惑的是, 当解码器只能接受一组codebook向量作为输入的时候, 人们怎么能指望它产生大量多样化的图像呢? 我们需要为美俄训练点提供一个唯一的离散值, 以便能够重建所有的数据. 如果情况确实如此, 那么模型难道不会通过将每个训练点映射到不同的离散的code来记住数据吗?
如果编码器只输出一个矢量, 这的确会成为问题, 但是在实际的VQ-VAE中, 编码器通常会产生一系列矢量. 例如, 对于图像, 编码器可能会输出一个32*32的矢量网格, 网格中每个位置都有一个向量要被量化. 虽然所有这些向量都使用同一个codebook, 但是由于网格的尺寸很大, 每个位置都可以独立选择它在codebook中对应的向量, 最终会得到一个庞大的组合空间. 例如, 假设我们正在处理图像, 我们有一个尺寸为512的码本, 我们的编码器输出一个32*32的矢量网格. 那么, 可以输出\(512^{32\times 32}\)个不同的图像. 可以看到, 这个数字是非常庞大的.
当然, 模型仍然可以记住数据, 但是通过在编码器中嵌入正确的归纳偏置(卷积网络的那套平移不变性...)和使用合理的因空间结构, 如上面的创建一个矢量网格, 模型应该能够学习到一个很好地表示数据的离散空间.
学习码本¶
就像编码器和解码器网络一样, codebook是通过梯度下降来学习的. 理想情况下, 我们的编码器将输入一个接近学习到的codebook向量. 这个码表并不是固定不变的, 也像网络中的其他部分一样, 可以通过训练不断更新其中的向量. 这是一个双向学习的过程: 1) codebook的更新: codebook需要学习一组向量, 使他们能够覆盖或者代表encoder输出的各种可能情况; 2) encoder也在学习产生出更容易被codebook表征的向量. 这样就形成了一个相互依赖, 共同演化的过程: 如果codebook改变, encoder就要跟着做出相应的调整; 如果encoder改变, codebook也要跟着微调. 这两个问题可以通过向损失函数添加项来解决. 整个VQ-VAE的损失函数是:
在这里, 我们使用的是和上一小节相同的符号, \(\text{sg}[x]\)代表"停止梯度". 第一项是标准的重构损失, 第二项是codebook对齐损失, 其目标是使得所选的codebook矢量尽可能接近编码器的输出, 编码器输出有一个停止梯度运算符, 因为这项仅用于更新codebook. 第三项和第二项类似, 但是它将停止梯度放在codebook向量上, 因为它旨在更新编码器输出, 让其尽可能接近codebook向量. 这项称为codebook损失, 其对总体损失的重要性由超参数\(\beta\)调整. 当然, 如果有多个, 例如刚才的那个矢量网格, 则最后两项在模型的每个量化向量的输出上取平均值.
这样, 我们就可以完整地训练一个VQ-VAE, 能够重组一组不同的图像. 我们还可以训练VQ-VAE来重构其他模态, 如音频或者视频.
-
深入理解 VQ-VAE. (2022, June 2). Sunlin-ai. https://sunlin-ai.github.io/2022/06/02/VQ-VAE.html ↩