【必看】Diffusion Transformer与Rectified Flow Transformer:图像生成新突破,你了解多少?
4 个月前
概述
介绍
扩散变换器(DiT)
修正流变换器
蒸馏加速推理
总结
按主题的参考文献
附录
介绍
在过去的几年中,图像生成神经网络的质量、美学效果和提示符的遵循度迅速提升。了解一些关键的技术创新可以帮助我们理解为什么某些模型在性能上大幅超越了之前的深度学习扩散模型。在本文中,我们将重点介绍扩散变换器(DiT),并探讨其在高质量类别条件图像生成和文本到图像生成方面的关键进展。
- 扩散变换器(DiT)
2023年,William Peebles和Saining Xie提出的DiT标志着与早期扩散模型的显著不同(Peebles和Xie,2023)。用于图像生成的扩散模型包括两个关键设计特征:卷积U-Net骨干网络(Ho等,2020)和最近的潜在扩散架构(Rombach等,2022)。DiT保留了潜在扩散架构,并将其作为视觉变换器(ViT)骨干网络的输入(图1)。DiT还引入了时间步长和类别标签作为嵌入,并使用自适应层归一化(adaLN)将条件信息注入模型。该模型学习从潜在预测中去除噪声,条件是时间步长和类别标签的嵌入。
图1. DiT使用潜在扩散架构作为视觉变换器(ViT)的输入,并进行了一些修改。最终层将图像标记序列解码为输出噪声预测和协方差预测Σ,用于训练模型。每个DiT块用自适应层归一化(adaLN)替换标准层归一化,其维度尺度和移位参数从t和c的嵌入向量之和回归,此外,在残差连接之前应用缩放参数(adaLN-Zero)。
前向扩散过程从冻结编码器E获得的输入潜在表示z = E(x)开始。在每个时间步长,高斯噪声逐渐添加到z中。模型被训练为在每个时间步长预测噪声,条件是类别标签。在训练期间,均值对应于噪声预测,模型最小化均方误差(MSE)损失。方差对应于对角协方差,使用KL散度损失来优化这一项。无分类器指导在训练期间随机丢弃c,并用学习的“空”嵌入∅替换。然后可以通过反向扩散过程生成图像。输入用随机高斯噪声初始化,然后在每个时间步长迭代去噪。在步骤t,模型将当前噪声潜在变量、时间步长和类别标签作为输入进行条件处理。最后,去噪的潜在表示使用冻结解码器D解码为图像,x = D(z)。
在模型扩展方面,Peebles等人实验了从3300万到6.75亿参数的配置。在ImageNet上训练的DiT在类别条件256×256生成基准测试中取得了2.27 FID的最先进结果。
以下是运行官方PyTorch DiT模型进行256×256生成的代码。我使用的是Google Colaboratory上的NVIDIA L4 GPU。首先需要设置环境并导入依赖项。
!git clone https://github.com/facebookresearch/DiT.git
import DiT, os
os.chdir('DiT')
os.environ['PYTHONPATH'] = '/env/python:/content/DiT'
!pip install diffusers timm --upgrade
# DiT导入:
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda"
现在加载模型。
image_size = 256
vae_model = "stabilityai/sd-vae-ft-ema"
latent_size = int(image_size) // 8
# 加载模型:
model = DiT_XL_2(input_size=latent_size).to(device)
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval() # 重要!
vae = AutoencoderKL.from_pretrained(vae_model).to(device)
使用ImageNet 1K的类别标签从DiT中采样。我选择了四个狗类别来比较结果。类别列表可以在这里找到。
seed = 0
torch.manual_seed(seed)
num_sampling_steps = 100
cfg_scale = 4
#ImageNet类别标签
class_labels = 162,247,248,254
samples_per_row = 2
# 创建扩散对象:
diffusion = create_diffusion(str(num_sampling_steps))
# 创建采样噪声:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)
# 设置无分类器指导:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)
# 采样图像:
samples = diffusion.p_sample_loop(
model.forward_with_cfg, z.shape, z, clip_denoised=False,
model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0) # 移除空类别样本
samples = vae.decode(samples / 0.18215).sample
# 保存并显示图像:
save_image(samples, "sample.png", nrow=int(samples_per_row),
normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)
四个类别条件下的采样结果如图2所示。
图2. 官方PyTorch实现的DiT在256x256图像生成中的类别条件图像生成。顶行选择的类别:'beagle'和'Saint Bernard, St Bernard',底行:'Eskimo dog, husky'和'pug, pug-dog',每个类别经过100个采样步骤。
- 修正流变换器
Stability AI在DiT架构的基础上发布了Stable Diffusion 3(SD3),用于文本到图像生成。该版本引入了几个关键更新,包括改进的训练目标、多模态扩散变换器(MMDiT)架构以及结合人类反馈的微调程序。这些修改简要总结如下。
SD3在训练损失中采用了修正流(Esser等,2024),利用最优传输(OT)的概念将数据分布和噪声沿直线路径连接(Liu等,2022;Albergo & Vanden-Eijnden,2022;Lipman等,2023)。简单来说,与遵循随机轨迹的扩散过程不同,OT建立了确定性的直线路径(图3)。在生成建模的背景下,OT使用常微分方程(ODE)定义噪声和样本分布之间的映射。这种方法具有优势,因为前向过程直接影响学习的反向过程,提高了采样效率。修正流是一种基于OT的目标,将前向过程重新定义为数据分布和标准正态分布之间的直线路径。该公式被重新参数化为噪声预测目标,以与扩散训练保持一致。
图3. 扩散路径(左)与最优传输路径(右)下的点源轨迹。这些简化的2D轨迹展示了效率的差异,OT更倾向于直接路径。
SD3的作者发现,固定的文本表示对于图像生成并不理想。相反,MMDiT在操作中混合了图像和文本标记的可学习流,实现了信息的双向流动。MMDiT结合了三种不同的文本编码器——两种基于CLIP,另一种基于T5——来表示文本输入。在每个MMDiT块中,在计算注意力矩阵之前使用查询-键归一化,以减少训练期间注意力对数增长的不稳定性,并进一步简化微调。模型的参数范围从8亿到80亿。
SD3通过直接偏好优化(DPO)进行了微调。DPO是替代人类反馈强化学习的一种方法,由Raflailov等人为语言模型引入,Wallace等人为扩散模型引入(Raflailov等,2023和Wallace等,2023)。DPO已被证明可以直接提高图像生成的质量、提示符的遵循度和文本生成,而无需像强化学习方法那样训练单独的奖励模型。在SD3中,DPO微调与低秩适应(LoRA)矩阵结合在2B和8B参数模型中。结果见附录。
- 蒸馏加速推理
当前生成模型的蒸馏方法旨在提高采样速度,同时保留扩散模型的迭代细化能力。最近的模型发布进一步推进了修正流变换器的蒸馏。Stable Diffusion 3.5(SD3.5)引入了8.1亿参数的大型模型和一个称为Large Turbo的蒸馏版本。Turbo系列利用对抗性扩散蒸馏(ADD)在仅1-4步内实现高效采样,同时保持高图像质量(Sauer等,2023)。
以下是使用Hugging Face Diffusers运行SD3.5 Large Turbo推理的代码。这是一个免费模型,但需要登录Hugging Face令牌。
# 需要登录的模型:使用具有访问权限的HF令牌登录
!huggingface-cli login
现在加载模型。下载需要几分钟。
# 固定随机种子的turbo large
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large-turbo", torch_dtype=torch.bfloat16, use_safetensors=True)
将模型放在GPU上并运行推理。我使用的是Google Colaboratory上的NVIDIA A100 GPU。
pipe = pipe.to("cuda")
# 使用生成器提供固定的起始种子
generator = torch.Generator(device="cuda").manual_seed(0)
prompt = '一张高度细节、美丽的人脸,以非常高的分辨率拍摄,捕捉到迷人的镜头,具有美学吸引力,个体具有友好可亲的表情。'
# 对于Turbo Large,使用指导尺度为0
image = pipe(
prompt,
num_inference_steps=4,
guidance_scale=0.0,
generator=generator
).images[0]
image
SD3.5 Large Turbo在几个采样步骤中实现了与SD3.5 Large相似的性能。蒸馏模型保留了迭代细化能力,同时需要更少的步骤来生成高质量的结果。结果如图4-6所示。
图4. 使用ADD的SD3.5 Large Turbo分别在1步、4步和10步下的结果。
图5. SD3.5 Large在1步、10步和25步下的结果,指导尺度为3.5。
图6. SD3.5 Large在1步、10步和25步下的结果,指导尺度为7.0。
ADD在训练期间使用两个损失函数:1)对抗性损失,通过训练模型欺骗判别器来确保输出位于真实图像的流形上;2)蒸馏损失,使用预训练的扩散模型作为教师,指导模型匹配去噪目标。Black Forest Labs还贡献了开放权重的FLUX.1系列。FLUX.1模型直接从非开放权重的120亿参数专业模型蒸馏而来。FLUX.1 [schnell]使用ADD,而FLUX.1 [dev]采用指导蒸馏来提高采样效率。
以下是运行FLUX.1 [schnell]的代码。该模型也需要Hugging Face令牌访问。
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = "一块古老的石碑部分埋在沙漠中,黄昏时分,石碑上覆盖着发光的符号。铭文用符文和未来主义文字写着:'我们将需要变换器来实现人工智能'。符号微弱地闪烁,在沙子上投下阴影。上方,一个半透明的卷轴展开,上面有更多神秘的符文,天空从粉红色过渡到紫色,地平线上有一个超现实的月亮。"
image = pipe(prompt).images[0]
image
ADD模型的结果,包括SD3.5 Large Turbo和FLUX.1 [schnell],如图7所示。
FLUX.1 [dev]应用指导蒸馏来提高无分类器指导模型的效率。这个两阶段过程首先训练学生模型复制冻结教师模型的输出,然后逐步将学生蒸馏为需要更少采样步骤的版本(Meng等,2023)。FLUX.1 [dev]的结果如图7所示。
图7. 对抗性扩散蒸馏模型的比较,SD3.5 Large Turbo(左)和FLUX.1 [schnell](中)。两个模型均在A100 GPU上运行4个推理步骤,指导尺度为0.0,使用bfloat-16操作。FLUX.1 [dev]运行28步,指导尺度为7(右)。
- 总结
带有变换器的扩散模型在类别标签引导的图像生成方面提升了技术水平。修正流的进一步进展提高了文本到图像生成的质量和采样效率。蒸馏的修正流变换器在几个采样步骤内提供了高度美学的结果。
- 按主题的参考文献
扩散变换器(DiT)
Peebles和Xie,可扩展的扩散模型与变换器,2023.
去噪扩散
Rombach等,高分辨率图像合成与潜在扩散模型,2022.
修正流变换器
Esser等,扩展修正流变换器用于高分辨率图像合成,2024.
修正流
Liu等,Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow, 2022.
Albergo和Vanden-Eijnden,Building Normalizing Flows with Stochastic Interpolants, 2023.
Lipman等,Flow Matching for Generative Modeling, 2023.
直接偏好优化
Raflailov等,Direct Preference Optimization: Your Language Model is Secretly a Reward Model 2023.
Wallace等,Diffusion Model Alignment Using Direct Preference Optimization, 2023.
模型蒸馏
Sauer等,Adversarial Diffusion Distillation, 2023.
Meng等,On Distillation of Guided Diffusion Models, 2023.
最近的模型发布
Introducing Stable Diffusion 3.5(2024年11月访问).
Announcing Black Forest Labs(2024年11月访问).
- 附录
A. 近期和旧版模型的视觉比较
使用Hugging Face Diffusers库在A100 GPU上加载16位模型。推理使用28个采样步骤和指导尺度7运行。生成每个图像面板的提示符在本节末尾显示。
选择变换器结果,按模型大小从小到大排列。
stabilityai/stable-diffusion-3.5-large;总参数81亿。
black-forest-labs/FLUX.1-dev;总参数120亿;指导蒸馏。
选择旧版U-Net模型的结果,按模型大小从小到大排列。
stabilityai/stable-diffusion-2;总参数8600万加一个文本编码器。
stabilityai/stable-diffusion-xl-base-1.0;总参数26亿加两个文本编码器。
提示1
"宇航员在丛林中,冷色调,柔和色彩,细节丰富,8k,图像中集成文字'The future is with diffusion transformers'。"
提示2
"一块古老的石碑部分埋在沙漠中,黄昏时分,石碑上覆盖着发光的符号。铭文用符文和未来主义文字写着:'我们将需要变换器来实现人工智能'。符号微弱地闪烁,在沙子上投下阴影。上方,一个半透明的卷轴展开,上面有更多神秘的符文,天空从粉红色过渡到紫色,地平线上有一个超现实的月亮。"
_B.
FluxAI 中文
© 2025. All Rights Reserved