笔记|扩散模型(一八):Flow Matching 理论详解
在 Stable Diffusion 3 中,模型是通过 Flow Matching 的方法训练的。从这个方法的名字来看,就知道它和 Flow-based Model 有比较强的关联,因此在正式开始介绍这个方法之前先交代一些 Flow-based Model 相关的背景知识。
Flow-based Models
Normalizing Flow
Normalizing Flow 是一种基于变换对概率分布进行建模的模型,其通过一系列离散且可逆的变换实现任意分布与先验分布(例如标准高斯分布)之间的相互转换。在 Normalizing Flow 训练完成后,就可以直接从高斯分布中进行采样,并通过逆变换得到原始分布中的样本,实现生成的过程。(有关 Normalizing Flow 的详细理论介绍可以移步我的这篇文章观看)
从这个角度看,Normalizing Flow 和 Diffusion Model 是有一些相通的,其做法的对比如下表所示。从表中可以看到,两者大致的过程是非常类似的,尽管依然有些地方不一样,但这两者应该可以通过一定的方法得到一个比较统一的表示。
模型 | 前向过程 | 反向过程 |
---|---|---|
Normalizing Flow | 通过显式的可学习变换将样本分布变换为标准高斯分布 | 从标准高斯分布采样,并通过上述变换的逆变换得到生成的样本 |
Diffusion Model | 通过不可学习的 schedule 对样本进行加噪,多次加噪变换为标准高斯分布 | 从标准高斯分布采样,通过模型隐式地学习反向过程的噪声,去噪得到生成样本 |
Continuous Normalizing Flow
Continuous Normalizing Flow(CNF),也就是连续标准化流,可以看作
Normalizing Flow 的一般形式。CNF 将原本 Normalizing Flow
中离散的变换替换为连续的变换,并用常微分方程(ODE)来表示,可以写成以下的形式:
在 Normalizing Flow 中存在 Change of Variable
Theory,这个定理是用来保证概率分布在进行变化时,概率密度在全体分布上的积分始终为
1 的一个式子(具体解释可以看上边给出的那篇 Normalizing Flow
的文章),其形式为:
连续性方程
概率分布在向量场中进行变换这一过程可以用物理学中的传输行为来建模。这是因为不管概率分布如何变换,其在全体分布上的积分始终为
1,因此可以认为概率密度也是一个守恒的物理量,可以类比物理学中的质量、电荷等的传输行为进行建模。这个建模方式就是连续性方程,其在物理学中定义如下:
类比到概率分布,这个方程可以写成:
在讲解 Score-based Model 的文章中,我们用随机微分方程(SDE)统一了 SMLD(Score Matching with Langevin Dynamics)和 DDPM,并且将 SDE 转化为了 ODE 概率流。也就是说,扩散模型同样能够用一个 ODE 来表示,因此,扩散模型也应当能够利用 CNF 的训练方式进行训练,这个训练的方式就是 Flow Matching。
Flow Matching
符号定义
在正式开始介绍之前我们先介绍一下各个概念以及符号定义。借用一下之前介绍 SDE 时的一张图,如下所示。在 Flow Matching 中存在以下几个概念:
- 数据分布
、 、 :这个不用多解释,不过需要注意下标定义为 0 是标准高斯分布、1 是样本,这个定义和 DDPM 是相反的,需要注意 - Flow
或 :这个也就是对分布进行变换的操作,例如 - 向量场
:这个相当于下图中的青色箭头,样本沿着箭头的方向传输 - 概率路径
:这个相当于下图中的浅绿色曲线
在实际上进行训练时,神经网络建模的是向量场
概述
Flow Matching 的训练目标和 Score Matching
是比较类似的,学习的目标就是通过学习拟合一个向量场
不过实际上这个公式并不实用,首先能够满足
从条件概率路径和向量场构造
虽然我们不知道
定理一:给定向量场
证明:首先,对于边缘概率路径
Conditional Flow Matching
虽然基于上述过程已经推导出了
定理二:假定对于所有
证明:首先把两个二次项都展开,然后证明右侧是相等的。注意,虽然右侧都有
如此即证明了上述的定理。这样,我们的训练就不再依赖于一个抽象的边缘向量场,而是依赖于
条件概率路径和向量场
上面我们已经证明了条件概率路径和条件向量场可以等价于边缘概率路径和边缘向量场,并且用
CFM 的方式进行训练和 Flow Matching 的效果是相同的。但现在
作者给出的条件概率路径的形式为:
同时,作者将 flow 定义为以下形式:
定理三:令
讨论
Flow Matching 定义了一种特定形式的高斯概率路径,当选择不同的均值和方差时,有几种特殊的情况:
- Variance Exploding:
,其中 、 ,并且 是递增函数, 、 。这种过程能够使模型生成数据时探索范围更广的空间,有助于生成多样的样本。 - Variance Preserving:
,其中 、 。这种过程在引入噪声的同时保持整体方差不变,这样能使数据的分布比较稳定。(可以看出 DDPM 就是这种过程) - Optimal Transport Conditional: 定义均值和方差为
、 。可以求得最优传输路径是直线,因此可以更快地训练和采样。(这个比较类似于 Rectified Flow)
总结
Flow Matching 的确理论性比较强,不是特别好理解。概括来说主要是给出了一种用来训练 CNF 的方法,并且提出了三个定理分别用来解决 flow 的表示问题、loss 函数的设计问题以及具体实现方式的问题。同时 flow matching 也统一了 score matching 和 DDPM,非常巧妙。(学到这里终于快要把 stable diffusion 3 的拼图拼完了,真不容易)
参考资料: