Spaces:
Runtime error
Runtime error
# import gradio as gr | |
# import torch | |
# import spaces | |
# from diffusers import FluxPipeline, DiffusionPipeline | |
# from torchao.quantization import autoquant | |
# # # # normal FluxPipeline | |
# pipeline_normal = FluxPipeline.from_pretrained( | |
# "sayakpaul/FLUX.1-merged", | |
# torch_dtype=torch.bfloat16 | |
# ).to("cuda") | |
# pipeline_normal.transformer.to(memory_format=torch.channels_last) | |
# pipeline_normal.transformer = torch.compile(pipeline_normal.transformer, mode="max-autotune", fullgraph=True) | |
# # # optimized FluxPipeline | |
# # pipeline_optimized = FluxPipeline.from_pretrained( | |
# # "camenduru/FLUX.1-dev-diffusers", | |
# # torch_dtype=torch.bfloat16 | |
# # ).to("cuda") | |
# # pipeline_optimized.transformer.to(memory_format=torch.channels_last) | |
# # pipeline_optimized.transformer = torch.compile( | |
# # pipeline_optimized.transformer, | |
# # mode="max-autotune", | |
# # fullgraph=True | |
# # ) | |
# # # wrap the autoquant call in a try-except block to handle unsupported layers | |
# # for name, layer in pipeline_optimized.transformer.named_children(): | |
# # try: | |
# # # apply autoquant to each layer | |
# # pipeline_optimized.transformer._modules[name] = autoquant(layer, error_on_unseen=False) | |
# # print(f"Successfully quantized {name}") | |
# # except AttributeError as e: | |
# # print(f"Skipping layer {name} due to error: {e}") | |
# # except Exception as e: | |
# # print(f"Unexpected error while quantizing {name}: {e}") | |
# # pipeline_optimized.transformer = autoquant( | |
# # pipeline_optimized.transformer, | |
# # error_on_unseen=False | |
# # ) | |
# pipeline_optimized = pipeline_normal | |
# @spaces.GPU(duration=120) | |
# def generate_images(prompt, guidance_scale, num_inference_steps): | |
# # # generate image with normal pipeline | |
# # image_normal = pipeline_normal( | |
# # prompt=prompt, | |
# # guidance_scale=guidance_scale, | |
# # num_inference_steps=int(num_inference_steps) | |
# # ).images[0] | |
# # generate image with optimized pipeline | |
# image_optimized = pipeline_optimized( | |
# prompt=prompt, | |
# guidance_scale=guidance_scale, | |
# num_inference_steps=int(num_inference_steps) | |
# ).images[0] | |
# return image_optimized | |
# # set up Gradio interface | |
# demo = gr.Interface( | |
# fn=generate_images, | |
# inputs=[ | |
# gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"), | |
# gr.Slider(1.0, 10.0, step=0.5, value=3.5, label="Guidance Scale"), | |
# gr.Slider(10, 100, step=1, value=50, label="Number of Inference Steps") | |
# ], | |
# outputs=[ | |
# gr.Image(type="pil", label="Optimized FluxPipeline") | |
# ], | |
# title="FluxPipeline Comparison", | |
# description="Compare images generated by the normal FluxPipeline and the optimized one using torchao and torch.compile()." | |
# ) | |
# demo.launch() | |
import gradio as gr | |
import torch | |
from optimum.quanto import quantize | |
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
from transformers import CLIPTextModel, CLIPTokenizer, T5TokenizerFast | |
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline | |
import subprocess | |
import spaces | |
import os | |
# Set the data type for inference | |
dtype = torch.bfloat16 | |
# Hugging Face repository and revision settings | |
repo_name = "FLUX.1-schnell-4bit" | |
bfl_repo = "black-forest-labs/FLUX.1-schnell" | |
revision = "refs/pr/1" | |
# Ensure local directory exists and download model files | |
subprocess.run(["mkdir", "-p", repo_name]) | |
subprocess.run([ | |
"huggingface-cli", "download", "PrunaAI/" + repo_name, | |
"--local-dir", repo_name, | |
"--local-dir-use-symlinks", "False" | |
]) | |
# Load scheduler, tokenizer, and VAE from the pre-trained repo | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder="scheduler", revision=revision) | |
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype) | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype) | |
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype, revision=revision) | |
# Load text_encoder_2 and tokenizer_2 locally | |
text_encoder_2 = torch.load(repo_name + '/text_encoder_2.pt') | |
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, revision=revision) | |
# Load transformer locally (quantized model) | |
transformer = torch.load(repo_name + '/transformer.pt') | |
# Create the pipeline using the pre-trained models | |
pipe = FluxPipeline( | |
scheduler=scheduler, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
text_encoder_2=text_encoder_2, | |
tokenizer_2=tokenizer_2, | |
vae=vae, | |
transformer=transformer, | |
) | |
# Enable model CPU offload to save memory | |
pipe.enable_model_cpu_offload() | |
# Define the image generation function | |
def generate_image(prompt, guidance_scale, num_inference_steps): | |
generator = torch.Generator().manual_seed(12345) | |
image = pipe( | |
prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=int(num_inference_steps), | |
max_sequence_length=256, | |
generator=generator | |
).images[0] | |
return image | |
# Set up Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# FLUX.1-schnell 4-bit Quantized Model") | |
# Input for text prompt | |
prompt_input = gr.Textbox(lines=2, label="Prompt", placeholder="Enter your prompt here...") | |
# Slider for guidance scale | |
guidance_scale_input = gr.Slider(0.0, 10.0, step=0.1, value=7.5, label="Guidance Scale") | |
# Slider for number of inference steps | |
inference_steps_input = gr.Slider(4, 50, step=1, value=25, label="Number of Inference Steps") | |
# Button to trigger generation | |
generate_button = gr.Button("Generate Image") | |
# Output image | |
output_image = gr.Image(label="Generated Image", type="pil") | |
# Connect button to the image generation function | |
generate_button.click(fn=generate_image, | |
inputs=[prompt_input, guidance_scale_input, inference_steps_input], | |
outputs=[output_image]) | |
# Launch the Gradio app | |
demo.launch() | |