prithivMLmods commited on
Commit
8f771be
·
verified ·
1 Parent(s): b1b843c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py CHANGED
@@ -10,6 +10,10 @@ import numpy as np
10
  import time
11
  import zipfile
12
  import os
 
 
 
 
13
 
14
  # Description for the app
15
  DESCRIPTION = """## Qwen Image Hpc/."""
@@ -44,6 +48,45 @@ aspect_ratios = {
44
  "3:4": (1140, 1472)
45
  }
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Generation function for Qwen/Qwen-Image
48
  @spaces.GPU(duration=120)
49
  def generate_qwen(
@@ -57,6 +100,8 @@ def generate_qwen(
57
  num_inference_steps: int = 50,
58
  num_images: int = 1,
59
  zip_images: bool = False,
 
 
60
  progress=gr.Progress(track_tqdm=True),
61
  ):
62
  if randomize_seed:
@@ -64,6 +109,16 @@ def generate_qwen(
64
  generator = torch.Generator(device).manual_seed(seed)
65
 
66
  start_time = time.time()
 
 
 
 
 
 
 
 
 
 
67
 
68
  images = pipe_qwen(
69
  prompt=prompt,
@@ -75,6 +130,7 @@ def generate_qwen(
75
  num_images_per_prompt=num_images,
76
  generator=generator,
77
  output_type="pil",
 
78
  ).images
79
 
80
  end_time = time.time()
@@ -105,6 +161,8 @@ def generate(
105
  num_inference_steps: int,
106
  num_images: int,
107
  zip_images: bool,
 
 
108
  progress=gr.Progress(track_tqdm=True),
109
  ):
110
  final_negative_prompt = negative_prompt if use_negative_prompt else ""
@@ -119,6 +177,8 @@ def generate(
119
  num_inference_steps=num_inference_steps,
120
  num_images=num_images,
121
  zip_images=zip_images,
 
 
122
  progress=progress,
123
  )
124
 
@@ -165,6 +225,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
165
  choices=list(aspect_ratios.keys()),
166
  value="1:1",
167
  )
 
 
168
  with gr.Accordion("Additional Options", open=False):
169
  use_negative_prompt = gr.Checkbox(
170
  label="Use negative prompt",
@@ -223,6 +285,14 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
223
  value=1,
224
  )
225
  zip_images = gr.Checkbox(label="Zip generated images", value=False)
 
 
 
 
 
 
 
 
226
 
227
  gr.Markdown("### Output Information")
228
  seed_display = gr.Textbox(label="Seed used", interactive=False)
@@ -263,6 +333,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
263
  num_inference_steps,
264
  num_images,
265
  zip_images,
 
 
266
  ],
267
  outputs=[result, seed_display, generation_time, zip_file],
268
  api_name="run",
 
10
  import time
11
  import zipfile
12
  import os
13
+ import requests
14
+ from urllib.parse import urlparse
15
+ import tempfile
16
+ import shutil
17
 
18
  # Description for the app
19
  DESCRIPTION = """## Qwen Image Hpc/."""
 
48
  "3:4": (1140, 1472)
49
  }
50
 
51
+ def load_lora_opt(pipe, lora_input):
52
+ lora_input = lora_input.strip()
53
+ if not lora_input:
54
+ return
55
+
56
+ # If it's just an ID like "author/model"
57
+ if "/" in lora_input and not lora_input.startswith("http"):
58
+ pipe.load_lora_weights(lora_input)
59
+ return
60
+
61
+ if lora_input.startswith("http"):
62
+ url = lora_input
63
+
64
+ # Repo page (no blob/resolve)
65
+ if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url:
66
+ repo_id = urlparse(url).path.strip("/")
67
+ pipe.load_lora_weights(repo_id)
68
+ return
69
+
70
+ # Blob link → convert to resolve link
71
+ if "/blob/" in url:
72
+ url = url.replace("/blob/", "/resolve/")
73
+
74
+ # Download direct file
75
+ tmp_dir = tempfile.mkdtemp()
76
+ local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
77
+
78
+ try:
79
+ print(f"Downloading LoRA from {url}...")
80
+ resp = requests.get(url, stream=True)
81
+ resp.raise_for_status()
82
+ with open(local_path, "wb") as f:
83
+ for chunk in resp.iter_content(chunk_size=8192):
84
+ f.write(chunk)
85
+ print(f"Saved LoRA to {local_path}")
86
+ pipe.load_lora_weights(local_path)
87
+ finally:
88
+ shutil.rmtree(tmp_dir, ignore_errors=True)
89
+
90
  # Generation function for Qwen/Qwen-Image
91
  @spaces.GPU(duration=120)
92
  def generate_qwen(
 
100
  num_inference_steps: int = 50,
101
  num_images: int = 1,
102
  zip_images: bool = False,
103
+ lora_input: str = "",
104
+ lora_scale: float = 1.0,
105
  progress=gr.Progress(track_tqdm=True),
106
  ):
107
  if randomize_seed:
 
109
  generator = torch.Generator(device).manual_seed(seed)
110
 
111
  start_time = time.time()
112
+
113
+ pipe_qwen.unload_lora_weights()
114
+ use_lora = False
115
+ if lora_input and lora_input.strip() != "":
116
+ load_lora_opt(pipe_qwen, lora_input)
117
+ use_lora = True
118
+
119
+ kwargs = {}
120
+ if use_lora:
121
+ kwargs["cross_attention_kwargs"] = {"scale": lora_scale}
122
 
123
  images = pipe_qwen(
124
  prompt=prompt,
 
130
  num_images_per_prompt=num_images,
131
  generator=generator,
132
  output_type="pil",
133
+ **kwargs,
134
  ).images
135
 
136
  end_time = time.time()
 
161
  num_inference_steps: int,
162
  num_images: int,
163
  zip_images: bool,
164
+ lora_input: str,
165
+ lora_scale: float,
166
  progress=gr.Progress(track_tqdm=True),
167
  ):
168
  final_negative_prompt = negative_prompt if use_negative_prompt else ""
 
177
  num_inference_steps=num_inference_steps,
178
  num_images=num_images,
179
  zip_images=zip_images,
180
+ lora_input=lora_input,
181
+ lora_scale=lora_scale,
182
  progress=progress,
183
  )
184
 
 
225
  choices=list(aspect_ratios.keys()),
226
  value="1:1",
227
  )
228
+ with gr.Row():
229
+ lora = gr.Textbox(label="qwen3 image lora (optional)", info="insert lora path")
230
  with gr.Accordion("Additional Options", open=False):
231
  use_negative_prompt = gr.Checkbox(
232
  label="Use negative prompt",
 
285
  value=1,
286
  )
287
  zip_images = gr.Checkbox(label="Zip generated images", value=False)
288
+ with gr.Row():
289
+ lora_scale = gr.Slider(
290
+ label="LoRA Scale",
291
+ minimum=0,
292
+ maximum=2,
293
+ step=0.01,
294
+ value=1,
295
+ )
296
 
297
  gr.Markdown("### Output Information")
298
  seed_display = gr.Textbox(label="Seed used", interactive=False)
 
333
  num_inference_steps,
334
  num_images,
335
  zip_images,
336
+ lora,
337
+ lora_scale,
338
  ],
339
  outputs=[result, seed_display, generation_time, zip_file],
340
  api_name="run",