DreamFuse / app.py
LL3RD's picture
test
f96f677
raw
history blame
22.6 kB
import gradio as gr
import spaces
from PIL import Image, ImageDraw, ImageOps
import base64, json
from io import BytesIO
import torch.nn.functional as F
import json
from typing import List
from dataclasses import dataclass, field
from dreamfuse_inference import DreamFuseInference, InferenceConfig
import numpy as np
import os
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
import torch
import subprocess
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
generated_images = []
RMBG_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
RMBG_model = RMBG_model.to("cuda")
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
@spaces.GPU
def remove_bg(image):
im = image.convert("RGB")
input_tensor = transform(im).unsqueeze(0).to("cuda")
with torch.no_grad():
preds = RMBG_model(input_tensor)[-1].sigmoid().cpu()[0].squeeze()
mask = transforms.ToPILImage()(preds).resize(im.size)
return mask
class DreamblendGUI:
def __init__(self):
self.examples = [
["./examples/9_02.png",
"./examples/9_01.png"],
]
self.examples = [[Image.open(x) for x in example] for example in self.examples]
self.css_style = self._get_css_style()
self.js_script = self._get_js_script()
def _get_css_style(self):
return """
body {
background: transparent;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
color: #fff;
}
.gradio-container {
max-width: 1200px;
margin: auto;
background: transparent;
border-radius: 10px;
padding: 20px;
box-shadow: 0px 2px 8px rgba(255,255,255,0.1);
}
h1, h2 {
text-align: center;
color: #fff;
}
#canvas_preview {
border: 2px dashed rgba(255,255,255,0.5);
padding: 10px;
background: transparent;
border-radius: 8px;
}
.gr-button {
background-color: #007bff;
border: none;
color: #fff;
padding: 10px 20px;
border-radius: 5px;
font-size: 16px;
cursor: pointer;
}
.gr-button:hover {
background-color: #0056b3;
}
#small-examples {
max-width: 200px !important;
width: 200px !important;
float: left;
margin-right: 20px;
}
"""
def _get_js_script(self):
return r"""
async () => {
window.updateTransformation = function() {
const img = document.getElementById('draggable-img');
const container = document.getElementById('canvas-container');
if (!img || !container) return;
const left = parseFloat(img.style.left) || 0;
const top = parseFloat(img.style.top) || 0;
const canvasSize = 400;
const data_original_width = parseFloat(img.getAttribute('data-original-width'));
const data_original_height = parseFloat(img.getAttribute('data-original-height'));
const bgWidth = parseFloat(container.dataset.bgWidth);
const bgHeight = parseFloat(container.dataset.bgHeight);
const scale_ratio = img.clientWidth / data_original_width;
const transformation = {
drag_left: left,
drag_top: top,
drag_width: img.clientWidth,
drag_height: img.clientHeight,
data_original_width: data_original_width,
data_original_height: data_original_height,
scale_ratio: scale_ratio
};
const transInput = document.querySelector("#transformation_info textarea");
if(transInput){
const newValue = JSON.stringify(transformation);
const nativeSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, 'value').set;
nativeSetter.call(transInput, newValue);
transInput.dispatchEvent(new Event('input', { bubbles: true }));
console.log("Transformation info updated: ", newValue);
} else {
console.log("找不到 transformation_info 的 textarea 元素");
}
};
globalThis.initializeDrag = () => {
console.log("初始化拖拽与缩放功能...");
const observer = new MutationObserver(() => {
const img = document.getElementById('draggable-img');
const container = document.getElementById('canvas-container');
const slider = document.getElementById('scale-slider');
if (img && container && slider) {
observer.disconnect();
console.log("绑定拖拽与缩放事件...");
img.ondragstart = (e) => { e.preventDefault(); return false; };
let offsetX = 0, offsetY = 0;
let isDragging = false;
let scaleAnchor = null;
img.addEventListener('mousedown', (e) => {
isDragging = true;
img.style.cursor = 'grabbing';
const imgRect = img.getBoundingClientRect();
offsetX = e.clientX - imgRect.left;
offsetY = e.clientY - imgRect.top;
img.style.transform = "none";
img.style.left = img.offsetLeft + "px";
img.style.top = img.offsetTop + "px";
console.log("mousedown: left=", img.style.left, "top=", img.style.top);
});
document.addEventListener('mousemove', (e) => {
if (!isDragging) return;
e.preventDefault();
const containerRect = container.getBoundingClientRect();
// 计算当前拖拽后的坐标(基于容器)
let left = e.clientX - containerRect.left - offsetX;
let top = e.clientY - containerRect.top - offsetY;
// 允许的拖拽范围:
// 水平方向允许最少超出图像一半:最小值为 -img.clientWidth * (7/8)
// 水平方向允许最多超出一半:最大值为 containerRect.width - img.clientWidth * (1/8)
const minLeft = -img.clientWidth * (7/8);
const maxLeft = containerRect.width - img.clientWidth * (1/8);
// 垂直方向允许范围:
// 最小值为 -img.clientHeight * (7/8)
// 最大值为 containerRect.height - img.clientHeight * (1/8)
const minTop = -img.clientHeight * (7/8);
const maxTop = containerRect.height - img.clientHeight * (1/8);
// 限制范围
if (left < minLeft) left = minLeft;
if (left > maxLeft) left = maxLeft;
if (top < minTop) top = minTop;
if (top > maxTop) top = maxTop;
img.style.left = left + "px";
img.style.top = top + "px";
});
window.addEventListener('mouseup', (e) => {
if (isDragging) {
isDragging = false;
img.style.cursor = 'grab';
const containerRect = container.getBoundingClientRect();
const bgWidth = parseFloat(container.dataset.bgWidth);
const bgHeight = parseFloat(container.dataset.bgHeight);
const offsetLeft = (containerRect.width - bgWidth) / 2;
const offsetTop = (containerRect.height - bgHeight) / 2;
const absoluteLeft = parseFloat(img.style.left);
const absoluteTop = parseFloat(img.style.top);
const relativeX = absoluteLeft - offsetLeft;
const relativeY = absoluteTop - offsetTop;
document.getElementById("coordinate").textContent =
`前景图坐标: (x=${relativeX.toFixed(2)}, y=${relativeY.toFixed(2)})`;
updateTransformation();
}
scaleAnchor = null;
});
slider.addEventListener('mousedown', (e) => {
const containerRect = container.getBoundingClientRect();
const imgRect = img.getBoundingClientRect();
scaleAnchor = {
x: imgRect.left + imgRect.width/2 - containerRect.left,
y: imgRect.top + imgRect.height/2 - containerRect.top
};
console.log("Slider mousedown, captured scaleAnchor: ", scaleAnchor);
});
slider.addEventListener('input', (e) => {
const scale = parseFloat(e.target.value);
const originalWidth = parseFloat(img.getAttribute('data-original-width'));
const originalHeight = parseFloat(img.getAttribute('data-original-height'));
const newWidth = originalWidth * scale;
const newHeight = originalHeight * scale;
const containerRect = container.getBoundingClientRect();
let centerX, centerY;
if (scaleAnchor) {
centerX = scaleAnchor.x;
centerY = scaleAnchor.y;
} else {
const imgRect = img.getBoundingClientRect();
centerX = imgRect.left + imgRect.width/2 - containerRect.left;
centerY = imgRect.top + imgRect.height/2 - containerRect.top;
}
const newLeft = centerX - newWidth/2;
const newTop = centerY - newHeight/2;
img.style.width = newWidth + "px";
img.style.height = newHeight + "px";
img.style.left = newLeft + "px";
img.style.top = newTop + "px";
console.log("slider: scale=", scale, "newWidth=", newWidth, "newHeight=", newHeight);
updateTransformation();
});
slider.addEventListener('mouseup', (e) => {
scaleAnchor = null;
});
}
});
observer.observe(document.body, { childList: true, subtree: true });
};
}
"""
def get_next_sequence(self, folder_path):
# 列出文件夹中的所有文件名
filenames = os.listdir(folder_path)
# 提取文件名中的序列号部分(假设是前三位数字)
sequences = [int(name.split('_')[0]) for name in filenames if name.split('_')[0].isdigit()]
# 找到最大序列号
max_sequence = max(sequences, default=-1)
# 返回下一位序列号,格式为三位数字(如002)
return f"{max_sequence + 1:03d}"
def pil_to_base64(self, img):
"""将 PIL Image 转为 base64 字符串,PNG 格式下保留透明通道"""
if img is None:
return ""
if img.mode != "RGBA":
img = img.convert("RGBA")
buffered = BytesIO()
img.save(buffered, format="PNG", optimize=True)
img_bytes = buffered.getvalue()
base64_str = base64.b64encode(img_bytes).decode()
return f"data:image/png;base64,{base64_str}"
def resize_background_image(self, img, max_size=400):
"""将背景图等比例缩放到最长边为 max_size(400)"""
if img is None:
return None
w, h = img.size
if w > max_size or h > max_size:
ratio = min(max_size / w, max_size / h)
new_w, new_h = int(w * ratio), int(h * ratio)
img = img.resize((new_w, new_h), Image.LANCZOS)
return img
def resize_draggable_image(self, img, max_size=400):
"""将前景图等比例缩放到最长边不超过 max_size(400)"""
if img is None:
return None
w, h = img.size
if w > max_size or h > max_size:
ratio = min(max_size / w, max_size / h)
new_w, new_h = int(w * ratio), int(h * ratio)
img = img.resize((new_w, new_h), Image.LANCZOS)
return img
def generate_html(self, background_img_b64, bg_width, bg_height, draggable_img_b64, draggable_width, draggable_height, canvas_size=400):
"""生成预览 HTML 页面"""
html_code = f"""
<html>
<head>
<style>
body {{
margin: 0;
padding: 0;
text-align: center;
font-family: sans-serif;
background: transparent;
color: #fff;
}}
h2 {{
margin-top: 1rem;
}}
#scale-control {{
margin: 1rem auto;
width: 400px;
text-align: left;
}}
#scale-control label {{
font-size: 1rem;
margin-right: 0.5rem;
}}
#canvas-container {{
position: relative;
width: {canvas_size}px;
height: {canvas_size}px;
margin: 0 auto;
border: 1px dashed rgba(255,255,255,0.5);
overflow: hidden;
background-image: url('{background_img_b64}');
background-repeat: no-repeat;
background-position: center;
background-size: contain;
border-radius: 8px;
}}
#draggable-img {{
position: absolute;
cursor: grab;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
background-color: transparent;
}}
#coordinate {{
color: #fff;
margin-top: 1rem;
font-weight: bold;
}}
</style>
</head>
<body>
<h2>拖拽前景图(支持缩放)</h2>
<div id="scale-control">
<label for="scale-slider">前景图缩放:</label>
<input type="range" id="scale-slider" min="0.1" max="2" step="0.01" value="1">
</div>
<div id="canvas-container" data-bg-width="{bg_width}" data-bg-height="{bg_height}">
<img id="draggable-img"
src="{draggable_img_b64}"
alt="Draggable Image"
draggable="false"
data-original-width="{draggable_width}"
data-original-height="{draggable_height}"
/>
</div>
<p id="coordinate">前景图坐标: (x=?, y=?)</p>
</body>
</html>
"""
return html_code
def on_upload(self, background_img, draggable_img):
"""上传图片后的处理"""
if background_img is None or draggable_img is None:
return "<p style='color:red;'>请先上传背景图片和可拖拽图片。</p>"
if draggable_img.mode != "RGB":
draggable_img = draggable_img.convert("RGB")
draggable_img_mask = remove_bg(draggable_img)
alpha_channel = draggable_img_mask.convert("L")
draggable_img = draggable_img.convert("RGBA")
draggable_img.putalpha(alpha_channel)
resized_bg = self.resize_background_image(background_img, max_size=400)
bg_w, bg_h = resized_bg.size
resized_fg = self.resize_draggable_image(draggable_img, max_size=400)
draggable_width, draggable_height = resized_fg.size
background_img_b64 = self.pil_to_base64(resized_bg)
draggable_img_b64 = self.pil_to_base64(resized_fg)
return self.generate_html(
background_img_b64, bg_w, bg_h,
draggable_img_b64, draggable_width, draggable_height,
canvas_size=400
), draggable_img
def save_image(self, save_path = "/mnt/bn/hjj-humanseg-lq/SubjectDriven/DreamFuse/debug"):
global generated_images
save_name = self.get_next_sequence(save_path)
generated_images[0].save(os.path.join(save_path, f"{save_name}_0_ori.png"))
generated_images[1].save(os.path.join(save_path, f"{save_name}_0.png"))
generated_images[2].save(os.path.join(save_path, f"{save_name}_1.png"))
generated_images[3].save(os.path.join(save_path, f"{save_name}_2.png"))
generated_images[4].save(os.path.join(save_path, f"{save_name}_0_mask.png"))
generated_images[5].save(os.path.join(save_path, f"{save_name}_0_mask_scale.png"))
generated_images[6].save(os.path.join(save_path, f"{save_name}_0_scale.png"))
generated_images[7].save(os.path.join(save_path, f"{save_name}_2_pasted.png"))
def create_gui(self):
config = InferenceConfig()
config.lora_id = 'LL3RD/DreamFuse'
pipeline = DreamFuseInference(config)
pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
"""创建 Gradio 界面"""
with gr.Blocks(css=self.css_style) as demo:
modified_fg_state = gr.State()
gr.Markdown("# Dreamblend-GUI-dirtydata")
gr.Markdown("通过上传背景图与前景图生成带有可拖拽/缩放预览的合成图像,同时支持 Seed 设置和 Prompt 文本输入。")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 上传图片")
background_img_in = gr.Image(label="背景图片", type="pil", height=240, width=240)
draggable_img_in = gr.Image(label="前景图片", type="pil", image_mode="RGBA", height=240, width=240)
generate_btn = gr.Button("生成可拖拽画布")
with gr.Row():
gr.Examples(
examples=[self.examples[0]],
inputs=[background_img_in, draggable_img_in],
elem_id="small-examples"
)
with gr.Column(scale=1):
gr.Markdown("### 预览区域")
html_out = gr.HTML(label="预览与拖拽", elem_id="canvas_preview")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 参数设置")
seed_slider = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
cfg_slider = gr.Slider(minimum=1, maximum=10, step=0.1, label="CFG", value=3.5)
size_select = gr.Radio(
choices=["512", "768", "1024"],
value="512",
label="生成质量(512-差 1024-好)",
)
prompt_text = gr.Textbox(label="Prompt", placeholder="输入文本提示", value="")
text_strength = gr.Slider(minimum=1, maximum=10, step=1, label="Text Strength", value=1)
enable_gui = gr.Checkbox(label="启用GUI", value=True)
enable_truecfg = gr.Checkbox(label="启用TrueCFG", value=False)
enable_save = gr.Button("保存图片 (内部测试)", visible=True)
with gr.Column(scale=1):
gr.Markdown("### 模型生成结果")
model_generate_btn = gr.Button("模型生成")
transformation_text = gr.Textbox(label="Transformation Info", elem_id="transformation_info", visible=False)
model_output = gr.Image(label="模型输出", type="pil")
# 交互事件绑定
enable_save.click(fn=self.save_image, inputs=None, outputs=None)
generate_btn.click(
fn=self.on_upload,
inputs=[background_img_in, draggable_img_in],
outputs=[html_out, modified_fg_state],
)
model_generate_btn.click(
fn=pipeline.gradio_generate,
inputs=[background_img_in, modified_fg_state, transformation_text, seed_slider, \
prompt_text, enable_gui, cfg_slider, size_select, text_strength, enable_truecfg],
outputs=model_output
)
# 页面加载后初始化拖拽/缩放事件
demo.load(None, None, None, js=self.js_script)
generate_btn.click(fn=None, inputs=None, outputs=None, js="initializeDrag")
return demo
if __name__ == "__main__":
gui = DreamblendGUI()
demo = gui.create_gui()
demo.queue()
demo.launch()
# demo.launch(server_port=7789, ssr_mode=False)
# demo.launch(server_name="[::]", share=True)