On the Scalability of Diffusion-based Text-to-Image Generation
公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
目录
0. 摘要
3. 扩展 Denoising 骨干
3.1. 现有的 UNet 设计
3.2. UNet 的受控对比
3.3. UNet 设计消融
3.4. 与 Transformer 的比较
4. 数据集的扩展
4.1. 数据集策划
4.2. 数据清洗
4.3. 通过合成标题扩展知识
4.4. 数据扩展提高了训练效率
5. 更多扩展性质
6. 结论
扩大模型和数据规模对 LLM 的发展非常成功。然而,扩散型文本到图像(T2I)模型的扩展规律尚未完全探索。如何有效地扩展模型以在降低成本的情况下获得更好的性能尚不清楚。不同的训练设置和昂贵的训练成本使得公平的模型比较极为困难。在这项工作中,我们通过对去噪骨干和训练集进行广泛和严格的消融实验,包括在数据集上训练缩放 UNet 和 Transformer 变体,范围从 0.4B 到 4B 参数,并涵盖高达 600M 图像。对于模型的扩展,我们发现交叉注意力的位置和数量区分了现有 UNet 设计的性能。增加 Transformer 块对于改善文本-图像对齐比增加通道数量更具参数效率。然后,我们确定了一种高效的 UNet 变体,比 SDXL 的 UNet 小 45%,速度快 28%。在数据扩展方面,我们表明训练集的质量和多样性比简单的数据集大小更重要。增加标题密度和多样性可以提高文本-图像对齐性能和学习效率。最后,我们提供了用于预测文本-图像对齐性能的扩展函数,作为模型大小、计算和数据集大小的函数。
图 2 给出了 SD2 和 SDXL 的 UNet 的比较。 SDXL 在多个方面改进了 SD2:
- 更少的下采样率。SD2 使用(1, 2, 4, 4)作为倍率,以增加不同下采样级别的通道数。DeepFloyd 采用(1, 2, 3, 4)来减少计算量,而 SDXL 使用(1, 2, 4),完全移除了第四个下采样级别。
- 只在较低分辨率下进行交叉关注。交叉关注仅在特定的下采样率下计算,例如,SD2 将交叉关注应用于下采样率(1×、2×、4×),而 SDXL 仅在 2× 和 4× 下采样级别集成文本嵌入。
- 在较低分辨率下进行更多的计算。SDXL 在 2× 和 4× 下采样级别应用更多的 Transformer 块,而 SD2 在所有三个下采样级别应用统一的单一 Transformer 块。
训练。我们在我们的策划数据集 LensArt上 训练模型,该数据集包含 250M 个文本-图像对(详见第 4 节)。我们使用 SDXL 的 VAE 和 OpenCLIP-H [20] 文本编码器(1024 维),没有添加额外的嵌入层或其他条件。我们以 256×256 分辨率训练所有模型,批量大小为 2048,最多 600K 步。我们遵循 LDM [34] 的 DDPM 调度设置。我们使用 AdamW [27] 优化器进行 10K 步热身,然后学习率保持 8e-5 不变。我们采用 BF16 进行混合精度训练,并为大型模型启用 FSDP。
推断和评估。我们在推断中使用 DDIM 采样器 [37] 在 50 个步骤中固定种子和 CFG 比例(7.5)。为了了解训练动态,我们在训练期间监控五个指标的演变。我们发现训练早期的指标可以帮助预测最终模型的性能。具体来说,我们使用以下指标来衡量构图能力和图像质量:
- TIFA [19],通过视觉问答(VQA)衡量生成图像对其文本输入的忠实度。它包含由语言模型生成的 4K 个收集提示和相应的问答对。通过检查现有的 VQA 模型是否可以使用生成的图像回答这些问题来计算图像忠实度。TIFA 允许对生成的图像进行细粒度和可解释的评估。
- ImageReward [40],用于近似人类偏好。我们计算在 MSCOCO-10K 提示下生成的图像的平均 ImageReward 分数。尽管 ImageReward 不是一个归一化分数,但其分数在 [-2, 2] 的范围内,对图像的平均评分提供了有意义的统计信息,以便允许跨模型进行比较。
- 由于空间限制,我们主要展示 TIFA 和 ImageReward,并在附录中提供其他指标(CLIP分数[14, 32],FID,HPSv2 [39])的结果。
SDXL vs SD2 vs IF-XL。我们在上述受控设置中比较了 SDXL [31]、DeepFloyd-IF [9]、SD2 [34] 及其扩展版本的几种现有 UNet 模型的设计。具体来说,我们比较了a)SD2 UNet(0.9B)b)具有 512 个初始通道的 SD2 UNet(2.2B)c)SDXL 的 UNet(2.4B)d)具有 512 通道的 2.0B 的 DeepFloyd 的 IF-XL UNet。
- 图 3 显示了朴素扩展的 SD2-UNet(C512,2.2B)在相同的训练步骤下比基础 SD2 模型取得更好的 TIFA 分数。然而,就训练 FLOPs 而言,收敛速度较慢,这表明增加通道数是一种有效但不是高效的方法。
- SDXL 的 UNet 在 150K 步内实现了 0.82 的 TIFA,比 SD2 UNet 快 6 倍,比 SD2-C512 快 3 倍。尽管其训练迭代速度(FLOPS)比 SD2 慢 2 倍,但它仍以 2 倍的降低训练成本实现了相同的 TIFA 分数。
- SDXL UNet 还可以获得比其他模型高得多的 TIFA 分数(0.84)。
- 因此,SDXL 的 UNet 设计在性能和训练效率方面明显优于其他模型,推动了帕累托前沿。
现在我们已经验证了 SDXL 比 SD2 和 DeepFloyd 变体具有更好的 UNet 设计。问题是它为什么表现出色,以及如何有效地进一步改进它。在这里,我们通过探索其设计空间来研究如何改进 SDXL 的 UNet。
搜索空间。表 1 显示了不同的 UNet 配置及其在 256 分辨率下的计算复杂度。我们主要变化初始通道和 transformer 深度。为了理解设计空间的每个维度的影响,我们选择了一些变体模型并使用相同的配置对它们进行训练。这构成了我们 UNet 架构的主要 “搜索空间”。关于 VAE、训练迭代次数和 batch 大小的更多消融可以在附录中找到。
初始通道的影响。我们用不同的通道数量训练以下 SDXL UNet 变体:128、192 和 384,对应的参数分别为 0.4B、0.9B 和 3.4B。
- 图 4(a)显示,将通道数从 320 减少到 128 的 SDXL UNet 仍然可以优于带有 320 个通道的 SD2 UNet,这表明可以通过合适的架构设计,实现更少的通道数有更好的质量。然而,与 320 个通道的 SDXL UNet 相比,TIFA(也称为 ImageReward/CLIP)分数较差,这表明了通道在视觉质量中的重要性。
- 将 SDXL UNet 通道数从 320 增加到 384,参数数量从 2.4B 增加到 3.4B,它在 600K 训练步骤时也比基线 320 个通道获得更好的指标。
- 注意,初始通道数 C 实际上与 UNet 的其他超参数相关联,例如,1)时间步长嵌入 T 的维度是 4C;2)注意力头的数量与通道数成线性关系,即 C/4。如表 1 所示,当 C 变化时,注意力层的计算比例保持稳定(64%)。这解释了为什么增加 UNet 的宽度也会带来对齐改进,如图 4 所示。
Transformer Depth 的影响。Transformer Depth(TD)设置控制了特定输入分辨率下的 Transformer 块数量。SDXL 在 2× 和 4× 下采样级别分别应用了 2 个和 10 个 Transformer 块。为了理解其影响,我们使用不同的 TD 训练了表 1 中显示的变体,参数范围从 0.9B 到 3.2B。具体地,我们首先在 4× 下采样率上改变 TD,得到 TD2、TD4、TD12 和 TD14,然后我们进一步在 2× 下采样率上改变深度,得到 TD4_4、TD4_8 和 TD4_12。注意,随着 TD 的增加,注意力操作的部分也相应增加。
- 图4(b)显示,将 4× 下采样率上的 TD 从 2 增加到 14 会持续提高 TIFA 分数。
- 从 TD4 和 TD4_4 的比较中,我们可以看到在 2× 分辨率下增加 Transformer 深度(2 → 4)也会提高 TIFA 分数。
- TD4_4 在与 SDXL 的 UNet 相比具有竞争性能的同时,参数减少了 45%,推理计算量减少了 28%。
- 在附录中,我们展示了相对于 SDXL UNet,TD4_4 在墙钟(wall-clock)训练时间方面以 1.7 倍的速度实现了相同的 TIFA 分数。TD4_8 几乎与 SDXL 的 UNet 具有相同的性能,但参数减少了 13%。
- 由于文本-图像对齐(TIFA)主要涉及图像中的大对象,因此在效率考虑之外,将更多的交叉计算分配给较低分辨率或全局图像级别是有帮助的。
同时扩展通道和 Transformer 深度。鉴于通道和 Transformer 深度的影响,我们进一步探索了扩大通道数量(从 320 增加到 384)和 Transformer 深度([0,2,10] → [0,4,12])的效果。图 4(c)显示,在训练过程中,它的 TIFA 分数略高于 SDXL-UNet。然而,与仅增加通道或 Transformer 深度相比的优势并不明显,这意味着在诸如 TIFA 之类的指标下,模型继续扩展的性能存在限制。
可视化 UNet 扩展效果。图 5 显示了使用相同提示生成的不同 UNet 生成的图像。我们可以看到,随着通道数或 Transformer 深度的增加,图像与给定的提示(例如,颜色、计数、空间、对象)更加对齐。某些 UNet 变体生成的图像比原始的 SDXL UNet(C320)更好,即,SDXL-C384 和 SDXL-TD4_8 都以更准确的方式生成第四个提示的图像。
DiT [30] 表明,增加 Transformer 的复杂度可以在 ImageNet 上实现类别条件图像生成的一致性改进图像保真度。PixArt-α [5] 将 DiT 扩展到具有类似骨干结构的文本条件图像生成。然而,在受控设置中与 UNet 进行公平比较还存在不足。为了与 UNet 进行比较并了解其扩展性,我们训练了多个缩放版本的 PixArt-α,保持其他组件和设置与之前的消融相同。表 2 显示了我们缩放变体的配置。与原始 PixArt-α 模型的区别在于:1)我们使用 SDXL 的 VAE 代替 SD2 的 VAE;2)我们使用 OpenCLIP-H 文本编码器代替 T5-XXL [7],token 嵌入维度从 4096 减少到 1024,token 长度为 77 而不是 120。
消融空间。我们在以下维度上对 PixArt-α 模型进行了消融:
- 隐藏维度 h:PixArt-α 继承了 DiT-XL/2 [30] 的设计,具有 1152 维度。我们还考虑了 1024 和 1536。
- Transformer 深度 d:我们将 Transformer 深度从 28 扩展到 56。
- 标题(caption)嵌入:标题嵌入层将文本编码器的输出映射到维度h。当隐藏维度与文本嵌入相同时(即 1024),我们可以跳过标题嵌入直接使用 token 嵌入。
模型扩展的效果。如图 6 所示,扩展隐藏维度 h 和模型深度 d 都会导致文本-图像对齐和图像保真度的提高,而扩展深度 d 会线性改变模型的计算量和大小。d56 和 h1536 变体都以与基线 d28 模型相似的参数大小和计算量实现了约 1.5 倍更快的收敛速度。
与 UNet 的比较。相对于在相同步骤中训练的 SD2-UNet,PixArt-α 变体的 TIFA 和 ImageReward 分数较低,例如,SD2 UNet 在 250K 步时达到 0.80 TIFA 和 0.2 ImageReward,而 0.9B PixArt-α变体达到 0.78 和 0.1。PixArt-α [5] 还报告说,训练过程中没有使用 ImageNet 预训练会导致生成的图像与使用预训练 DiT 权重初始化的模型相比出现失真,后者在 ImageNet 上训练了 7M 步 [30]。尽管 DiT [30] 证明了 UNet 并不是扩散模型的必需品,但 PixArt-α 变体需要更长的迭代次数和更多的计算才能达到与 UNet 相似的性能。我们将此改进留待未来工作,并期待架构改进能够缓解这个问题,例如 [11, 12, 41] 中所做的工作。
我们策划了名为 LensArt 和 SSTK 的数据集。
- LensArt 是从 10 亿个有噪的网络图像文本对中获取的 2.5 亿个图像文本对。我们应用了一系列自动过滤器来消除数据噪声,包括但不限于不安全内容、低审美图像、重复图像和小图像。
- SSTK 是另一个内部数据集,约有 3.5 亿条清理后的数据。
- 表 3 显示了数据集的统计信息。更详细的分析可以在附录中看到。
训练数据的质量是数据扩展的前提条件。与使用有噪数据源训练相比,高质量的子集不仅可以提高图像生成质量,还可以保留图像文本对齐。LensArt 比其未经筛选的 10 亿数据源小 4 倍,移除了数亿条有噪数据。然而,使用这个高质量子集训练的模型将生成图像的平均审美分数 [23] 从 5.07 提高到 5.20。这是因为 LensArt 的平均审美分数为 5.33,高于 LensArt-raw 中的 5.00。此外,如图 7 所示,使用 LensArt 训练的 SD2 模型在 TIFA 分数上与使用原始版本训练的模型相比达到了类似的水平,表明过滤不会损害图像文本对齐。原因是在激进的过滤下仍保留了足够的常识知识,同时消除了大量的重复和长尾数据。
为了增加较小但质量较高的数据的有效文本监督,我们采用了一种内部图像标题模型,类似于 BLIP2 [24],来生成合成标题。如图 9 所示,标题模型为每个图像生成五个通用描述,按预测置信度排序。其中一个合成标题和原始 alt-text 以 50% 的概率随机选取与图像配对进行模型训练。因此,我们将图像文本对加倍,并显著增加了图像-名词对,如表 3 所示。通过合成标题扩展的文本监督,使图像文本对齐和保真度得到一致提升,如表 4 所示。具体来说,LensArt 的消融表明,合成标题显著提高了 ImageReward 分数。此外,我们发现随机从前 5 个合成标题中选择一个略优于始终选择前 1 个,这被采用为合成标题的默认训练方案。与 PixArt-α 不同,后者始终用长合成标题替换原始标题,我们提供了一种通过随机翻转标题来增加图像文本对的替代方法,这与 DALL-E3 的标题增强工作一致 [3]。
组合数据集。随着数据集规模的增加,文本图像对齐和图像质量可以进一步提高。在这里,我们比较了在不同数据集上训练的 SD2 模型,并比较了它们的收敛速度:具有合成标题的 1)LensArt 2)SSTK 和 3)LensArt + SSTK 。我们还将使用未经筛选的 LensArt-raw 进行训练作为基线。图7 显示,与仅在 LensArt 或 SSTK 上训练的模型相比,将 LensArt 和 SSTK 组合起来可以显著提高收敛速度和两个指标的上限。使用 LensArt + SSTK 训练的 SDXL 模型在 100K 步时达到了 0.82+ TIFA 分数,比仅使用 LensArt 训练的 SDXL 快 2.5 倍。
在更大数据集上,高级模型表现更佳。图 8 显示,当在扩展(组合)数据集上训练时,SD2 模型可以获得显著的性能提升。即使在使用扩展数据集训练时,SDXL 仍然比 SD2 模型获得性能提升,这表明容量较大的模型在数据集规模增加时具有更好的性能。
性能与模型 FLOPs 之间的关系。对于所有检验的 SD2 和 SDXL 变体,图 10(a-b)显示了在固定步数(即,600K)获得的 TIFA 分数与模型计算复杂度(GFLOPs)以及模型大小(#Params)之间的相关性。我们看到 TIFA 分数与 FLOPs 的相关性稍微好于参数,表明在训练预算充足时,模型计算的重要性,这与我们在第 3 节中的发现一致。
性能与数据量之间的关系。图11(c)显示了 SD2 的 TIFA 分数与数据集大小(以图像-名词对的数量表示)之间的相关性。每个图像-名词对定义为一个图像与其标题中的一个名词配对。它衡量了细粒度文本单元与图像之间的交互。当扩展清理后的数据时,我们看到 TIFA 与图像-名词对的规模呈线性相关。与具有类似数量的图像-名词对的 LensArt-raw 相比,LensArt+SSTK 要好得多,这表明了数据质量的重要性。
数值缩放定律。LLMs 的缩放定律 [17, 21] 揭示了 LLM 的性能作为数据集大小、模型大小和计算预算的函数具有精确的幂律缩放。在这里,我们为 SDXL 变体和 SD2 拟合了类似的缩放函数。TIFA 分数 S 可以是总计算 C(GFLOPs)、模型参数大小 N(M参数)和数据集大小 D(M图像-名词对)的函数,如图 11 所示。具体来说,通过帕累托边界数据点,我们可以拟合幂律函数为 S = 0.47C^0.02,S = 0.77N^0.11 和 S = 0.64D^0.03,它们在给定充分训练的情况下近似于性能范围。与 LLMs 类似,我们看到较大的模型更具样本效率,而较小的模型更具计算效率。
低分辨率下的模型评估。人们可能会想知道模型的相对性能是否会在高分辨率训练时发生变化,从而缓解模型之间的差距。在附录中,我们展示了持续训练 512 分辨率的模型略微改善了它们的 256 分辨率指标,但没有明显变化。尽管可以通过高质量的微调来改善图像质量和审美 [8],但在相同数据上训练时,较差的模型很难超越,特别是当高分辨率数据远少于其低分辨率版本时。大多数构图能力是在低分辨率下开发的,这使我们能够在低分辨率训练的早期阶段评估模型的性能。