Diffusers and Quanto giving hope to the GPU-challenged
Image generation tools are hotter than ever, and they’ve never been more powerful. Models like PixArt Sigma and Flux.1 are leading the charge, thanks to their open weight models and permissive licenses. This setup allows for creative tinkering, including training LoRAs without sharing data outside your computer.
However, working with these models can be challenging if you’re using older or less VRAM-rich GPUs. Typically, there’s a trade-off between quality, speed, and VRAM usage. In this blog post, we’ll focus on optimizing for speed and lower VRAM usage while maintaining as much quality as possible. This approach works exceptionally well for PixArt due to its smaller size, but results might vary with Flux.1. I’ll share some alternative solutions for Flux.1 at the end of this post.
Both PixArt Sigma and Flux.1 are transformer-based, which means they benefit from the same quantization techniques used by large language models (LLMs). Quantization involves compressing the model’s components to use less memory. It allows you to keep all model components in GPU VRAM simultaneously, leading to faster generation speeds compared to methods that move weights between the GPU and CPU, which can slow things down.
Let’s dive into the setup!
Setting Up Your Local Environment
First, ensure you have Nvidia drivers and Anaconda installed.
Next, create a python environment and install all the main requirements:
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
Then the Diffusers and Quanto libs:
pip install pillow==10.3.0 loguru~=0.7.2 optimum-quanto==0.2.4 diffusers==0.30.0 transformers==4.44.2 accelerate==0.33.0 sentencepiece==0.2.0
Quantization Code
Here’s a simple script to get you started for PixArt-Sigma:
from optimum.quanto import qint8, qint4, quantize, freeze
from diffusers import PixArtSigmaPipeline
import torchpipeline = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
)
quantize(pipeline.transformer, weights=qint8)
freeze(pipeline.transformer)
quantize(pipeline.text_encoder, weights=qint4, exclude="proj_out")
freeze(pipeline.text_encoder)
pipe = pipeline.to("cuda")
for i in range(2):
generator = torch.Generator(device="cpu").manual_seed(i)
prompt = "Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed"
image = pipe(prompt, height=512, width=768, guidance_scale=3.5, generator=generator).images[0]
image.save(f"Sigma_{i}.png")
Understanding the Script: Here are the major steps of the implementation
- Import Necessary Libraries: We import libraries for quantization, model loading, and GPU handling.
- Load the Model: We load the PixArt Sigma model in half-precision (float16) to CPU first.
- Quantize the Model: We apply quantization to the transformer and text encoder components of the model. Here we apply different levels of quantizations: The Text encoder part is quantized at qint4 given that it is quite large. The vision part, if quantized at qint8, would make the full pipeline use up 7.5 G VRAM, if not quantized at all would use around 8.5 G VRAM.
- Move to GPU: We move the pipeline to the GPU
.to("cuda")
for faster processing. - Generate Images: We use the
pipe
to generate images based on a given prompt and save the output.
Running the Script
Save the script and run it in your environment. You should see an image generated based on the prompt “Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed” saved as sigma_1.png
. Generation takes 6 seconds on a RTX 3080 GPU.
You can achieve similar results with Flux.1 Schnell, despite its additional components, but it would necessitate more aggressive quantization, which would negatively lower quality (Unless you have access to more VRAM, say 16 or 25 Gigs)
import torchfrom optimum.quanto import qint2, qint4, quantize, freeze
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
quantize(pipe.text_encoder, weights=qint4, exclude="proj_out")
freeze(pipe.text_encoder)
quantize(pipe.text_encoder_2, weights=qint2, exclude="proj_out")
freeze(pipe.text_encoder_2)
quantize(pipe.transformer, weights=qint4, exclude="proj_out")
freeze(pipe.transformer)
pipe = pipe.to("cuda")
for i in range(10):
generator = torch.Generator(device="cpu").manual_seed(i)
prompt = "Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed"
image = pipe(prompt, height=512, width=768, guidance_scale=3.5, generator=generator, num_inference_steps=4).images[0]
image.save(f"Schnell_{i}.png")
We can see that quantization of the text encoder to qint2 and vision transformer to qint8 might be too aggressive, which had a significant impact on the quality for Flux.1 Schnell
Here are some alternatives for running Flux.1 Schnell:
If PixArt-Sigma is not sufficient for your needs and you don’t have enough VRAM to run Flux.1 at sufficient quality you have two main options:
- ComfyUI or Forge: Those are GUI tools that enthusiasts use, they mostly sacrifice speed for quality.
- Replicate API: It costs 0.003 per image generation for Schnell.
Deployment
I had a little fun deploying PixArt Sigma on an older machine I have. Here is a brief summary of how I went about it:
First the list of component:
- HTMX and Tailwind: These are like the face of the project. HTMX helps make the website interactive without a lot of extra code, and Tailwind gives it a nice look.
- FastAPI: It takes requests from the website and decides what to do with them.
- Celery Worker: Think of this as the hard worker. It takes the orders from FastAPI and actually creates the images.
- Redis Cache/Pub-Sub: This is like the communication center. It helps different parts of the project talk to each other and remember important stuff.
- GCS (Google Cloud Storage): This is where we keep the finished images.
Now, how do they all work together? Here’s a simple rundown:
- When you visit the website and make a request, HTMX and Tailwind make sure it looks good.
- FastAPI gets the request and tells the Celery Worker what kind of image to make through Redis.
- The Celery Worker goes to work, creating the image.
- Once the image is ready, it gets stored in GCS, so it’s easy to access.
Service URL: https://image-generation-app-340387183829.europe-west1.run.app
Conclusion
By quantizing the model components, we can significantly reduce VRAM usage while maintaining good image quality and improving generation speed. This method is particularly effective for models like PixArt Sigma. For Flux.1, while the results might be mixed, the principles of quantization remain applicable.
References: