扩散模型的步数蒸馏,指的是在教师模型的指导下,学生模型学会用更少的步数(对应inference steps)生成相似质量的图片。
渐进式蒸馏 Progressive Distillation
渐进式蒸馏方法可以说是最典型的一个步数蒸馏的方法了。
想象你要教一个新手画家(学生模型)快速画画。
原本的画法是:老师(扩散模型)需要画100笔,每一笔都慢慢修正细节(对应扩散模型的100步去噪)。但新手没耐心画100笔,想几笔搞定。这时候就需要“蒸馏”老师的技巧,让新手学会用更少的步骤画出差不多的效果。
怎么教呢?渐进式蒸馏使用了跳步学习的思想:
- 老师先按老方法画完100笔,但记录下关键中间步骤(比如每隔5笔记录一次)。
- 然后告诉新手:“别一步一步画了,你直接从第0笔跳到第5笔,再跳到第10笔,跳过这些中间步骤”。
- 新手练习时,就要模仿老师跳多步后的结果(比如一笔顶老师五笔的效果),这就完成了单次的步数蒸馏(100步->20步)
- 这时候这个20步模型作为教师模型,再去教下一个学生用更少的步数画出相同的结果,反复练几次,新手就能用很少很少的次数画出老师100笔的效果了。
这个渐进压缩步数的过程就是渐进式蒸馏的核心。
为什么需要渐进蒸馏:
1. 直接学习很难训练,容易模式崩塌。
2. 逐步蒸馏避免误差跳跃过大,积累误差。
分数蒸馏 Score Distillation
直接的渐进式蒸馏技术,在压缩后几步的时候效果会急剧下降,于是DMD提出了分数蒸馏的改进方法,通过最小化生成分布与真实分布之间的KL散度,确保生成图像与原始扩散模型输出在分布层面一致,从而使得模型画出来的结果也和原始模型一样好。
提到KL散度大家可能会联想到GAN,这篇文章也提到,对抗蒸馏(下文会解释)的方法一般是引入判别器,区分教师和学生的生成成果,通过对抗loss迫使学生欺骗判别器。作者认为: “对抗训练需要复杂的平衡,且容易导致模式崩塌(mode collapse),而分布匹配通过显式的最小化KL散度,能更稳定地实现一步生成。
对抗训练 Adversarial Training
对抗训练通常是通过构建一个生成对抗网络(GAN)的架构,其中学生模型作为生成器(Generator,通常用教师模型进行初始化),负责生成样本;另外引入一个判别器(Discriminator),用于区分生成的样本是来自学生模型还是教师模型,从而让学生模型的分布接近教师模型的分布。SDXL-Turbo采用的蒸馏方案就是Adversarial Diffusion Distillation(ADD)。
由于引入对抗机制,GAN方法通常生成质量都会比较接近教师模型,但正如前文所说,GAN面临着难以训练,且容易模式崩塌的问题。并且SDXL-Turbo采用的D是传统的图片编码backbone(DINOv2),不支持latent输入,限制了更大分辨率的图片生成,并且只能在t=0(也就是干净去噪的图片)上使用,无法兼容渐进式的蒸馏方法。
SDXL-Lightning结合了对抗蒸馏和渐进式蒸馏,采用和G一样的网络结构的D(都是pre-trained Diffusion Unet)来支持对t的输入,先直接把模型从 128 步直接蒸馏到 32 步,然后按照按32->8->4 ->2 ->1的顺序,增加对抗损失进行渐进式蒸馏。
饺子包完了下一篇就可以蘸醋吃了大家再等等!