charliebaby2023 commited on
Commit
ff71cd5
·
verified ·
1 Parent(s): a6c670c

Create handle_models.py

Browse files
Files changed (1) hide show
  1. handle_models.py +60 -0
handle_models.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def load_fn(models):
3
+ global models_load
4
+ models_load = {}
5
+ for model in models:
6
+ if model not in models_load.keys():
7
+ try:
8
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
9
+ except Exception as error:
10
+ print(error)
11
+ m = gr.Interface(lambda: None, ['text'], ['image'])
12
+ models_load.update({model: m})
13
+
14
+ async def infer(model_str, prompt, nprompt="", height=0, width=0, steps=0, cfg=0, seed=-1, timeout=inference_timeout):
15
+ kwargs = {}
16
+ if height > 0: kwargs["height"] = height
17
+ if width > 0: kwargs["width"] = width
18
+ if steps > 0: kwargs["num_inference_steps"] = steps
19
+ if cfg > 0: cfg = kwargs["guidance_scale"] = cfg
20
+ if seed == -1:
21
+ theSeed = randomize_seed()
22
+ else:
23
+ theSeed = seed
24
+ kwargs["seed"] = theSeed
25
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, negative_prompt=nprompt, **kwargs, token=HF_TOKEN))
26
+ await asyncio.sleep(0)
27
+ try:
28
+ result = await asyncio.wait_for(task, timeout=timeout)
29
+ except asyncio.TimeoutError as e:
30
+ print(e)
31
+ print(f"infer: Task timed out: {model_str}")
32
+ if not task.done(): task.cancel()
33
+ result = None
34
+ raise Exception(f"Task timed out: {model_str}") from e
35
+ except Exception as e:
36
+ print(e)
37
+ print(f"infer: exception: {model_str}")
38
+ if not task.done(): task.cancel()
39
+ result = None
40
+ raise Exception() from e
41
+ if task.done() and result is not None and not isinstance(result, tuple):
42
+ with lock:
43
+ png_path = model_str.replace("/", "_") + " - " + get_current_time() + "_" + str(theSeed) + ".png"
44
+ image = save_image(result, png_path, model_str, prompt, nprompt, height, width, steps, cfg, theSeed)
45
+ return image
46
+ return None
47
+
48
+ def gen_fn(model_str, prompt, nprompt="", height=0, width=0, steps=0, cfg=0, seed=-1):
49
+ try:
50
+ loop = asyncio.new_event_loop()
51
+ result = loop.run_until_complete(infer(model_str, prompt, nprompt,
52
+ height, width, steps, cfg, seed, inference_timeout))
53
+ except (Exception, asyncio.CancelledError) as e:
54
+ print(e)
55
+ print(f"gen_fn: Task aborted: {model_str}")
56
+ result = None
57
+ raise gr.Error(f"Task aborted: {model_str}, Error: {e}")
58
+ finally:
59
+ loop.close()
60
+ return result