Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
from PIL import Image, ImageDraw, ImageOps | |
import base64, json | |
from io import BytesIO | |
import torch.nn.functional as F | |
import json | |
from typing import List | |
from dataclasses import dataclass, field | |
from dreamfuse_inference import DreamFuseInference, InferenceConfig | |
import numpy as np | |
import os | |
from transformers import AutoModelForImageSegmentation | |
from torchvision import transforms | |
import torch | |
import subprocess | |
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
generated_images = [] | |
RMBG_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) | |
RMBG_model = RMBG_model.to("cuda") | |
transform = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
def remove_bg(image): | |
im = image.convert("RGB") | |
input_tensor = transform(im).unsqueeze(0).to("cuda") | |
with torch.no_grad(): | |
preds = RMBG_model(input_tensor)[-1].sigmoid().cpu()[0].squeeze() | |
mask = transforms.ToPILImage()(preds).resize(im.size) | |
return mask | |
class DreamblendGUI: | |
def __init__(self): | |
self.examples = [ | |
["./examples/9_02.png", | |
"./examples/9_01.png"], | |
] | |
self.examples = [[Image.open(x) for x in example] for example in self.examples] | |
self.css_style = self._get_css_style() | |
self.js_script = self._get_js_script() | |
def _get_css_style(self): | |
return """ | |
body { | |
background: transparent; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
color: #fff; | |
} | |
.gradio-container { | |
max-width: 1200px; | |
margin: auto; | |
background: transparent; | |
border-radius: 10px; | |
padding: 20px; | |
box-shadow: 0px 2px 8px rgba(255,255,255,0.1); | |
} | |
h1, h2 { | |
text-align: center; | |
color: #fff; | |
} | |
#canvas_preview { | |
border: 2px dashed rgba(255,255,255,0.5); | |
padding: 10px; | |
background: transparent; | |
border-radius: 8px; | |
} | |
.gr-button { | |
background-color: #007bff; | |
border: none; | |
color: #fff; | |
padding: 10px 20px; | |
border-radius: 5px; | |
font-size: 16px; | |
cursor: pointer; | |
} | |
.gr-button:hover { | |
background-color: #0056b3; | |
} | |
#small-examples { | |
max-width: 200px !important; | |
width: 200px !important; | |
float: left; | |
margin-right: 20px; | |
} | |
""" | |
def _get_js_script(self): | |
return r""" | |
async () => { | |
window.updateTransformation = function() { | |
const img = document.getElementById('draggable-img'); | |
const container = document.getElementById('canvas-container'); | |
if (!img || !container) return; | |
const left = parseFloat(img.style.left) || 0; | |
const top = parseFloat(img.style.top) || 0; | |
const canvasSize = 400; | |
const data_original_width = parseFloat(img.getAttribute('data-original-width')); | |
const data_original_height = parseFloat(img.getAttribute('data-original-height')); | |
const bgWidth = parseFloat(container.dataset.bgWidth); | |
const bgHeight = parseFloat(container.dataset.bgHeight); | |
const scale_ratio = img.clientWidth / data_original_width; | |
const transformation = { | |
drag_left: left, | |
drag_top: top, | |
drag_width: img.clientWidth, | |
drag_height: img.clientHeight, | |
data_original_width: data_original_width, | |
data_original_height: data_original_height, | |
scale_ratio: scale_ratio | |
}; | |
const transInput = document.querySelector("#transformation_info textarea"); | |
if(transInput){ | |
const newValue = JSON.stringify(transformation); | |
const nativeSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, 'value').set; | |
nativeSetter.call(transInput, newValue); | |
transInput.dispatchEvent(new Event('input', { bubbles: true })); | |
console.log("Transformation info updated: ", newValue); | |
} else { | |
console.log("找不到 transformation_info 的 textarea 元素"); | |
} | |
}; | |
globalThis.initializeDrag = () => { | |
console.log("初始化拖拽与缩放功能..."); | |
const observer = new MutationObserver(() => { | |
const img = document.getElementById('draggable-img'); | |
const container = document.getElementById('canvas-container'); | |
const slider = document.getElementById('scale-slider'); | |
if (img && container && slider) { | |
observer.disconnect(); | |
console.log("绑定拖拽与缩放事件..."); | |
img.ondragstart = (e) => { e.preventDefault(); return false; }; | |
let offsetX = 0, offsetY = 0; | |
let isDragging = false; | |
let scaleAnchor = null; | |
img.addEventListener('mousedown', (e) => { | |
isDragging = true; | |
img.style.cursor = 'grabbing'; | |
const imgRect = img.getBoundingClientRect(); | |
offsetX = e.clientX - imgRect.left; | |
offsetY = e.clientY - imgRect.top; | |
img.style.transform = "none"; | |
img.style.left = img.offsetLeft + "px"; | |
img.style.top = img.offsetTop + "px"; | |
console.log("mousedown: left=", img.style.left, "top=", img.style.top); | |
}); | |
document.addEventListener('mousemove', (e) => { | |
if (!isDragging) return; | |
e.preventDefault(); | |
const containerRect = container.getBoundingClientRect(); | |
// 计算当前拖拽后的坐标(基于容器) | |
let left = e.clientX - containerRect.left - offsetX; | |
let top = e.clientY - containerRect.top - offsetY; | |
// 允许的拖拽范围: | |
// 水平方向允许最少超出图像一半:最小值为 -img.clientWidth * (7/8) | |
// 水平方向允许最多超出一半:最大值为 containerRect.width - img.clientWidth * (1/8) | |
const minLeft = -img.clientWidth * (7/8); | |
const maxLeft = containerRect.width - img.clientWidth * (1/8); | |
// 垂直方向允许范围: | |
// 最小值为 -img.clientHeight * (7/8) | |
// 最大值为 containerRect.height - img.clientHeight * (1/8) | |
const minTop = -img.clientHeight * (7/8); | |
const maxTop = containerRect.height - img.clientHeight * (1/8); | |
// 限制范围 | |
if (left < minLeft) left = minLeft; | |
if (left > maxLeft) left = maxLeft; | |
if (top < minTop) top = minTop; | |
if (top > maxTop) top = maxTop; | |
img.style.left = left + "px"; | |
img.style.top = top + "px"; | |
}); | |
window.addEventListener('mouseup', (e) => { | |
if (isDragging) { | |
isDragging = false; | |
img.style.cursor = 'grab'; | |
const containerRect = container.getBoundingClientRect(); | |
const bgWidth = parseFloat(container.dataset.bgWidth); | |
const bgHeight = parseFloat(container.dataset.bgHeight); | |
const offsetLeft = (containerRect.width - bgWidth) / 2; | |
const offsetTop = (containerRect.height - bgHeight) / 2; | |
const absoluteLeft = parseFloat(img.style.left); | |
const absoluteTop = parseFloat(img.style.top); | |
const relativeX = absoluteLeft - offsetLeft; | |
const relativeY = absoluteTop - offsetTop; | |
document.getElementById("coordinate").textContent = | |
`前景图坐标: (x=${relativeX.toFixed(2)}, y=${relativeY.toFixed(2)})`; | |
updateTransformation(); | |
} | |
scaleAnchor = null; | |
}); | |
slider.addEventListener('mousedown', (e) => { | |
const containerRect = container.getBoundingClientRect(); | |
const imgRect = img.getBoundingClientRect(); | |
scaleAnchor = { | |
x: imgRect.left + imgRect.width/2 - containerRect.left, | |
y: imgRect.top + imgRect.height/2 - containerRect.top | |
}; | |
console.log("Slider mousedown, captured scaleAnchor: ", scaleAnchor); | |
}); | |
slider.addEventListener('input', (e) => { | |
const scale = parseFloat(e.target.value); | |
const originalWidth = parseFloat(img.getAttribute('data-original-width')); | |
const originalHeight = parseFloat(img.getAttribute('data-original-height')); | |
const newWidth = originalWidth * scale; | |
const newHeight = originalHeight * scale; | |
const containerRect = container.getBoundingClientRect(); | |
let centerX, centerY; | |
if (scaleAnchor) { | |
centerX = scaleAnchor.x; | |
centerY = scaleAnchor.y; | |
} else { | |
const imgRect = img.getBoundingClientRect(); | |
centerX = imgRect.left + imgRect.width/2 - containerRect.left; | |
centerY = imgRect.top + imgRect.height/2 - containerRect.top; | |
} | |
const newLeft = centerX - newWidth/2; | |
const newTop = centerY - newHeight/2; | |
img.style.width = newWidth + "px"; | |
img.style.height = newHeight + "px"; | |
img.style.left = newLeft + "px"; | |
img.style.top = newTop + "px"; | |
console.log("slider: scale=", scale, "newWidth=", newWidth, "newHeight=", newHeight); | |
updateTransformation(); | |
}); | |
slider.addEventListener('mouseup', (e) => { | |
scaleAnchor = null; | |
}); | |
} | |
}); | |
observer.observe(document.body, { childList: true, subtree: true }); | |
}; | |
} | |
""" | |
def get_next_sequence(self, folder_path): | |
# 列出文件夹中的所有文件名 | |
filenames = os.listdir(folder_path) | |
# 提取文件名中的序列号部分(假设是前三位数字) | |
sequences = [int(name.split('_')[0]) for name in filenames if name.split('_')[0].isdigit()] | |
# 找到最大序列号 | |
max_sequence = max(sequences, default=-1) | |
# 返回下一位序列号,格式为三位数字(如002) | |
return f"{max_sequence + 1:03d}" | |
def pil_to_base64(self, img): | |
"""将 PIL Image 转为 base64 字符串,PNG 格式下保留透明通道""" | |
if img is None: | |
return "" | |
if img.mode != "RGBA": | |
img = img.convert("RGBA") | |
buffered = BytesIO() | |
img.save(buffered, format="PNG", optimize=True) | |
img_bytes = buffered.getvalue() | |
base64_str = base64.b64encode(img_bytes).decode() | |
return f"data:image/png;base64,{base64_str}" | |
def resize_background_image(self, img, max_size=400): | |
"""将背景图等比例缩放到最长边为 max_size(400)""" | |
if img is None: | |
return None | |
w, h = img.size | |
if w > max_size or h > max_size: | |
ratio = min(max_size / w, max_size / h) | |
new_w, new_h = int(w * ratio), int(h * ratio) | |
img = img.resize((new_w, new_h), Image.LANCZOS) | |
return img | |
def resize_draggable_image(self, img, max_size=400): | |
"""将前景图等比例缩放到最长边不超过 max_size(400)""" | |
if img is None: | |
return None | |
w, h = img.size | |
if w > max_size or h > max_size: | |
ratio = min(max_size / w, max_size / h) | |
new_w, new_h = int(w * ratio), int(h * ratio) | |
img = img.resize((new_w, new_h), Image.LANCZOS) | |
return img | |
def generate_html(self, background_img_b64, bg_width, bg_height, draggable_img_b64, draggable_width, draggable_height, canvas_size=400): | |
"""生成预览 HTML 页面""" | |
html_code = f""" | |
<html> | |
<head> | |
<style> | |
body {{ | |
margin: 0; | |
padding: 0; | |
text-align: center; | |
font-family: sans-serif; | |
background: transparent; | |
color: #fff; | |
}} | |
h2 {{ | |
margin-top: 1rem; | |
}} | |
#scale-control {{ | |
margin: 1rem auto; | |
width: 400px; | |
text-align: left; | |
}} | |
#scale-control label {{ | |
font-size: 1rem; | |
margin-right: 0.5rem; | |
}} | |
#canvas-container {{ | |
position: relative; | |
width: {canvas_size}px; | |
height: {canvas_size}px; | |
margin: 0 auto; | |
border: 1px dashed rgba(255,255,255,0.5); | |
overflow: hidden; | |
background-image: url('{background_img_b64}'); | |
background-repeat: no-repeat; | |
background-position: center; | |
background-size: contain; | |
border-radius: 8px; | |
}} | |
#draggable-img {{ | |
position: absolute; | |
cursor: grab; | |
left: 50%; | |
top: 50%; | |
transform: translate(-50%, -50%); | |
background-color: transparent; | |
}} | |
#coordinate {{ | |
color: #fff; | |
margin-top: 1rem; | |
font-weight: bold; | |
}} | |
</style> | |
</head> | |
<body> | |
<h2>拖拽前景图(支持缩放)</h2> | |
<div id="scale-control"> | |
<label for="scale-slider">前景图缩放:</label> | |
<input type="range" id="scale-slider" min="0.1" max="2" step="0.01" value="1"> | |
</div> | |
<div id="canvas-container" data-bg-width="{bg_width}" data-bg-height="{bg_height}"> | |
<img id="draggable-img" | |
src="{draggable_img_b64}" | |
alt="Draggable Image" | |
draggable="false" | |
data-original-width="{draggable_width}" | |
data-original-height="{draggable_height}" | |
/> | |
</div> | |
<p id="coordinate">前景图坐标: (x=?, y=?)</p> | |
</body> | |
</html> | |
""" | |
return html_code | |
def on_upload(self, background_img, draggable_img): | |
"""上传图片后的处理""" | |
if background_img is None or draggable_img is None: | |
return "<p style='color:red;'>请先上传背景图片和可拖拽图片。</p>" | |
if draggable_img.mode != "RGB": | |
draggable_img = draggable_img.convert("RGB") | |
draggable_img_mask = remove_bg(draggable_img) | |
alpha_channel = draggable_img_mask.convert("L") | |
draggable_img = draggable_img.convert("RGBA") | |
draggable_img.putalpha(alpha_channel) | |
resized_bg = self.resize_background_image(background_img, max_size=400) | |
bg_w, bg_h = resized_bg.size | |
resized_fg = self.resize_draggable_image(draggable_img, max_size=400) | |
draggable_width, draggable_height = resized_fg.size | |
background_img_b64 = self.pil_to_base64(resized_bg) | |
draggable_img_b64 = self.pil_to_base64(resized_fg) | |
return self.generate_html( | |
background_img_b64, bg_w, bg_h, | |
draggable_img_b64, draggable_width, draggable_height, | |
canvas_size=400 | |
), draggable_img | |
def save_image(self, save_path = "/mnt/bn/hjj-humanseg-lq/SubjectDriven/DreamFuse/debug"): | |
global generated_images | |
save_name = self.get_next_sequence(save_path) | |
generated_images[0].save(os.path.join(save_path, f"{save_name}_0_ori.png")) | |
generated_images[1].save(os.path.join(save_path, f"{save_name}_0.png")) | |
generated_images[2].save(os.path.join(save_path, f"{save_name}_1.png")) | |
generated_images[3].save(os.path.join(save_path, f"{save_name}_2.png")) | |
generated_images[4].save(os.path.join(save_path, f"{save_name}_0_mask.png")) | |
generated_images[5].save(os.path.join(save_path, f"{save_name}_0_mask_scale.png")) | |
generated_images[6].save(os.path.join(save_path, f"{save_name}_0_scale.png")) | |
generated_images[7].save(os.path.join(save_path, f"{save_name}_2_pasted.png")) | |
def create_gui(self): | |
config = InferenceConfig() | |
config.lora_id = 'LL3RD/DreamFuse' | |
pipeline = DreamFuseInference(config) | |
pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate) | |
"""创建 Gradio 界面""" | |
with gr.Blocks(css=self.css_style) as demo: | |
modified_fg_state = gr.State() | |
gr.Markdown("# Dreamblend-GUI-dirtydata") | |
gr.Markdown("通过上传背景图与前景图生成带有可拖拽/缩放预览的合成图像,同时支持 Seed 设置和 Prompt 文本输入。") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### 上传图片") | |
background_img_in = gr.Image(label="背景图片", type="pil", height=240, width=240) | |
draggable_img_in = gr.Image(label="前景图片", type="pil", image_mode="RGBA", height=240, width=240) | |
generate_btn = gr.Button("生成可拖拽画布") | |
with gr.Row(): | |
gr.Examples( | |
examples=[self.examples[0]], | |
inputs=[background_img_in, draggable_img_in], | |
elem_id="small-examples" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("### 预览区域") | |
html_out = gr.HTML(label="预览与拖拽", elem_id="canvas_preview") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### 参数设置") | |
seed_slider = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) | |
cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5) | |
size_select = gr.Radio( | |
choices=["512", "768", "1024"], | |
value="512", | |
label="生成质量(512-差 1024-好)", | |
) | |
prompt_text = gr.Textbox(label="Prompt", placeholder="输入文本提示", value="") | |
text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength", value=1) | |
enable_gui = gr.Checkbox(label="启用GUI", value=True) | |
enable_truecfg = gr.Checkbox(label="启用TrueCFG", value=False) | |
enable_save = gr.Button("保存图片 (内部测试)", visible=True) | |
with gr.Column(scale=1): | |
gr.Markdown("### 模型生成结果") | |
model_generate_btn = gr.Button("模型生成") | |
transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False) | |
model_output = gr.Image(label="模型输出", type="pil") | |
# 交互事件绑定 | |
enable_save.click(fn=self.save_image, inputs=None, outputs=None) | |
generate_btn.click( | |
fn=self.on_upload, | |
inputs=[background_img_in, draggable_img_in], | |
outputs=[html_out, modified_fg_state], | |
) | |
model_generate_btn.click( | |
fn=pipeline.gradio_generate, | |
inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \ | |
prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg], | |
outputs=model_output | |
) | |
# 页面加载后初始化拖拽/缩放事件 | |
demo.load(None, None, None, js=self.js_script) | |
generate_btn.click(fn=None, inputs=None, outputs=None, js="initializeDrag") | |
return demo | |
if __name__ == "__main__": | |
gui = DreamblendGUI() | |
demo = gui.create_gui() | |
demo.queue() | |
demo.launch() | |
# demo.launch(server_port=7789, ssr_mode=False) | |
# demo.launch(server_name="[::]", share=True) | |