import gradio as gr import spaces from PIL import Image, ImageDraw, ImageOps import base64, json from io import BytesIO import torch.nn.functional as F import json from typing import List from dataclasses import dataclass, field from dreamfuse_inference import DreamFuseInference, InferenceConfig import numpy as np import os from transformers import AutoModelForImageSegmentation from torchvision import transforms import torch import subprocess subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) generated_images = [] RMBG_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) RMBG_model = RMBG_model.to("cuda") transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) @spaces.GPU def remove_bg(image): im = image.convert("RGB") input_tensor = transform(im).unsqueeze(0).to("cuda") with torch.no_grad(): preds = RMBG_model(input_tensor)[-1].sigmoid().cpu()[0].squeeze() mask = transforms.ToPILImage()(preds).resize(im.size) return mask class DreamblendGUI: def __init__(self): self.examples = [ ["./examples/9_02.png", "./examples/9_01.png"], ] self.examples = [[Image.open(x) for x in example] for example in self.examples] self.css_style = self._get_css_style() self.js_script = self._get_js_script() def _get_css_style(self): return """ body { background: transparent; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; color: #fff; } .gradio-container { max-width: 1200px; margin: auto; background: transparent; border-radius: 10px; padding: 20px; box-shadow: 0px 2px 8px rgba(255,255,255,0.1); } h1, h2 { text-align: center; color: #fff; } #canvas_preview { border: 2px dashed rgba(255,255,255,0.5); padding: 10px; background: transparent; border-radius: 8px; } .gr-button { background-color: #007bff; border: none; color: #fff; padding: 10px 20px; border-radius: 5px; font-size: 16px; cursor: pointer; } .gr-button:hover { background-color: #0056b3; } #small-examples { max-width: 200px !important; width: 200px !important; float: left; margin-right: 20px; } """ def _get_js_script(self): return r""" async () => { window.updateTransformation = function() { const img = document.getElementById('draggable-img'); const container = document.getElementById('canvas-container'); if (!img || !container) return; const left = parseFloat(img.style.left) || 0; const top = parseFloat(img.style.top) || 0; const canvasSize = 400; const data_original_width = parseFloat(img.getAttribute('data-original-width')); const data_original_height = parseFloat(img.getAttribute('data-original-height')); const bgWidth = parseFloat(container.dataset.bgWidth); const bgHeight = parseFloat(container.dataset.bgHeight); const scale_ratio = img.clientWidth / data_original_width; const transformation = { drag_left: left, drag_top: top, drag_width: img.clientWidth, drag_height: img.clientHeight, data_original_width: data_original_width, data_original_height: data_original_height, scale_ratio: scale_ratio }; const transInput = document.querySelector("#transformation_info textarea"); if(transInput){ const newValue = JSON.stringify(transformation); const nativeSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, 'value').set; nativeSetter.call(transInput, newValue); transInput.dispatchEvent(new Event('input', { bubbles: true })); console.log("Transformation info updated: ", newValue); } else { console.log("找不到 transformation_info 的 textarea 元素"); } }; globalThis.initializeDrag = () => { console.log("初始化拖拽与缩放功能..."); const observer = new MutationObserver(() => { const img = document.getElementById('draggable-img'); const container = document.getElementById('canvas-container'); const slider = document.getElementById('scale-slider'); if (img && container && slider) { observer.disconnect(); console.log("绑定拖拽与缩放事件..."); img.ondragstart = (e) => { e.preventDefault(); return false; }; let offsetX = 0, offsetY = 0; let isDragging = false; let scaleAnchor = null; img.addEventListener('mousedown', (e) => { isDragging = true; img.style.cursor = 'grabbing'; const imgRect = img.getBoundingClientRect(); offsetX = e.clientX - imgRect.left; offsetY = e.clientY - imgRect.top; img.style.transform = "none"; img.style.left = img.offsetLeft + "px"; img.style.top = img.offsetTop + "px"; console.log("mousedown: left=", img.style.left, "top=", img.style.top); }); document.addEventListener('mousemove', (e) => { if (!isDragging) return; e.preventDefault(); const containerRect = container.getBoundingClientRect(); // 计算当前拖拽后的坐标(基于容器) let left = e.clientX - containerRect.left - offsetX; let top = e.clientY - containerRect.top - offsetY; // 允许的拖拽范围: // 水平方向允许最少超出图像一半:最小值为 -img.clientWidth * (7/8) // 水平方向允许最多超出一半:最大值为 containerRect.width - img.clientWidth * (1/8) const minLeft = -img.clientWidth * (7/8); const maxLeft = containerRect.width - img.clientWidth * (1/8); // 垂直方向允许范围: // 最小值为 -img.clientHeight * (7/8) // 最大值为 containerRect.height - img.clientHeight * (1/8) const minTop = -img.clientHeight * (7/8); const maxTop = containerRect.height - img.clientHeight * (1/8); // 限制范围 if (left < minLeft) left = minLeft; if (left > maxLeft) left = maxLeft; if (top < minTop) top = minTop; if (top > maxTop) top = maxTop; img.style.left = left + "px"; img.style.top = top + "px"; }); window.addEventListener('mouseup', (e) => { if (isDragging) { isDragging = false; img.style.cursor = 'grab'; const containerRect = container.getBoundingClientRect(); const bgWidth = parseFloat(container.dataset.bgWidth); const bgHeight = parseFloat(container.dataset.bgHeight); const offsetLeft = (containerRect.width - bgWidth) / 2; const offsetTop = (containerRect.height - bgHeight) / 2; const absoluteLeft = parseFloat(img.style.left); const absoluteTop = parseFloat(img.style.top); const relativeX = absoluteLeft - offsetLeft; const relativeY = absoluteTop - offsetTop; document.getElementById("coordinate").textContent = `前景图坐标: (x=${relativeX.toFixed(2)}, y=${relativeY.toFixed(2)})`; updateTransformation(); } scaleAnchor = null; }); slider.addEventListener('mousedown', (e) => { const containerRect = container.getBoundingClientRect(); const imgRect = img.getBoundingClientRect(); scaleAnchor = { x: imgRect.left + imgRect.width/2 - containerRect.left, y: imgRect.top + imgRect.height/2 - containerRect.top }; console.log("Slider mousedown, captured scaleAnchor: ", scaleAnchor); }); slider.addEventListener('input', (e) => { const scale = parseFloat(e.target.value); const originalWidth = parseFloat(img.getAttribute('data-original-width')); const originalHeight = parseFloat(img.getAttribute('data-original-height')); const newWidth = originalWidth * scale; const newHeight = originalHeight * scale; const containerRect = container.getBoundingClientRect(); let centerX, centerY; if (scaleAnchor) { centerX = scaleAnchor.x; centerY = scaleAnchor.y; } else { const imgRect = img.getBoundingClientRect(); centerX = imgRect.left + imgRect.width/2 - containerRect.left; centerY = imgRect.top + imgRect.height/2 - containerRect.top; } const newLeft = centerX - newWidth/2; const newTop = centerY - newHeight/2; img.style.width = newWidth + "px"; img.style.height = newHeight + "px"; img.style.left = newLeft + "px"; img.style.top = newTop + "px"; console.log("slider: scale=", scale, "newWidth=", newWidth, "newHeight=", newHeight); updateTransformation(); }); slider.addEventListener('mouseup', (e) => { scaleAnchor = null; }); } }); observer.observe(document.body, { childList: true, subtree: true }); }; } """ def get_next_sequence(self, folder_path): # 列出文件夹中的所有文件名 filenames = os.listdir(folder_path) # 提取文件名中的序列号部分(假设是前三位数字) sequences = [int(name.split('_')[0]) for name in filenames if name.split('_')[0].isdigit()] # 找到最大序列号 max_sequence = max(sequences, default=-1) # 返回下一位序列号,格式为三位数字(如002) return f"{max_sequence + 1:03d}" def pil_to_base64(self, img): """将 PIL Image 转为 base64 字符串,PNG 格式下保留透明通道""" if img is None: return "" if img.mode != "RGBA": img = img.convert("RGBA") buffered = BytesIO() img.save(buffered, format="PNG", optimize=True) img_bytes = buffered.getvalue() base64_str = base64.b64encode(img_bytes).decode() return f"data:image/png;base64,{base64_str}" def resize_background_image(self, img, max_size=400): """将背景图等比例缩放到最长边为 max_size(400)""" if img is None: return None w, h = img.size if w > max_size or h > max_size: ratio = min(max_size / w, max_size / h) new_w, new_h = int(w * ratio), int(h * ratio) img = img.resize((new_w, new_h), Image.LANCZOS) return img def resize_draggable_image(self, img, max_size=400): """将前景图等比例缩放到最长边不超过 max_size(400)""" if img is None: return None w, h = img.size if w > max_size or h > max_size: ratio = min(max_size / w, max_size / h) new_w, new_h = int(w * ratio), int(h * ratio) img = img.resize((new_w, new_h), Image.LANCZOS) return img def generate_html(self, background_img_b64, bg_width, bg_height, draggable_img_b64, draggable_width, draggable_height, canvas_size=400): """生成预览 HTML 页面""" html_code = f"""
前景图坐标: (x=?, y=?)
""" return html_code def on_upload(self, background_img, draggable_img): """上传图片后的处理""" if background_img is None or draggable_img is None: return "请先上传背景图片和可拖拽图片。
" if draggable_img.mode != "RGB": draggable_img = draggable_img.convert("RGB") draggable_img_mask = remove_bg(draggable_img) alpha_channel = draggable_img_mask.convert("L") draggable_img = draggable_img.convert("RGBA") draggable_img.putalpha(alpha_channel) resized_bg = self.resize_background_image(background_img, max_size=400) bg_w, bg_h = resized_bg.size resized_fg = self.resize_draggable_image(draggable_img, max_size=400) draggable_width, draggable_height = resized_fg.size background_img_b64 = self.pil_to_base64(resized_bg) draggable_img_b64 = self.pil_to_base64(resized_fg) return self.generate_html( background_img_b64, bg_w, bg_h, draggable_img_b64, draggable_width, draggable_height, canvas_size=400 ), draggable_img def save_image(self, save_path = "/mnt/bn/hjj-humanseg-lq/SubjectDriven/DreamFuse/debug"): global generated_images save_name = self.get_next_sequence(save_path) generated_images[0].save(os.path.join(save_path, f"{save_name}_0_ori.png")) generated_images[1].save(os.path.join(save_path, f"{save_name}_0.png")) generated_images[2].save(os.path.join(save_path, f"{save_name}_1.png")) generated_images[3].save(os.path.join(save_path, f"{save_name}_2.png")) generated_images[4].save(os.path.join(save_path, f"{save_name}_0_mask.png")) generated_images[5].save(os.path.join(save_path, f"{save_name}_0_mask_scale.png")) generated_images[6].save(os.path.join(save_path, f"{save_name}_0_scale.png")) generated_images[7].save(os.path.join(save_path, f"{save_name}_2_pasted.png")) def create_gui(self): config = InferenceConfig() config.lora_id = 'LL3RD/DreamFuse' pipeline = DreamFuseInference(config) pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate) """创建 Gradio 界面""" with gr.Blocks(css=self.css_style) as demo: modified_fg_state = gr.State() gr.Markdown("# Dreamblend-GUI-dirtydata") gr.Markdown("通过上传背景图与前景图生成带有可拖拽/缩放预览的合成图像,同时支持 Seed 设置和 Prompt 文本输入。") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 上传图片") background_img_in = gr.Image(label="背景图片", type="pil", height=240, width=240) draggable_img_in = gr.Image(label="前景图片", type="pil", image_mode="RGBA", height=240, width=240) generate_btn = gr.Button("生成可拖拽画布") with gr.Row(): gr.Examples( examples=[self.examples[0]], inputs=[background_img_in, draggable_img_in], elem_id="small-examples" ) with gr.Column(scale=1): gr.Markdown("### 预览区域") html_out = gr.HTML(label="预览与拖拽", elem_id="canvas_preview") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 参数设置") seed_slider = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5) size_select = gr.Radio( choices=["512", "768", "1024"], value="512", label="生成质量(512-差 1024-好)", ) prompt_text = gr.Textbox(label="Prompt", placeholder="输入文本提示", value="") text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength", value=1) enable_gui = gr.Checkbox(label="启用GUI", value=True) enable_truecfg = gr.Checkbox(label="启用TrueCFG", value=False) enable_save = gr.Button("保存图片 (内部测试)", visible=True) with gr.Column(scale=1): gr.Markdown("### 模型生成结果") model_generate_btn = gr.Button("模型生成") transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False) model_output = gr.Image(label="模型输出", type="pil") # 交互事件绑定 enable_save.click(fn=self.save_image, inputs=None, outputs=None) generate_btn.click( fn=self.on_upload, inputs=[background_img_in, draggable_img_in], outputs=[html_out, modified_fg_state], ) model_generate_btn.click( fn=pipeline.gradio_generate, inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \ prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg], outputs=model_output ) # 页面加载后初始化拖拽/缩放事件 demo.load(None, None, None, js=self.js_script) generate_btn.click(fn=None, inputs=None, outputs=None, js="initializeDrag") return demo if __name__ == "__main__": gui = DreamblendGUI() demo = gui.create_gui() demo.queue() demo.launch() # demo.launch(server_port=7789, ssr_mode=False) # demo.launch(server_name="[::]", share=True)