Omnieraser / app.py
theSure's picture
Update app.py
5e9a8cc verified
raw
history blame
11.3 kB
import io
import os
import shutil
import uuid
import torch
import random
import spaces
import gradio as gr
print(gr.__version__)
import numpy as np
from PIL import Image, ImageCms
import torch
from diffusers import FluxTransformer2DModel
from diffusers.utils import load_image
from pipeline_flux_control_removal import FluxControlRemovalPipeline
torch.set_grad_enabled(False)
device = "cuda"
print(device)
image_path = mask_path = None
image_examples = [...]
image_path = mask_path =None
image_examples = [
[
"example/image/3c43156c-2b44-4ebf-9c47-7707ec60b166.png",
"example/mask/3c43156c-2b44-4ebf-9c47-7707ec60b166.png"
],
[
"example/image/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png",
"example/mask/0e5124d8-fe43-4b5c-819f-7212f23a6d2a.png"
],
[
"example/image/0f900fe8-6eab-4f85-8121-29cac9509b94.png",
"example/mask/0f900fe8-6eab-4f85-8121-29cac9509b94.png"
],
[
"example/image/3ed1ee18-33b0-4964-b679-0e214a0d8848.png",
"example/mask/3ed1ee18-33b0-4964-b679-0e214a0d8848.png"
],
[
"example/image/9a3b6af9-c733-46a4-88d4-d77604194102.png",
"example/mask/9a3b6af9-c733-46a4-88d4-d77604194102.png"
],
[
"example/image/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png",
"example/mask/87cdf3e2-0fa1-4d80-a228-cbb4aba3f44f.png"
],
[
"example/image/55dd199b-d99b-47a2-a691-edfd92233a6b.png",
"example/mask/55dd199b-d99b-47a2-a691-edfd92233a6b.png"
]
]
@spaces.GPU(duration=120)
def load_model(base_model_path, lora_path):
global pipe
transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
gr.Info(str(f"Model loading: {int((40 / 100) * 100)}%"))
# enable image inputs
with torch.no_grad():
initial_input_channels = transformer.config.in_channels
new_linear = torch.nn.Linear(
transformer.x_embedder.in_features*4,
transformer.x_embedder.out_features,
bias=transformer.x_embedder.bias is not None,
dtype=transformer.dtype,
device=transformer.device,
)
new_linear.weight.zero_()
new_linear.weight[:, :initial_input_channels].copy_(transformer.x_embedder.weight)
if transformer.x_embedder.bias is not None:
new_linear.bias.copy_(transformer.x_embedder.bias)
transformer.x_embedder = new_linear
transformer.register_to_config(in_channels=initial_input_channels*4)
pipe = FluxControlRemovalPipeline.from_pretrained(
base_model_path,
transformer=transformer,
torch_dtype=torch.bfloat16
).to(device)
pipe.transformer.to(torch.bfloat16)
gr.Info(str(f"Model loading: {int((80 / 100) * 100)}%"))
gr.Info(str(f"Inject LoRA: {lora_path}"))
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
@spaces.GPU(duration=120)
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
@spaces.GPU(duration=120)
def predict(
input_image,
prompt,
ddim_steps,
seed,
scale,
image_paths,
mask_paths
):
global image_path, mask_path
gr.Info(str(f"Set seed = {seed}"))
if image_paths is not None:
input_image["image"] = load_image(image_paths).convert("RGB")
input_image["mask"] = load_image(mask_paths).convert("RGB")
size1, size2 = input_image["image"].convert("RGB").size
icc_profile = input_image["image"].info.get('icc_profile')
if icc_profile:
gr.Info(str(f"Image detected to contain ICC profile, converting color space to sRGB..."))
srgb_profile = ImageCms.createProfile("sRGB")
io_handle = io.BytesIO(icc_profile)
src_profile = ImageCms.ImageCmsProfile(io_handle)
input_image["image"] = ImageCms.profileToProfile(input_image["image"], src_profile, srgb_profile)
input_image["image"].info.pop('icc_profile', None)
if size1 < size2:
input_image["image"] = input_image["image"].convert("RGB").resize((1024, int(size2 / size1 * 1024)))
else:
input_image["image"] = input_image["image"].convert("RGB").resize((int(size1 / size2 * 1024), 1024))
img = np.array(input_image["image"].convert("RGB"))
W = int(np.shape(img)[0] - np.shape(img)[0] % 8)
H = int(np.shape(img)[1] - np.shape(img)[1] % 8)
input_image["image"] = input_image["image"].resize((H, W))
input_image["mask"] = input_image["mask"].resize((H, W))
if seed == -1:
seed = random.randint(1, 2147483647)
set_seed(random.randint(1, 2147483647))
else:
set_seed(seed)
result = pipe(
prompt=prompt,
control_image=input_image["image"].convert("RGB"),
control_mask=input_image["mask"].convert("RGB"),
width=H,
height=W,
num_inference_steps=ddim_steps,
generator=torch.Generator(device).manual_seed(seed),
guidance_scale=scale,
max_sequence_length=512,
).images[0]
mask_np = np.array(input_image["mask"].convert("RGB"))
red = np.array(input_image["image"]).astype("float") * 1
red[:, :, 0] = 180.0
red[:, :, 2] = 0
red[:, :, 1] = 0
result_m = np.array(input_image["image"])
result_m = Image.fromarray(
(
result_m.astype("float") * (1 - mask_np.astype("float") / 512.0) + mask_np.astype("float") / 512.0 * red
).astype("uint8")
)
dict_res = [input_image["image"], input_image["mask"], result_m, result]
dict_out = [result]
image_path = None
mask_path = None
return dict_out, dict_res
def infer(
input_image,
ddim_steps,
seed,
scale,
removal_prompt,
):
img_path = image_path
msk_path = mask_path
return predict(input_image,
removal_prompt,
ddim_steps,
seed,
scale,
img_path,
msk_path
)
def process_example(image_paths, mask_paths):
global image_path, mask_path
image = Image.open(image_paths).convert("RGB")
mask = Image.open(mask_paths).convert("L")
black_background = Image.new("RGB", image.size, (0, 0, 0))
masked_image = Image.composite(black_background, image, mask)
image_path = image_paths
mask_path = mask_paths
return masked_image
custom_css = """
.contain { max-width: 1200px !important; }
.custom-image {
border: 2px dashed #7e22ce !important;
border-radius: 12px !important;
transition: all 0.3s ease !important;
}
.custom-image:hover {
border-color: #9333ea !important;
box-shadow: 0 4px 15px rgba(158, 109, 202, 0.2) !important;
}
.btn-primary {
background: linear-gradient(45deg, #7e22ce, #9333ea) !important;
border: none !important;
color: white !important;
border-radius: 8px !important;
}
#inline-examples {
border: 1px solid #e2e8f0 !important;
border-radius: 12px !important;
padding: 16px !important;
margin-top: 8px !important;
}
#inline-examples .thumbnail {
border-radius: 8px !important;
transition: transform 0.2s ease !important;
}
#inline-examples .thumbnail:hover {
transform: scale(1.05);
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
}
.example-title h3 {
margin: 0 0 12px 0 !important;
color: #475569 !important;
font-size: 1.1em !important;
display: flex !important;
align-items: center !important;
}
.example-title h3::before {
content: "πŸ“š";
margin-right: 8px;
font-size: 1.2em;
}
"""
with gr.Blocks(
css=custom_css,
theme=gr.themes.Soft(
primary_hue="purple",
secondary_hue="purple",
font=[gr.themes.GoogleFont('Inter'), 'sans-serif']
),
title="Omnieraser"
) as demo:
base_model_path = "black-forest-labs/FLUX.1-dev"
lora_path = 'theSure/Omnieraser'
load_model(base_model_path=base_model_path, lora_path=lora_path)
ddim_steps = gr.Slider(visible=False, value=28)
scale = gr.Slider(visible=False, value=3.5)
seed = gr.Slider(visible=False, value=-1)
removal_prompt = gr.Textbox(visible=False, value="There is nothing here.")
gr.Markdown("""
<div align="center">
<h1 style="font-size: 2.5em; margin-bottom: 0.5em;">πŸͺ„ Omnieraser</h1>
</div>
""")
with gr.Row(equal_height=True):
with gr.Column(scale=1, variant="panel"):
gr.Markdown("## πŸ“₯ Input Panel")
with gr.Group():
input_image = gr.Image(
type="pil",
tool="sketch",
label="Upload & Annotate",
height=400,
elem_id="custom-image",
interactive=True
)
with gr.Row(variant="compact"):
run_button = gr.Button(
"πŸš€ Start Processing",
variant="primary",
size="lg"
)
with gr.Group():
gr.Markdown("### βš™οΈ Control Parameters")
seed = gr.Slider(
label="Random Seed",
minimum=-1,
maximum=2147483647,
value=1234,
step=1,
info="-1 for random generation"
)
with gr.Column(variant="panel"):
gr.Markdown("### πŸ–ΌοΈ Example Gallery", elem_classes=["example-title"])
example = gr.Examples(
examples=image_examples,
inputs=[
gr.Image(label="Image", type="filepath",visible=False),
gr.Image(label="Mask", type="filepath",visible=False)
],
outputs=[input_image],
fn=process_example,
run_on_click=True,
examples_per_page=10,
label="Click any example to load",
elem_id="inline-examples"
)
with gr.Column(scale=1, variant="panel"):
gr.Markdown("## πŸ“€ Output Panel")
with gr.Tabs():
with gr.Tab("Final Result"):
inpaint_result = gr.Gallery(
label="Generated Image",
columns=2,
height=450,
preview=True,
object_fit="contain"
)
with gr.Tab("Visualization Steps"):
gallery = gr.Gallery(
label="Workflow Steps",
columns=2,
height=450,
object_fit="contain"
)
run_button.click(
fn=infer,
inputs=[
input_image,
ddim_steps,
seed,
scale,
removal_prompt,
],
outputs=[inpaint_result, gallery]
)
demo.launch()