Spaces:
Running
on
Zero
Running
on
Zero
import io | |
import os | |
import shutil | |
import uuid | |
import torch | |
import random | |
import spaces | |
import gradio as gr | |
print(gr.__version__) | |
import numpy as np | |
from PIL import Image, ImageCms | |
import torch | |
from diffusers import FluxTransformer2DModel | |
from diffusers.utils import load_image | |
from pipeline_flux_control_removal import FluxControlRemovalPipeline | |
torch.set_grad_enabled(False) | |
device = "cuda" | |
print(device) | |
image_path = mask_path = None | |
image_examples = [...] | |
image_path = mask_path =None | |
image_examples = [ | |
[ | |
"example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png", | |
"example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png" | |
], | |
[ | |
"example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png", | |
"example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png" | |
], | |
[ | |
"example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png", | |
"example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png" | |
], | |
[ | |
"example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png", | |
"example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png" | |
], | |
[ | |
"example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png", | |
"example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png" | |
], | |
[ | |
"example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png", | |
"example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png" | |
], | |
[ | |
"example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png", | |
"example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png" | |
] | |
] | |
def load_model(base_model_path, lora_path): | |
global pipe | |
transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16) | |
gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%")) | |
# enable image inputs | |
with torch.no_grad(): | |
initial_input_channels = transformer.config.in_channels | |
new_linear = torch.nn.Linear( | |
transformer.x_embedder.in_features*4, | |
transformer.x_embedder.out_features, | |
bias=transformer.x_embedder.bias is not None, | |
dtype=transformer.dtype, | |
device=transformer.device, | |
) | |
new_linear.weight.zero_() | |
new_linear.weight[:, :initial_input_channels].copy_(transformer.x_embedder.weight) | |
if transformer.x_embedder.bias is not None: | |
new_linear.bias.copy_(transformer.x_embedder.bias) | |
transformer.x_embedder = new_linear | |
transformer.register_to_config(in_channels=initial_input_channels*4) | |
pipe = FluxControlRemovalPipeline.from_pretrained( | |
base_model_path, | |
transformer=transformer, | |
torch_dtype=torch.bfloat16 | |
).to(device) | |
pipe.transformer.to(torch.bfloat16) | |
gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%")) | |
gr.Info(str(f"Inject LoRA: {lora_path}")) | |
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors") | |
gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%")) | |
def set_seed(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
def predict( | |
input_image, | |
prompt, | |
ddim_steps, | |
seed, | |
scale, | |
image_paths, | |
mask_paths | |
): | |
global image_path, mask_path | |
gr.Info(str(f"Set seed = {seed}")) | |
if image_paths is not None: | |
input_image["image"] = load_image(image_paths).convert("RGB") | |
input_image["mask"] = load_image(mask_paths).convert("RGB") | |
size1, size2 = input_image["image"].convert("RGB").size | |
icc_profile = input_image["image"].info.get('icc_profile') | |
if icc_profile: | |
gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB...")) | |
srgb_profile = ImageCms.createProfile("sRGB") | |
io_handle = io.BytesIO(icc_profile) | |
src_profile = ImageCms.ImageCmsProfile(io_handle) | |
input_image["image"] = ImageCms.profileToProfile(input_image["image"], src_profile, srgb_profile) | |
input_image["image"].info.pop('icc_profile', None) | |
if size1 < size2: | |
input_image["image"] = input_image["image"].convert("RGB").resize((1024, int(size2 / size1 * 1024))) | |
else: | |
input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 1024), 1024)) | |
img = np.array(input_image["image"].convert("RGB")) | |
W = int(np.shape(img)[0] - np.shape(img)[0] % 8) | |
H = int(np.shape(img)[1] - np.shape(img)[1] % 8) | |
input_image["image"] = input_image["image"].resize((H, W)) | |
input_image["mask"] = input_image["mask"].resize((H, W)) | |
if seed == -1: | |
seed = random.randint(1, 2147483647) | |
set_seed(random.randint(1, 2147483647)) | |
else: | |
set_seed(seed) | |
result = pipe( | |
prompt=prompt, | |
control_image=input_image["image"].convert("RGB"), | |
control_mask=input_image["mask"].convert("RGB"), | |
width=H, | |
height=W, | |
num_inference_steps=ddim_steps, | |
generator=torch.Generator(device).manual_seed(seed), | |
guidance_scale=scale, | |
max_sequence_length=512, | |
).images[0] | |
mask_np = np.array(input_image["mask"].convert("RGB")) | |
red = np.array(input_image["image"]).astype("float") * 1 | |
red[:, :, 0] = 180.0 | |
red[:, :, 2] = 0 | |
red[:, :, 1] = 0 | |
result_m = np.array(input_image["image"]) | |
result_m = Image.fromarray( | |
( | |
result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red | |
).astype("uint8") | |
) | |
dict_res = [input_image["image"], input_image["mask"], result_m, result] | |
dict_out = [result] | |
image_path = None | |
mask_path = None | |
return dict_out, dict_res | |
def infer( | |
input_image, | |
ddim_steps, | |
seed, | |
scale, | |
removal_prompt, | |
): | |
img_path = image_path | |
msk_path = mask_path | |
return predict(input_image, | |
removal_prompt, | |
ddim_steps, | |
seed, | |
scale, | |
img_path, | |
msk_path | |
) | |
def process_example(image_paths, mask_paths): | |
global image_path, mask_path | |
image = Image.open(image_paths).convert("RGB") | |
mask = Image.open(mask_paths).convert("L") | |
black_background = Image.new("RGB", image.size, (0, 0, 0)) | |
masked_image = Image.composite(black_background, image, mask) | |
image_path = image_paths | |
mask_path = mask_paths | |
return masked_image | |
custom_css = """ | |
.contain { max-width: 1200px !important; } | |
.custom-image { | |
border: 2px dashed #7e22ce !important; | |
border-radius: 12px !important; | |
transition: all 0.3s ease !important; | |
} | |
.custom-image:hover { | |
border-color: #9333ea !important; | |
box-shadow: 0 4px 15px rgba(158, 109, 202, 0.2) !important; | |
} | |
.btn-primary { | |
background: linear-gradient(45deg, #7e22ce, #9333ea) !important; | |
border: none !important; | |
color: white !important; | |
border-radius: 8px !important; | |
} | |
#inline-examples { | |
border: 1px solid #e2e8f0 !important; | |
border-radius: 12px !important; | |
padding: 16px !important; | |
margin-top: 8px !important; | |
} | |
#inline-examples .thumbnail { | |
border-radius: 8px !important; | |
transition: transform 0.2s ease !important; | |
} | |
#inline-examples .thumbnail:hover { | |
transform: scale(1.05); | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); | |
} | |
.example-title h3 { | |
margin: 0 0 12px 0 !important; | |
color: #475569 !important; | |
font-size: 1.1em !important; | |
display: flex !important; | |
align-items: center !important; | |
} | |
.example-title h3::before { | |
content: "π"; | |
margin-right: 8px; | |
font-size: 1.2em; | |
} | |
""" | |
with gr.Blocks( | |
css=custom_css, | |
theme=gr.themes.Soft( | |
primary_hue="purple", | |
secondary_hue="purple", | |
font=[gr.themes.GoogleFont('Inter'), 'sans-serif'] | |
), | |
title="Omnieraser" | |
) as demo: | |
base_model_path = "black-forest-labs/FLUX.1-dev" | |
lora_path = 'theSure/Omnieraser' | |
load_model(base_model_path=base_model_path, lora_path=lora_path) | |
ddim_steps = gr.Slider(visible=False, value=28) | |
scale = gr.Slider(visible=False, value=3.5) | |
seed = gr.Slider(visible=False, value=-1) | |
removal_prompt = gr.Textbox(visible=False, value="There is nothing here.") | |
gr.Markdown(""" | |
<div align="center"> | |
<h1 style="font-size: 2.5em; margin-bottom: 0.5em;">πͺ Omnieraser</h1> | |
</div> | |
""") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1, variant="panel"): | |
gr.Markdown("## π₯ Input Panel") | |
with gr.Group(): | |
input_image = gr.Image( | |
type="pil", | |
tool="sketch", | |
label="Upload & Annotate", | |
height=400, | |
elem_id="custom-image", | |
interactive=True | |
) | |
with gr.Row(variant="compact"): | |
run_button = gr.Button( | |
"π Start Processing", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Group(): | |
gr.Markdown("### βοΈ Control Parameters") | |
seed = gr.Slider( | |
label="Random Seed", | |
minimum=-1, | |
maximum=2147483647, | |
value=1234, | |
step=1, | |
info="-1 for random generation" | |
) | |
with gr.Column(variant="panel"): | |
gr.Markdown("### πΌοΈ Example Gallery", elem_classes=["example-title"]) | |
example = gr.Examples( | |
examples=image_examples, | |
inputs=[ | |
gr.Image(label="Image", type="filepath",visible=False), | |
gr.Image(label="Mask", type="filepath",visible=False) | |
], | |
outputs=[input_image], | |
fn=process_example, | |
run_on_click=True, | |
examples_per_page=10, | |
label="Click any example to load", | |
elem_id="inline-examples" | |
) | |
with gr.Column(scale=1, variant="panel"): | |
gr.Markdown("## π€ Output Panel") | |
with gr.Tabs(): | |
with gr.Tab("Final Result"): | |
inpaint_result = gr.Gallery( | |
label="Generated Image", | |
columns=2, | |
height=450, | |
preview=True, | |
object_fit="contain" | |
) | |
with gr.Tab("Visualization Steps"): | |
gallery = gr.Gallery( | |
label="Workflow Steps", | |
columns=2, | |
height=450, | |
object_fit="contain" | |
) | |
run_button.click( | |
fn=infer, | |
inputs=[ | |
input_image, | |
ddim_steps, | |
seed, | |
scale, | |
removal_prompt, | |
], | |
outputs=[inpaint_result, gallery] | |
) | |
demo.launch() |