Dreamspire commited on
Commit
6a1c163
·
1 Parent(s): de76e75

new app.py

Browse files
Files changed (2) hide show
  1. app.py +523 -246
  2. app.py.bak +340 -0
app.py CHANGED
@@ -4,337 +4,614 @@ import sys
4
  from typing import Sequence, Mapping, Any, Union
5
  import torch
6
  import gradio as gr
7
- from PIL import Image
8
- from huggingface_hub import hf_hub_download
9
- import spaces
10
- from comfy import model_management
11
-
12
- hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
13
- hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
14
- hf_hub_download(repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision")
15
- hf_hub_download(repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything")
16
- hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1")
17
- hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders")
18
- t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5")
19
-
20
- # Import all the necessary functions from the original script
21
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
  return obj[index]
24
  except KeyError:
25
  return obj["result"][index]
26
 
27
- # Add all the necessary setup functions from the original script
28
  def find_path(name: str, path: str = None) -> str:
 
 
 
 
 
29
  if path is None:
30
  path = os.getcwd()
 
 
31
  if name in os.listdir(path):
32
  path_name = os.path.join(path, name)
33
  print(f"{name} found: {path_name}")
34
  return path_name
 
 
35
  parent_directory = os.path.dirname(path)
 
 
36
  if parent_directory == path:
37
  return None
 
 
38
  return find_path(name, parent_directory)
39
 
 
40
  def add_comfyui_directory_to_sys_path() -> None:
 
 
 
41
  comfyui_path = find_path("ComfyUI")
42
  if comfyui_path is not None and os.path.isdir(comfyui_path):
43
  sys.path.append(comfyui_path)
44
  print(f"'{comfyui_path}' added to sys.path")
45
 
 
46
  def add_extra_model_paths() -> None:
 
 
 
47
  try:
48
  from main import load_extra_path_config
49
  except ImportError:
 
 
 
50
  from utils.extra_config import load_extra_path_config
 
51
  extra_model_paths = find_path("extra_model_paths.yaml")
 
52
  if extra_model_paths is not None:
53
  load_extra_path_config(extra_model_paths)
54
  else:
55
  print("Could not find the extra_model_paths config file.")
56
 
57
- # Initialize paths
58
  add_comfyui_directory_to_sys_path()
59
  add_extra_model_paths()
60
 
 
 
61
  def import_custom_nodes() -> None:
 
 
 
 
62
  import asyncio
63
  import execution
 
 
 
 
64
  from nodes import init_extra_nodes
65
  import server
 
 
 
 
 
 
 
66
  loop = asyncio.new_event_loop()
67
  asyncio.set_event_loop(loop)
68
  server_instance = server.PromptServer(loop)
69
  execution.PromptQueue(server_instance)
 
 
70
  init_extra_nodes()
 
71
 
72
- # Import all necessary nodes
73
- from nodes import (
74
- StyleModelLoader,
75
- VAEEncode,
76
- NODE_CLASS_MAPPINGS,
77
- LoadImage,
78
- CLIPVisionLoader,
79
- SaveImage,
80
- VAELoader,
81
- CLIPVisionEncode,
82
- DualCLIPLoader,
83
- EmptyLatentImage,
84
- VAEDecode,
85
- UNETLoader,
86
- CLIPTextEncode,
87
- )
88
 
89
- # Initialize all constant nodes and models in global context
 
90
  import_custom_nodes()
91
 
92
- # Global variables for preloaded models and constants
93
- #with torch.inference_mode():
94
- # Initialize constants
95
- intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
96
- CONST_1024 = intconstant.get_value(value=1024)
97
-
98
- # Load CLIP
99
- dualcliploader = DualCLIPLoader()
100
- CLIP_MODEL = dualcliploader.load_clip(
101
- clip_name1="t5/t5xxl_fp16.safetensors",
102
- clip_name2="clip_l.safetensors",
103
- type="flux",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  )
 
105
 
106
- # Load VAE
107
- vaeloader = VAELoader()
108
- VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
109
 
110
- # Load UNET
111
- unetloader = UNETLoader()
112
- UNET_MODEL = unetloader.load_unet(
113
- unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
114
  )
115
 
116
- # Load CLIP Vision
117
- clipvisionloader = CLIPVisionLoader()
118
- CLIP_VISION_MODEL = clipvisionloader.load_clip(
119
- clip_name="sigclip_vision_patch14_384.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  )
121
 
122
- # Load Style Model
123
- stylemodelloader = StyleModelLoader()
124
- STYLE_MODEL = stylemodelloader.load_style_model(
125
- style_model_name="flux1-redux-dev.safetensors"
126
  )
127
 
128
- # Initialize samplers
129
- ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
130
- SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
 
 
 
 
 
 
 
 
 
 
131
 
132
- # Initialize depth model
133
- cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
134
- downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
135
- DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
136
- model="depth_anything_v2_vitl_fp32.safetensors"
 
 
 
 
 
 
 
 
137
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- cliptextencode = CLIPTextEncode()
140
- loadimage = LoadImage()
141
- vaeencode = VAEEncode()
142
- fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
143
- instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
144
- clipvisionencode = CLIPVisionEncode()
145
- stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
146
- emptylatentimage = EmptyLatentImage()
147
- basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
148
- basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
149
- randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
150
- samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
151
- vaedecode = VAEDecode()
152
- cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
153
- saveimage = SaveImage()
154
- getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
155
- depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
156
- imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
157
-
158
- model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
159
-
160
- model_management.load_models_gpu([
161
- loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
162
- ])
163
-
164
- @spaces.GPU
165
- def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5) -> str:
166
- """Main generation function that processes inputs and returns the path to the generated image."""
167
  with torch.inference_mode():
168
- # Set up CLIP
169
- clip_switch = cr_clip_input_switch.switch(
170
- Input=1,
171
- clip1=get_value_at_index(CLIP_MODEL, 0),
172
- clip2=get_value_at_index(CLIP_MODEL, 0),
173
  )
174
-
175
- # Encode text
176
- text_encoded = cliptextencode.encode(
177
- text=prompt,
178
- clip=get_value_at_index(clip_switch, 0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  )
180
- empty_text = cliptextencode.encode(
181
- text="",
182
- clip=get_value_at_index(clip_switch, 0),
 
183
  )
184
-
185
- # Process structure image
186
- structure_img = loadimage.load_image(image=structure_image)
187
-
188
- # Resize image
189
- resized_img = imageresize.execute(
190
- width=get_value_at_index(CONST_1024, 0),
191
- height=get_value_at_index(CONST_1024, 0),
192
- interpolation="bicubic",
193
- method="keep proportion",
194
- condition="always",
195
- multiple_of=16,
196
- image=get_value_at_index(structure_img, 0),
197
  )
198
-
199
- # Get image size
200
- size_info = getimagesizeandcount.getsize(
201
- image=get_value_at_index(resized_img, 0)
202
  )
203
-
204
- # Encode VAE
205
- vae_encoded = vaeencode.encode(
206
- pixels=get_value_at_index(size_info, 0),
207
- vae=get_value_at_index(VAE_MODEL, 0),
 
 
 
 
 
 
 
 
 
208
  )
209
-
210
- # Process depth
211
- depth_processed = depthanything_v2.process(
212
- da_model=get_value_at_index(DEPTH_MODEL, 0),
213
- images=get_value_at_index(size_info, 0),
 
 
 
 
 
 
214
  )
215
-
216
- # Apply Flux guidance
217
- flux_guided = fluxguidance.append(
218
- guidance=depth_strength,
219
- conditioning=get_value_at_index(text_encoded, 0),
 
 
 
 
 
 
 
 
 
 
 
220
  )
221
-
222
- # Process style image
223
- style_img = loadimage.load_image(image=style_image)
224
-
225
- # Encode style with CLIP Vision
226
- style_encoded = clipvisionencode.encode(
227
- crop="center",
228
- clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
229
- image=get_value_at_index(style_img, 0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  )
231
-
232
- # Set up conditioning
233
- conditioning = instructpixtopixconditioning.encode(
234
- positive=get_value_at_index(flux_guided, 0),
235
- negative=get_value_at_index(empty_text, 0),
236
- vae=get_value_at_index(VAE_MODEL, 0),
237
- pixels=get_value_at_index(depth_processed, 0),
238
  )
239
-
240
- # Apply style
241
- style_applied = stylemodelapplyadvanced.apply_stylemodel(
242
- strength=style_strength,
243
- conditioning=get_value_at_index(conditioning, 0),
244
- style_model=get_value_at_index(STYLE_MODEL, 0),
245
- clip_vision_output=get_value_at_index(style_encoded, 0),
 
 
246
  )
247
-
248
- # Set up empty latent
249
- empty_latent = emptylatentimage.generate(
250
- width=get_value_at_index(resized_img, 1),
251
- height=get_value_at_index(resized_img, 2),
252
- batch_size=1,
 
 
 
 
 
 
 
253
  )
254
-
255
- # Set up guidance
256
- guided = basicguider.get_guider(
257
- model=get_value_at_index(UNET_MODEL, 0),
258
- conditioning=get_value_at_index(style_applied, 0),
259
  )
260
-
261
- # Set up scheduler
262
- schedule = basicscheduler.get_sigmas(
263
- scheduler="simple",
264
- steps=28,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  denoise=1,
266
- model=get_value_at_index(UNET_MODEL, 0),
 
 
 
267
  )
268
-
269
- # Generate random noise
270
- noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
271
-
272
- # Sample
273
- sampled = samplercustomadvanced.sample(
274
- noise=get_value_at_index(noise, 0),
275
- guider=get_value_at_index(guided, 0),
276
- sampler=get_value_at_index(SAMPLER, 0),
277
- sigmas=get_value_at_index(schedule, 0),
278
- latent_image=get_value_at_index(empty_latent, 0),
279
  )
280
-
281
- # Decode VAE
282
- decoded = vaedecode.decode(
283
- samples=get_value_at_index(sampled, 0),
284
- vae=get_value_at_index(VAE_MODEL, 0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  )
286
-
287
- # Save image
288
- prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
289
-
290
- saved = saveimage.save_images(
291
- filename_prefix=get_value_at_index(prefix, 0),
292
- images=get_value_at_index(decoded, 0),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  )
294
- saved_path = f"output/{saved['ui']['images'][0]['filename']}"
295
  return saved_path
296
 
297
- # Create Gradio interface
298
-
299
- examples = [
300
- ["", "mona.png", "receita-tacos.webp", 15, 0.6],
301
- ["a woman looking at a house catching fire on the background", "disaster_girl.png", "abaporu.jpg", 15, 0.15],
302
- ["istanbul aerial, dramatic photography", "natasha.png", "istambul.jpg", 15, 0.5],
303
- ]
304
-
305
- output_image = gr.Image(label="Generated Image")
306
-
307
- with gr.Blocks() as app:
308
- gr.Markdown("# FLUX Style Shaping")
309
- gr.Markdown("Flux[dev] Redux + Flux[dev] Depth ComfyUI workflow by [Nathan Shipley](https://x.com/CitizenPlain) running directly on Gradio. [workflow](https://gist.github.com/nathanshipley/7a9ac1901adde76feebe58d558026f68) - [how to convert your any comfy workflow to gradio](https://huggingface.co/blog/run-comfyui-workflows-on-spaces)")
310
- with gr.Row():
311
- with gr.Column():
312
- prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
313
- with gr.Row():
314
- with gr.Group():
315
- structure_image = gr.Image(label="Structure Image", type="filepath")
316
- depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Depth Strength")
317
- with gr.Group():
318
- style_image = gr.Image(label="Style Image", type="filepath")
319
- style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Style Strength")
320
  generate_btn = gr.Button("Generate")
321
-
322
- gr.Examples(
323
- examples=examples,
324
- inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
325
- outputs=[output_image],
326
- fn=generate_image,
327
- cache_examples=True,
328
- cache_mode="lazy"
329
- )
330
 
331
- with gr.Column():
332
- output_image.render()
333
- generate_btn.click(
334
- fn=generate_image,
335
- inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
336
- outputs=[output_image]
337
- )
338
 
339
- if __name__ == "__main__":
340
- app.launch(share=True)
 
4
  from typing import Sequence, Mapping, Any, Union
5
  import torch
6
  import gradio as gr
7
+
8
+
 
 
 
 
 
 
 
 
 
 
 
 
9
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
10
+ """Returns the value at the given index of a sequence or mapping.
11
+
12
+ If the object is a sequence (like list or string), returns the value at the given index.
13
+ If the object is a mapping (like a dictionary), returns the value at the index-th key.
14
+
15
+ Some return a dictionary, in these cases, we look for the "results" key
16
+
17
+ Args:
18
+ obj (Union[Sequence, Mapping]): The object to retrieve the value from.
19
+ index (int): The index of the value to retrieve.
20
+
21
+ Returns:
22
+ Any: The value at the given index.
23
+
24
+ Raises:
25
+ IndexError: If the index is out of bounds for the object and the object is not a mapping.
26
+ """
27
  try:
28
  return obj[index]
29
  except KeyError:
30
  return obj["result"][index]
31
 
32
+
33
  def find_path(name: str, path: str = None) -> str:
34
+ """
35
+ Recursively looks at parent folders starting from the given path until it finds the given name.
36
+ Returns the path as a Path object if found, or None otherwise.
37
+ """
38
+ # If no path is given, use the current working directory
39
  if path is None:
40
  path = os.getcwd()
41
+
42
+ # Check if the current directory contains the name
43
  if name in os.listdir(path):
44
  path_name = os.path.join(path, name)
45
  print(f"{name} found: {path_name}")
46
  return path_name
47
+
48
+ # Get the parent directory
49
  parent_directory = os.path.dirname(path)
50
+
51
+ # If the parent directory is the same as the current directory, we've reached the root and stop the search
52
  if parent_directory == path:
53
  return None
54
+
55
+ # Recursively call the function with the parent directory
56
  return find_path(name, parent_directory)
57
 
58
+
59
  def add_comfyui_directory_to_sys_path() -> None:
60
+ """
61
+ Add 'ComfyUI' to the sys.path
62
+ """
63
  comfyui_path = find_path("ComfyUI")
64
  if comfyui_path is not None and os.path.isdir(comfyui_path):
65
  sys.path.append(comfyui_path)
66
  print(f"'{comfyui_path}' added to sys.path")
67
 
68
+
69
  def add_extra_model_paths() -> None:
70
+ """
71
+ Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.
72
+ """
73
  try:
74
  from main import load_extra_path_config
75
  except ImportError:
76
+ print(
77
+ "Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead."
78
+ )
79
  from utils.extra_config import load_extra_path_config
80
+
81
  extra_model_paths = find_path("extra_model_paths.yaml")
82
+
83
  if extra_model_paths is not None:
84
  load_extra_path_config(extra_model_paths)
85
  else:
86
  print("Could not find the extra_model_paths config file.")
87
 
88
+
89
  add_comfyui_directory_to_sys_path()
90
  add_extra_model_paths()
91
 
92
+
93
+ # MODIFIED FUNCTION - THE CORE FIX IS HERE
94
  def import_custom_nodes() -> None:
95
+ """
96
+ This function now correctly mimics the necessary parts of ComfyUI's startup
97
+ to ensure all paths and nodes are initialized.
98
+ """
99
  import asyncio
100
  import execution
101
+
102
+
103
+ # Crucially, import the main module to access its functions
104
+ import main as comfyui_main
105
  from nodes import init_extra_nodes
106
  import server
107
+ # 1. Apply paths from extra_model_paths.yaml and command-line args (if any)
108
+ # This is the step that was missing and caused the 'model_paths' error.
109
+ comfyui_main.apply_custom_paths()
110
+ print("Applied custom paths from extra_model_paths.yaml")
111
+
112
+ # 2. Initialize the server and queue (needed as a dependency for some nodes)
113
+ # We create a new loop each time, as per the original request to keep logic inside the function.
114
  loop = asyncio.new_event_loop()
115
  asyncio.set_event_loop(loop)
116
  server_instance = server.PromptServer(loop)
117
  execution.PromptQueue(server_instance)
118
+
119
+ # 3. Initialize the custom nodes. This will now work because paths are set.
120
  init_extra_nodes()
121
+ print("Custom nodes initialized.")
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ from nodes import NODE_CLASS_MAPPINGS
125
+
126
  import_custom_nodes()
127
 
128
+
129
+ if "Florence2ModelLoader" in NODE_CLASS_MAPPINGS:
130
+ print("Manually initializing Florence2ModelLoader.INPUT_TYPES() to populate model paths.")
131
+ florence_class = NODE_CLASS_MAPPINGS["Florence2ModelLoader"]
132
+ florence_class.INPUT_TYPES()
133
+ # =========================================================================
134
+
135
+ checkpointloadersimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
136
+ checkpointloadersimple_50 = checkpointloadersimple.load_checkpoint(
137
+ ckpt_name="SD1.5/dreamshaper_8.safetensors"
138
+ )
139
+
140
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
141
+
142
+ controlnetloader = NODE_CLASS_MAPPINGS["ControlNetLoader"]()
143
+ controlnetloader_73 = controlnetloader.load_controlnet(
144
+ control_net_name="SD1.5/control_v11p_sd15_openpose.pth"
145
+ )
146
+
147
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
148
+
149
+
150
+ florence2modelloader = NODE_CLASS_MAPPINGS["Florence2ModelLoader"]()
151
+ florence2modelloader_204 = florence2modelloader.loadmodel(
152
+ model="Florence-2-base",
153
+ precision="fp16",
154
+ attention="sdpa",
155
+ convert_to_safetensors=False,
156
  )
157
+ florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]()
158
 
 
 
 
159
 
160
+ checkpointloadersimple_319 = checkpointloadersimple.load_checkpoint(
161
+ ckpt_name="SD1.5/dreamshaper_8Inpainting.safetensors"
 
 
162
  )
163
 
164
+ loraloader = NODE_CLASS_MAPPINGS["LoraLoader"]()
165
+ loraloader_338 = loraloader.load_lora(
166
+ lora_name="add_detail.safetensors",
167
+ strength_model=1,
168
+ strength_clip=1,
169
+ model=get_value_at_index(checkpointloadersimple_319, 0),
170
+ clip=get_value_at_index(checkpointloadersimple_319, 1),
171
+ )
172
+
173
+
174
+
175
+ loraloader_353 = loraloader.load_lora(
176
+ lora_name="BaldifierW2.safetensors",
177
+ strength_model=2,
178
+ strength_clip=1,
179
+ model=get_value_at_index(loraloader_338, 0),
180
+ clip=get_value_at_index(loraloader_338, 1),
181
  )
182
 
183
+ controlnetloader_389 = controlnetloader.load_controlnet(
184
+ control_net_name="SD1.5/control_v11p_sd15_openpose.pth"
 
 
185
  )
186
 
187
+ dwpreprocessor = NODE_CLASS_MAPPINGS["DWPreprocessor"]()
188
+ controlnetapplyadvanced = NODE_CLASS_MAPPINGS["ControlNetApplyAdvanced"]()
189
+
190
+ layerutility_imagescalebyaspectratio_v2 = NODE_CLASS_MAPPINGS[
191
+ "LayerUtility: ImageScaleByAspectRatio V2"
192
+ ]()
193
+
194
+
195
+ layermask_personmaskultra_v2 = NODE_CLASS_MAPPINGS[
196
+ "LayerMask: PersonMaskUltra V2"
197
+ ]()
198
+
199
+ growmask = NODE_CLASS_MAPPINGS["GrowMask"]()
200
 
201
+ inpaintmodelconditioning = NODE_CLASS_MAPPINGS["InpaintModelConditioning"]()
202
+ ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
203
+
204
+
205
+ vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
206
+ vaeencode = NODE_CLASS_MAPPINGS["VAEEncode"]()
207
+ faceanalysismodels = NODE_CLASS_MAPPINGS["FaceAnalysisModels"]()
208
+ faceanalysismodels_506 = faceanalysismodels.load_models(
209
+ library="insightface", provider="CPU"
210
+ )
211
+ upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]()
212
+ upscalemodelloader_835 = upscalemodelloader.load_model(
213
+ model_name="4x-UltraSharp.pth"
214
  )
215
+ ipadapterunifiedloader = NODE_CLASS_MAPPINGS["IPAdapterUnifiedLoader"]()
216
+ ipadapteradvanced = NODE_CLASS_MAPPINGS["IPAdapterAdvanced"]()
217
+ facesegmentation = NODE_CLASS_MAPPINGS["FaceSegmentation"]()
218
+ layerutility_imageblend_v2 = NODE_CLASS_MAPPINGS[
219
+ "LayerUtility: ImageBlend V2"
220
+ ]()
221
+ image_comparer_rgthree = NODE_CLASS_MAPPINGS["Image Comparer (rgthree)"]()
222
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
223
+ imageupscalewithmodel = NODE_CLASS_MAPPINGS["ImageUpscaleWithModel"]()
224
+
225
+
226
+ # def main():
227
+ def generate_image(model_image, hairstyle_template_image):
228
+
229
+
230
+
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  with torch.inference_mode():
233
+ cliptextencode_52 = cliptextencode.encode(
234
+ text="multiple_hands, multiple_legs, multiple_girls\nlow quality, blurry, out of focus, distorted, unrealistic, extra limbs, missing limbs, deformed hands, deformed fingers, extra fingers, long neck, unnatural face, bad anatomy, bad proportions, poorly drawn face, poorly drawn eyes, asymmetrical eyes, extra eyes, extra head, floating objects, watermark, text, logo, jpeg artifacts, overexposed, underexposed, harsh lighting, bad posture, strange angles, unnatural expressions, oversaturated colors, messy hair, unrealistic skin texture, wrinkles inappropriately placed, incorrect shading, pixelation, complex background, busy background, detailed background, crowded scene, clutter, messy elements, unnecessary objects, overlapping objects, intricate patterns, vibrant colors, high contrast, graffiti, shadows, reflections, multiple layers, unrealistic lighting, overexposed areas.",
235
+ clip=get_value_at_index(checkpointloadersimple_50, 1),
 
 
236
  )
237
+
238
+
239
+
240
+
241
+ loadimage_144 = loadimage.load_image(image=hairstyle_template_image)
242
+
243
+
244
+
245
+ florence2run_203 = florence2run.encode(
246
+ text_input="",
247
+ task="more_detailed_caption",
248
+ fill_mask=True,
249
+ keep_model_loaded=False,
250
+ max_new_tokens=1024,
251
+ num_beams=3,
252
+ do_sample=True,
253
+ output_mask_select="",
254
+ seed=random.randint(1, 2**64),
255
+ image=get_value_at_index(loadimage_144, 0),
256
+ florence2_model=get_value_at_index(florence2modelloader_204, 0),
257
  )
258
+
259
+ cliptextencode_188 = cliptextencode.encode(
260
+ text=get_value_at_index(florence2run_203, 2),
261
+ clip=get_value_at_index(checkpointloadersimple_50, 1),
262
  )
263
+
264
+
265
+
266
+
267
+
268
+
269
+
270
+ cliptextencode_836 = cliptextencode.encode(
271
+ text=" Bald, no hair, small head, small head, nothing around, no light, no highlights, no sunlight,Smooth forehead,No wrinkles",
272
+ clip=get_value_at_index(loraloader_353, 1),
 
 
 
273
  )
274
+
275
+ cliptextencode_321 = cliptextencode.encode(
276
+ text="wrinkles,Big forehead, big head, big back of the head,multiple_hands, multiple_legs, multiple_girls\nlow quality, blurry, out of focus, distorted, unrealistic, extra limbs, missing limbs, deformed hands, deformed fingers, extra fingers, long neck, unnatural face, bad anatomy, bad proportions, poorly drawn face, poorly drawn eyes, asymmetrical eyes, extra eyes, extra head, floating objects, watermark, text, logo, jpeg artifacts, overexposed, underexposed, harsh lighting, bad posture, strange angles, unnatural expressions, oversaturated colors, messy hair, unrealistic skin texture, wrinkles inappropriately placed, incorrect shading, pixelation, complex background, busy background, detailed background, crowded scene, clutter, messy elements, unnecessary objects, overlapping objects, intricate patterns, vibrant colors, high contrast, graffiti, shadows, reflections, multiple layers, unrealistic lighting, overexposed areas.",
277
+ clip=get_value_at_index(loraloader_353, 1),
278
  )
279
+
280
+
281
+ loadimage_317 = loadimage.load_image(image=model_image)
282
+
283
+
284
+ dwpreprocessor_390 = dwpreprocessor.estimate_pose(
285
+ detect_hand="enable",
286
+ detect_body="enable",
287
+ detect_face="enable",
288
+ resolution=768,
289
+ bbox_detector="yolox_l.onnx",
290
+ pose_estimator="dw-ll_ucoco_384_bs5.torchscript.pt",
291
+ scale_stick_for_xinsr_cn="disable",
292
+ image=get_value_at_index(loadimage_317, 0),
293
  )
294
+
295
+
296
+ controlnetapplyadvanced_388 = controlnetapplyadvanced.apply_controlnet(
297
+ strength=1,
298
+ start_percent=0,
299
+ end_percent=1,
300
+ positive=get_value_at_index(cliptextencode_836, 0),
301
+ negative=get_value_at_index(cliptextencode_321, 0),
302
+ control_net=get_value_at_index(controlnetloader_389, 0),
303
+ image=get_value_at_index(dwpreprocessor_390, 0),
304
+ vae=get_value_at_index(checkpointloadersimple_319, 2),
305
  )
306
+
307
+
308
+ layerutility_imagescalebyaspectratio_v2_331 = (
309
+ layerutility_imagescalebyaspectratio_v2.image_scale_by_aspect_ratio(
310
+ aspect_ratio="original",
311
+ proportional_width=1,
312
+ proportional_height=1,
313
+ fit="letterbox",
314
+ method="lanczos",
315
+ round_to_multiple="8",
316
+ scale_to_side="longest",
317
+ scale_to_length=768,
318
+ background_color="#000000",
319
+ image=get_value_at_index(loadimage_317, 0),
320
+ mask=get_value_at_index(loadimage_317, 1),
321
+ )
322
  )
323
+
324
+
325
+ layermask_personmaskultra_v2_327 = (
326
+ layermask_personmaskultra_v2.person_mask_ultra_v2(
327
+ face=False,
328
+ hair=True,
329
+ body=False,
330
+ clothes=False,
331
+ accessories=False,
332
+ background=False,
333
+ confidence=0.4,
334
+ detail_method="VITMatte",
335
+ detail_erode=6,
336
+ detail_dilate=6,
337
+ black_point=0.01,
338
+ white_point=0.99,
339
+ process_detail=True,
340
+ device="cuda",
341
+ max_megapixels=2,
342
+ images=get_value_at_index(
343
+ layerutility_imagescalebyaspectratio_v2_331, 0
344
+ ),
345
+ )
346
  )
347
+
348
+
349
+ growmask_502 = growmask.expand_mask(
350
+ expand=20,
351
+ tapered_corners=True,
352
+ mask=get_value_at_index(layermask_personmaskultra_v2_327, 1),
 
353
  )
354
+
355
+
356
+ inpaintmodelconditioning_330 = inpaintmodelconditioning.encode(
357
+ noise_mask=True,
358
+ positive=get_value_at_index(controlnetapplyadvanced_388, 0),
359
+ negative=get_value_at_index(controlnetapplyadvanced_388, 1),
360
+ vae=get_value_at_index(checkpointloadersimple_319, 2),
361
+ pixels=get_value_at_index(layerutility_imagescalebyaspectratio_v2_331, 0),
362
+ mask=get_value_at_index(growmask_502, 0),
363
  )
364
+
365
+
366
+ ksampler_318 = ksampler.sample(
367
+ seed=random.randint(1, 2**64),
368
+ steps=10,
369
+ cfg=2.5,
370
+ sampler_name="euler_ancestral",
371
+ scheduler="normal",
372
+ denoise=1,
373
+ model=get_value_at_index(loraloader_353, 0),
374
+ positive=get_value_at_index(inpaintmodelconditioning_330, 0),
375
+ negative=get_value_at_index(inpaintmodelconditioning_330, 1),
376
+ latent_image=get_value_at_index(inpaintmodelconditioning_330, 2),
377
  )
378
+
379
+
380
+ vaedecode_322 = vaedecode.decode(
381
+ samples=get_value_at_index(ksampler_318, 0),
382
+ vae=get_value_at_index(checkpointloadersimple_319, 2),
383
  )
384
+
385
+
386
+ vaeencode_191 = vaeencode.encode(
387
+ pixels=get_value_at_index(vaedecode_322, 0),
388
+ vae=get_value_at_index(checkpointloadersimple_50, 2),
389
+ )
390
+
391
+
392
+
393
+
394
+ faceanalysismodels_840 = faceanalysismodels.load_models(
395
+ library="insightface", provider="CUDA"
396
+ )
397
+
398
+
399
+
400
+ # for q in range(1):
401
+ ipadapterunifiedloader_90 = ipadapterunifiedloader.load_models(
402
+ preset="PLUS (high strength)",
403
+ model=get_value_at_index(checkpointloadersimple_50, 0),
404
+ )
405
+
406
+ layerutility_imagescalebyaspectratio_v2_187 = (
407
+ layerutility_imagescalebyaspectratio_v2.image_scale_by_aspect_ratio(
408
+ aspect_ratio="original",
409
+ proportional_width=132,
410
+ proportional_height=1,
411
+ fit="letterbox",
412
+ method="lanczos",
413
+ round_to_multiple="8",
414
+ scale_to_side="longest",
415
+ scale_to_length=768,
416
+ background_color="#000000",
417
+ image=get_value_at_index(loadimage_144, 0),
418
+ )
419
+ )
420
+
421
+ ipadapteradvanced_85 = ipadapteradvanced.apply_ipadapter(
422
+ weight=1,
423
+ weight_type="strong style transfer",
424
+ combine_embeds="concat",
425
+ start_at=0,
426
+ end_at=1,
427
+ embeds_scaling="V only",
428
+ model=get_value_at_index(ipadapterunifiedloader_90, 0),
429
+ ipadapter=get_value_at_index(ipadapterunifiedloader_90, 1),
430
+ image=get_value_at_index(
431
+ layerutility_imagescalebyaspectratio_v2_187, 0
432
+ ),
433
+ )
434
+
435
+ dwpreprocessor_72 = dwpreprocessor.estimate_pose(
436
+ detect_hand="enable",
437
+ detect_body="enable",
438
+ detect_face="enable",
439
+ resolution=1024,
440
+ bbox_detector="yolox_l.onnx",
441
+ pose_estimator="dw-ll_ucoco_384_bs5.torchscript.pt",
442
+ scale_stick_for_xinsr_cn="disable",
443
+ image=get_value_at_index(vaedecode_322, 0),
444
+ )
445
+
446
+ controlnetapplyadvanced_189 = controlnetapplyadvanced.apply_controlnet(
447
+ strength=1,
448
+ start_percent=0,
449
+ end_percent=1,
450
+ positive=get_value_at_index(cliptextencode_188, 0),
451
+ negative=get_value_at_index(cliptextencode_52, 0),
452
+ control_net=get_value_at_index(controlnetloader_73, 0),
453
+ image=get_value_at_index(dwpreprocessor_72, 0),
454
+ vae=get_value_at_index(checkpointloadersimple_50, 2),
455
+ )
456
+
457
+ ksampler_45 = ksampler.sample(
458
+ seed=random.randint(1, 2**64),
459
+ steps=15,
460
+ cfg=1,
461
+ sampler_name="dpmpp_2m",
462
+ scheduler="karras",
463
  denoise=1,
464
+ model=get_value_at_index(ipadapteradvanced_85, 0),
465
+ positive=get_value_at_index(controlnetapplyadvanced_189, 0),
466
+ negative=get_value_at_index(controlnetapplyadvanced_189, 1),
467
+ latent_image=get_value_at_index(vaeencode_191, 0),
468
  )
469
+
470
+ vaedecode_67 = vaedecode.decode(
471
+ samples=get_value_at_index(ksampler_45, 0),
472
+ vae=get_value_at_index(checkpointloadersimple_50, 2),
 
 
 
 
 
 
 
473
  )
474
+
475
+ layermask_personmaskultra_v2_192 = (
476
+ layermask_personmaskultra_v2.person_mask_ultra_v2(
477
+ face=False,
478
+ hair=True,
479
+ body=False,
480
+ clothes=False,
481
+ accessories=False,
482
+ background=False,
483
+ confidence=0.4,
484
+ detail_method="VITMatte",
485
+ detail_erode=6,
486
+ detail_dilate=6,
487
+ black_point=0.01,
488
+ white_point=0.99,
489
+ process_detail=True,
490
+ device="cuda",
491
+ max_megapixels=2,
492
+ images=get_value_at_index(vaedecode_67, 0),
493
+ )
494
  )
495
+
496
+ facesegmentation_505 = facesegmentation.segment(
497
+ area="face+forehead (if available)",
498
+ grow=-5,
499
+ grow_tapered=False,
500
+ blur=41,
501
+ analysis_models=get_value_at_index(faceanalysismodels_506, 0),
502
+ image=get_value_at_index(
503
+ layerutility_imagescalebyaspectratio_v2_331, 0
504
+ ),
505
+ )
506
+
507
+ growmask_396 = growmask.expand_mask(
508
+ expand=0,
509
+ tapered_corners=True,
510
+ mask=get_value_at_index(facesegmentation_505, 0),
511
+ )
512
+
513
+ layerutility_imageblend_v2_399 = layerutility_imageblend_v2.image_blend_v2(
514
+ invert_mask=True,
515
+ blend_mode="normal",
516
+ opacity=100,
517
+ background_image=get_value_at_index(
518
+ layerutility_imagescalebyaspectratio_v2_331, 0
519
+ ),
520
+ layer_image=get_value_at_index(vaedecode_322, 0),
521
+ layer_mask=get_value_at_index(growmask_396, 0),
522
+ )
523
+
524
+ layerutility_imageblend_v2_314 = layerutility_imageblend_v2.image_blend_v2(
525
+ invert_mask=True,
526
+ blend_mode="normal",
527
+ opacity=100,
528
+ background_image=get_value_at_index(layerutility_imageblend_v2_399, 0),
529
+ layer_image=get_value_at_index(layermask_personmaskultra_v2_192, 0),
530
+ )
531
+
532
+ image_comparer_rgthree_486 = image_comparer_rgthree.compare_images(
533
+ image_a=get_value_at_index(layerutility_imageblend_v2_314, 0),
534
+ image_b=get_value_at_index(
535
+ layerutility_imagescalebyaspectratio_v2_331, 0
536
+ ),
537
+ )
538
+
539
+ saveimage_680 = saveimage.save_images(
540
+ filename_prefix="ComfyUI",
541
+ images=get_value_at_index(layerutility_imageblend_v2_314, 0),
542
+ )
543
+
544
+ saved_path = f"output/{saveimage_680['ui']['images'][0]['filename']}"
545
+
546
+
547
+ facesegmentation_838 = facesegmentation.segment(
548
+ area="face+forehead (if available)",
549
+ grow=0,
550
+ grow_tapered=False,
551
+ blur=13,
552
+ analysis_models=get_value_at_index(faceanalysismodels_840, 0),
553
+ image=get_value_at_index(layerutility_imageblend_v2_399, 0),
554
+ )
555
+
556
+ growmask_839 = growmask.expand_mask(
557
+ expand=0,
558
+ tapered_corners=True,
559
+ mask=get_value_at_index(facesegmentation_838, 0),
560
+ )
561
+
562
+ layerutility_imageblend_v2_686 = layerutility_imageblend_v2.image_blend_v2(
563
+ invert_mask=False,
564
+ blend_mode="normal",
565
+ opacity=100,
566
+ background_image=get_value_at_index(layerutility_imageblend_v2_314, 0),
567
+ layer_image=get_value_at_index(layerutility_imageblend_v2_399, 0),
568
+ layer_mask=get_value_at_index(growmask_839, 0),
569
+ )
570
+
571
+ image_comparer_rgthree_820 = image_comparer_rgthree.compare_images(
572
+ image_a=get_value_at_index(layerutility_imageblend_v2_399, 0),
573
+ image_b=get_value_at_index(
574
+ layerutility_imagescalebyaspectratio_v2_331, 0
575
+ ),
576
+ )
577
+
578
+ imageupscalewithmodel_831 = imageupscalewithmodel.upscale(
579
+ upscale_model=get_value_at_index(upscalemodelloader_835, 0),
580
+ image=get_value_at_index(layerutility_imageblend_v2_686, 0),
581
  )
582
+
583
  return saved_path
584
 
585
+
586
+ if __name__ == "__main__":
587
+ # main()
588
+
589
+ with gr.Blocks() as app:
590
+ gr.Markdown("# Swap Hairstyle")
591
+
592
+ with gr.Row():
593
+ # 添加输入
594
+ with gr.Column():
595
+ with gr.Row():
596
+ # 第一组包括结构图像和深度强度
597
+ with gr.Group():
598
+ model_image = gr.Image(label="Model Image", type="filepath")
599
+ # 第二组包括风格图像和风格强度
600
+ with gr.Group():
601
+ hairstyle_template_image = gr.Image(label="Hairstyle Template Image", type="filepath")
602
+
603
+ with gr.Column():
604
+ # 输出图像
605
+ output_image = gr.Image(label="Generated Image")
606
+
607
+ with gr.Row():
608
  generate_btn = gr.Button("Generate")
609
+
610
+ generate_btn.click(
611
+ fn=generate_image,
612
+ inputs=[model_image, hairstyle_template_image],
613
+ outputs=[output_image]
614
+ )
 
 
 
615
 
616
+ app.launch(share=True)
 
 
 
 
 
 
617
 
 
 
app.py.bak ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
+ import spaces
10
+ from comfy import model_management
11
+
12
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
13
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
14
+ hf_hub_download(repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision")
15
+ hf_hub_download(repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything")
16
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1")
17
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders")
18
+ t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5")
19
+
20
+ # Import all the necessary functions from the original script
21
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
22
+ try:
23
+ return obj[index]
24
+ except KeyError:
25
+ return obj["result"][index]
26
+
27
+ # Add all the necessary setup functions from the original script
28
+ def find_path(name: str, path: str = None) -> str:
29
+ if path is None:
30
+ path = os.getcwd()
31
+ if name in os.listdir(path):
32
+ path_name = os.path.join(path, name)
33
+ print(f"{name} found: {path_name}")
34
+ return path_name
35
+ parent_directory = os.path.dirname(path)
36
+ if parent_directory == path:
37
+ return None
38
+ return find_path(name, parent_directory)
39
+
40
+ def add_comfyui_directory_to_sys_path() -> None:
41
+ comfyui_path = find_path("ComfyUI")
42
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
43
+ sys.path.append(comfyui_path)
44
+ print(f"'{comfyui_path}' added to sys.path")
45
+
46
+ def add_extra_model_paths() -> None:
47
+ try:
48
+ from main import load_extra_path_config
49
+ except ImportError:
50
+ from utils.extra_config import load_extra_path_config
51
+ extra_model_paths = find_path("extra_model_paths.yaml")
52
+ if extra_model_paths is not None:
53
+ load_extra_path_config(extra_model_paths)
54
+ else:
55
+ print("Could not find the extra_model_paths config file.")
56
+
57
+ # Initialize paths
58
+ add_comfyui_directory_to_sys_path()
59
+ add_extra_model_paths()
60
+
61
+ def import_custom_nodes() -> None:
62
+ import asyncio
63
+ import execution
64
+ from nodes import init_extra_nodes
65
+ import server
66
+ loop = asyncio.new_event_loop()
67
+ asyncio.set_event_loop(loop)
68
+ server_instance = server.PromptServer(loop)
69
+ execution.PromptQueue(server_instance)
70
+ init_extra_nodes()
71
+
72
+ # Import all necessary nodes
73
+ from nodes import (
74
+ StyleModelLoader,
75
+ VAEEncode,
76
+ NODE_CLASS_MAPPINGS,
77
+ LoadImage,
78
+ CLIPVisionLoader,
79
+ SaveImage,
80
+ VAELoader,
81
+ CLIPVisionEncode,
82
+ DualCLIPLoader,
83
+ EmptyLatentImage,
84
+ VAEDecode,
85
+ UNETLoader,
86
+ CLIPTextEncode,
87
+ )
88
+
89
+ # Initialize all constant nodes and models in global context
90
+ import_custom_nodes()
91
+
92
+ # Global variables for preloaded models and constants
93
+ #with torch.inference_mode():
94
+ # Initialize constants
95
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
96
+ CONST_1024 = intconstant.get_value(value=1024)
97
+
98
+ # Load CLIP
99
+ dualcliploader = DualCLIPLoader()
100
+ CLIP_MODEL = dualcliploader.load_clip(
101
+ clip_name1="t5/t5xxl_fp16.safetensors",
102
+ clip_name2="clip_l.safetensors",
103
+ type="flux",
104
+ )
105
+
106
+ # Load VAE
107
+ vaeloader = VAELoader()
108
+ VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
109
+
110
+ # Load UNET
111
+ unetloader = UNETLoader()
112
+ UNET_MODEL = unetloader.load_unet(
113
+ unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
114
+ )
115
+
116
+ # Load CLIP Vision
117
+ clipvisionloader = CLIPVisionLoader()
118
+ CLIP_VISION_MODEL = clipvisionloader.load_clip(
119
+ clip_name="sigclip_vision_patch14_384.safetensors"
120
+ )
121
+
122
+ # Load Style Model
123
+ stylemodelloader = StyleModelLoader()
124
+ STYLE_MODEL = stylemodelloader.load_style_model(
125
+ style_model_name="flux1-redux-dev.safetensors"
126
+ )
127
+
128
+ # Initialize samplers
129
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
130
+ SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
131
+
132
+ # Initialize depth model
133
+ cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
134
+ downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
135
+ DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
136
+ model="depth_anything_v2_vitl_fp32.safetensors"
137
+ )
138
+
139
+ cliptextencode = CLIPTextEncode()
140
+ loadimage = LoadImage()
141
+ vaeencode = VAEEncode()
142
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
143
+ instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
144
+ clipvisionencode = CLIPVisionEncode()
145
+ stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
146
+ emptylatentimage = EmptyLatentImage()
147
+ basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
148
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
149
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
150
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
151
+ vaedecode = VAEDecode()
152
+ cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
153
+ saveimage = SaveImage()
154
+ getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
155
+ depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
156
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
157
+
158
+ model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
159
+
160
+ model_management.load_models_gpu([
161
+ loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
162
+ ])
163
+
164
+ @spaces.GPU
165
+ def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5) -> str:
166
+ """Main generation function that processes inputs and returns the path to the generated image."""
167
+ with torch.inference_mode():
168
+ # Set up CLIP
169
+ clip_switch = cr_clip_input_switch.switch(
170
+ Input=1,
171
+ clip1=get_value_at_index(CLIP_MODEL, 0),
172
+ clip2=get_value_at_index(CLIP_MODEL, 0),
173
+ )
174
+
175
+ # Encode text
176
+ text_encoded = cliptextencode.encode(
177
+ text=prompt,
178
+ clip=get_value_at_index(clip_switch, 0),
179
+ )
180
+ empty_text = cliptextencode.encode(
181
+ text="",
182
+ clip=get_value_at_index(clip_switch, 0),
183
+ )
184
+
185
+ # Process structure image
186
+ structure_img = loadimage.load_image(image=structure_image)
187
+
188
+ # Resize image
189
+ resized_img = imageresize.execute(
190
+ width=get_value_at_index(CONST_1024, 0),
191
+ height=get_value_at_index(CONST_1024, 0),
192
+ interpolation="bicubic",
193
+ method="keep proportion",
194
+ condition="always",
195
+ multiple_of=16,
196
+ image=get_value_at_index(structure_img, 0),
197
+ )
198
+
199
+ # Get image size
200
+ size_info = getimagesizeandcount.getsize(
201
+ image=get_value_at_index(resized_img, 0)
202
+ )
203
+
204
+ # Encode VAE
205
+ vae_encoded = vaeencode.encode(
206
+ pixels=get_value_at_index(size_info, 0),
207
+ vae=get_value_at_index(VAE_MODEL, 0),
208
+ )
209
+
210
+ # Process depth
211
+ depth_processed = depthanything_v2.process(
212
+ da_model=get_value_at_index(DEPTH_MODEL, 0),
213
+ images=get_value_at_index(size_info, 0),
214
+ )
215
+
216
+ # Apply Flux guidance
217
+ flux_guided = fluxguidance.append(
218
+ guidance=depth_strength,
219
+ conditioning=get_value_at_index(text_encoded, 0),
220
+ )
221
+
222
+ # Process style image
223
+ style_img = loadimage.load_image(image=style_image)
224
+
225
+ # Encode style with CLIP Vision
226
+ style_encoded = clipvisionencode.encode(
227
+ crop="center",
228
+ clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
229
+ image=get_value_at_index(style_img, 0),
230
+ )
231
+
232
+ # Set up conditioning
233
+ conditioning = instructpixtopixconditioning.encode(
234
+ positive=get_value_at_index(flux_guided, 0),
235
+ negative=get_value_at_index(empty_text, 0),
236
+ vae=get_value_at_index(VAE_MODEL, 0),
237
+ pixels=get_value_at_index(depth_processed, 0),
238
+ )
239
+
240
+ # Apply style
241
+ style_applied = stylemodelapplyadvanced.apply_stylemodel(
242
+ strength=style_strength,
243
+ conditioning=get_value_at_index(conditioning, 0),
244
+ style_model=get_value_at_index(STYLE_MODEL, 0),
245
+ clip_vision_output=get_value_at_index(style_encoded, 0),
246
+ )
247
+
248
+ # Set up empty latent
249
+ empty_latent = emptylatentimage.generate(
250
+ width=get_value_at_index(resized_img, 1),
251
+ height=get_value_at_index(resized_img, 2),
252
+ batch_size=1,
253
+ )
254
+
255
+ # Set up guidance
256
+ guided = basicguider.get_guider(
257
+ model=get_value_at_index(UNET_MODEL, 0),
258
+ conditioning=get_value_at_index(style_applied, 0),
259
+ )
260
+
261
+ # Set up scheduler
262
+ schedule = basicscheduler.get_sigmas(
263
+ scheduler="simple",
264
+ steps=28,
265
+ denoise=1,
266
+ model=get_value_at_index(UNET_MODEL, 0),
267
+ )
268
+
269
+ # Generate random noise
270
+ noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
271
+
272
+ # Sample
273
+ sampled = samplercustomadvanced.sample(
274
+ noise=get_value_at_index(noise, 0),
275
+ guider=get_value_at_index(guided, 0),
276
+ sampler=get_value_at_index(SAMPLER, 0),
277
+ sigmas=get_value_at_index(schedule, 0),
278
+ latent_image=get_value_at_index(empty_latent, 0),
279
+ )
280
+
281
+ # Decode VAE
282
+ decoded = vaedecode.decode(
283
+ samples=get_value_at_index(sampled, 0),
284
+ vae=get_value_at_index(VAE_MODEL, 0),
285
+ )
286
+
287
+ # Save image
288
+ prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
289
+
290
+ saved = saveimage.save_images(
291
+ filename_prefix=get_value_at_index(prefix, 0),
292
+ images=get_value_at_index(decoded, 0),
293
+ )
294
+ saved_path = f"output/{saved['ui']['images'][0]['filename']}"
295
+ return saved_path
296
+
297
+ # Create Gradio interface
298
+
299
+ examples = [
300
+ ["", "mona.png", "receita-tacos.webp", 15, 0.6],
301
+ ["a woman looking at a house catching fire on the background", "disaster_girl.png", "abaporu.jpg", 15, 0.15],
302
+ ["istanbul aerial, dramatic photography", "natasha.png", "istambul.jpg", 15, 0.5],
303
+ ]
304
+
305
+ output_image = gr.Image(label="Generated Image")
306
+
307
+ with gr.Blocks() as app:
308
+ gr.Markdown("# FLUX Style Shaping")
309
+ gr.Markdown("Flux[dev] Redux + Flux[dev] Depth ComfyUI workflow by [Nathan Shipley](https://x.com/CitizenPlain) running directly on Gradio. [workflow](https://gist.github.com/nathanshipley/7a9ac1901adde76feebe58d558026f68) - [how to convert your any comfy workflow to gradio](https://huggingface.co/blog/run-comfyui-workflows-on-spaces)")
310
+ with gr.Row():
311
+ with gr.Column():
312
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
313
+ with gr.Row():
314
+ with gr.Group():
315
+ structure_image = gr.Image(label="Structure Image", type="filepath")
316
+ depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Depth Strength")
317
+ with gr.Group():
318
+ style_image = gr.Image(label="Style Image", type="filepath")
319
+ style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Style Strength")
320
+ generate_btn = gr.Button("Generate")
321
+
322
+ gr.Examples(
323
+ examples=examples,
324
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
325
+ outputs=[output_image],
326
+ fn=generate_image,
327
+ cache_examples=True,
328
+ cache_mode="lazy"
329
+ )
330
+
331
+ with gr.Column():
332
+ output_image.render()
333
+ generate_btn.click(
334
+ fn=generate_image,
335
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
336
+ outputs=[output_image]
337
+ )
338
+
339
+ if __name__ == "__main__":
340
+ app.launch(share=True)