charliebaby2023 commited on
Commit
288b3a4
·
verified ·
1 Parent(s): 12fa10e

Update handle_models.py

Browse files
Files changed (1) hide show
  1. handle_models.py +71 -0
handle_models.py CHANGED
@@ -14,6 +14,76 @@ def get_current_time():
14
  current_time = now.strftime("%y-%m-%d %H:%M:%S")
15
  return current_time
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def load_fn(models,HF_TOKEN):
18
  global models_load
19
  models_load = {}
@@ -78,3 +148,4 @@ def gen_fn(model_str, prompt, nprompt="", height=0, width=0, steps=0, cfg=0, see
78
  finally:
79
  loop.close()
80
  return result
 
 
14
  current_time = now.strftime("%y-%m-%d %H:%M:%S")
15
  return current_time
16
 
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+ def load_fn(models, HF_TOKEN):
25
+ global models_load
26
+ models_load = {}
27
+ for model in models:
28
+ if model not in models_load:
29
+ try:
30
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
31
+ models_load[model] = m.fn # Store the callable
32
+ except Exception as error:
33
+ print(error)
34
+ models_load[model] = lambda **kwargs: None
35
+
36
+ async def infer(model_str, prompt, nprompt="", height=0, width=0, steps=0, cfg=0, seed=-1, timeout=120, hf_token=None):
37
+ print(f"{prompt}\n{model_str}\n{timeout}\n")
38
+ kwargs = {}
39
+ if height > 0: kwargs["height"] = height
40
+ if width > 0: kwargs["width"] = width
41
+ if steps > 0: kwargs["num_inference_steps"] = steps
42
+ if cfg > 0: kwargs["guidance_scale"] = cfg
43
+ kwargs["negative_prompt"] = nprompt
44
+
45
+ theSeed = randomize_seed() if seed == -1 else seed
46
+ kwargs["seed"] = theSeed
47
+ if hf_token:
48
+ kwargs["token"] = hf_token
49
+
50
+ try:
51
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str], prompt=prompt, **kwargs))
52
+ result = await asyncio.wait_for(task, timeout=timeout)
53
+ except asyncio.TimeoutError as e:
54
+ print(f"Timeout: {model_str}")
55
+ if not task.done(): task.cancel()
56
+ raise Exception(f"Timeout: {model_str}") from e
57
+ except Exception as e:
58
+ print(f"Exception: {model_str} -> {e}")
59
+ if not task.done(): task.cancel()
60
+ raise Exception(f"Inference failed: {model_str}") from e
61
+
62
+ if result is not None and not isinstance(result, tuple):
63
+ with lock:
64
+ png_path = model_str.replace("/", "_") + " - " + get_current_time() + "_" + str(theSeed) + ".png"
65
+ image = save_image(result, png_path, model_str, prompt, nprompt, height, width, steps, cfg, theSeed)
66
+ return image
67
+ return None
68
+
69
+ def gen_fn(model_str, prompt, nprompt="", height=0, width=0, steps=0, cfg=0, seed=-1, inference_timeout2=120):
70
+ try:
71
+ loop = asyncio.new_event_loop()
72
+ result = loop.run_until_complete(infer(model_str, prompt, nprompt,
73
+ height, width, steps, cfg, seed, inference_timeout2, HF_TOKEN))
74
+ except Exception as e:
75
+ print(f"gen_fn: Task aborted: {model_str} -> {e}")
76
+ raise gr.Error(f"Task aborted: {model_str}, Error: {e}")
77
+ finally:
78
+ loop.close()
79
+ return result
80
+
81
+
82
+
83
+
84
+
85
+
86
+ '''
87
  def load_fn(models,HF_TOKEN):
88
  global models_load
89
  models_load = {}
 
148
  finally:
149
  loop.close()
150
  return result
151
+ '''