前语

在本文中,咱们将运用基于 KerasCV 完成的 Stable Diffusion 模型进行图画生成,这是由 stable.ai 开发的文本生成图画的多模态模型。

Stable Diffusion 是一种功能强大的开源的文本到图画生成模型。尽管市场上存在多种开源完成能够让用户依据文本提示轻松创建图画,但 KerasCV 有一些独特的优势来加快图片生成,其中包括 XLA 编译混合精度支撑等特性。所以本文除了介绍如何运用 KerasCV 内置的 StableDiffusion 模块来生成图画,别的咱们还经过比照展示了运用 KerasCV 特性所带来的图片加快优势。

准备

  • N 卡,主张 24 G ,在下文运用 KerasCV 实际生成图画过程中至少需求 20 G
  • 装置 python 3.10 的 anaconda 虚拟环境
  • 装置 tensorflow gpu 2.10
  • 一颗充溢想象力的大脑,主要是用来构建自己的文本 prompt

这儿有一个东西函数 plot_images ,主要是用来把模型生成的图画进行展示。

def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")
    plt.show()

模型作业原理

超分辨率作业能够练习深度学习模型来对输入图画进行去噪,从而将其转换为更高分辨率的作用。为了完成这一目的,深度学习模型并不是经过恢复低分辨率输入图画中丢掉的信息做到的,而是模型运用其练习数据散布来填充最有或许的给定输入的视觉细节。

然后将这个想法面向极限,在纯噪声上运转这样的模型,然后运用该模型不断去噪终究发生一个全新的图画。这就是潜在分散模型的关键思维,在 2020 年的 High-Resolution Image Synthesis with Latent Diffusion Models中提出。

运用 Keras 的 Stable Diffusion 完成高性能文生图

现在要从潜在分散过渡到文本生成图画的作用,需求增加关键字控制生成图画的能力,简略来说就是将一段文本的向量参加到到带噪图片中,然后在数据集上练习模型即可得到咱们想要的文生图模型 Stable Diffusion 。这就发生了 Stable Diffusion 架构,主要由三部分组成:

  • text encoder:可将用户的提示转换为向量。
  • diffusion model:反复对 64×64 潜在图画进行去噪。
  • decoder:将终究生成的 64×64 潜在图画转换为更高分辨率的 512×512 图画。

基本模型架构图如下:

运用 Keras 的 Stable Diffusion 完成高性能文生图

benchmark

咱们运用 keras_cv 中的 StableDiffusion 模块结构一个文生图基准模型 model ,在对模型进行基准测验之前,先履行一次 text_to_image 函数来预热模型,以确保 TensorFlow graph已被跟踪,这样在后续运用模型进行推理时分的速度测验才是准确的。能够从日志中看到第一次运转的时刻是 22 s ,这个不用去管他,咱们只看第二个时刻。

我这儿的提示词是“There is a pink BMW Mini at the exhibition where the lights focus” ,生成 3 张图画,耗时 10.32 s

履行结束之后运转 keras.backend.clear_session() 铲除刚刚运转的模型,以确保不会影响到后面的实验。

model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=False)
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image("There is a pink BMW Mini at the exhibition where the lights focus", batch_size=3)
print(f"Standard model: {(time.time() - start):.2f} seconds")
plot_images(images)
keras.backend.clear_session()

日志打印:

25/25 [==============================] - 22s 399ms/step
25/25 [==============================] - 10s 400ms/step
Standard model: 10.32 seconds

运用 Keras 的 Stable Diffusion 完成高性能文生图

benchmark + Mixed precision

正如日志中打印的信息能够看到,咱们这儿构建的模型现在运用混合精度核算,运用 float16 运算的速度进行核算,同时以 float32 精度存储变量,这是因为 NVIDIA GPU 内核处理相同的操作,运用 float16 比 float32 要快得多。

咱们这儿和上面一样先将模型预热加载,然后针对我的提示词“There is a black BMW Mini at the exhibition where the lights focus”生成了 3 张图画,耗时 5.30s ,能够看到在 benchmark 基础上运用混合精度生成速度提升将近一倍。

keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=False)
print("Compute dtype:", model.diffusion_model.compute_dtype)
print("Variable dtype:",  model.diffusion_model.variable_dtype)
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image( "There is a black BMW Mini at the exhibition where the lights focus", batch_size=3,)
print(f"Mixed precision model: {(time.time() - start):.2f} seconds")
plot_images(images)
keras.backend.clear_session()

日志打印:

Compute dtype: float16
Variable dtype: float32
25/25 [==============================] - 9s 205ms/step
25/25 [==============================] - 5s 202ms/step
Mixed precision model: 5.30 seconds

运用 Keras 的 Stable Diffusion 完成高性能文生图

benchmark + XLA Compilation

XLA(加快线性代数)是一种用于机器学习的开源编译器。XLA 编译器从 PyTorch、TensorFlow 和 JAX 等常用框架中获取模型,并优化模型以在不同的硬件渠道(包括 GPU、CPU 和机器学习加快器)上完成高性能履行。

TensorFlow 和 JAX 附带 XLA , keras_cv.models.StableDiffusion 支撑开箱即用的 jit_compile 参数。 将此参数设置为 True 可启用 XLA 编译,从而显著进步速度。

从日志中能够看到,在 benchmark 基础上运用 XLA 生成时刻减少了 3.34 s

keras.mixed_precision.set_global_policy("float32")
model = keras_cv.models.StableDiffusion(jit_compile=True)
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image("There is a black ford mustang at the exhibition where the lights focus", batch_size=3, )
print(f"With XLA: {(time.time() - start):.2f} seconds")
plot_images(images)
keras.backend.clear_session()

日志打印:

25/25 [==============================] - 34s 271ms/step
25/25 [==============================] - 7s 271ms/step
With XLA: 6.98 seconds

运用 Keras 的 Stable Diffusion 完成高性能文生图

benchmark + Mixed precision + XLA Compilation

最后咱们在 benchmark 基础上同时运用混合精度核算和 XLA 编译,终究生成相同的 3 张图画,时刻仅为 3.96s ,与 benchmark 相比生成时刻减少了 6.36 s ,生成时刻大幅缩短!

keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=True)
model.text_to_image("warming up the model", batch_size=3, )
start = time.time()
images = model.text_to_image( "There is a purple ford mustang at the exhibition where the lights focus", batch_size=3,)
print(f"XLA + mixed precision: {(time.time() - start):.2f} seconds")
plot_images(images)
keras.backend.clear_session()

日志打印:

25/25 [==============================] - 28s 144ms/step
25/25 [==============================] - 4s 152ms/step
XLA + mixed precision: 3.96 seconds

运用 Keras 的 Stable Diffusion 完成高性能文生图

定论

四种状况的耗时比照结果,展示了运用 KerasCV 生成图片确实在速度方面有特别之处:

  • benchmark : 10.32s
  • benchmark + Mixed precision :5.3 s
  • benchmark + XLA Compilation : 6.98s
  • benchmark + Mixed precision + XLA Compilation : 3.96s