barreloflube commited on
Commit
898faaf
·
1 Parent(s): e63972b

Refactor import statements in models.py and load_models.py

Browse files
Files changed (1) hide show
  1. tabs/images/models.py +23 -25
tabs/images/models.py CHANGED
@@ -1,6 +1,5 @@
1
  from typing import List, Optional, Dict, Any
2
 
3
- import gradio as gr
4
  from pydantic import BaseModel, field_validator
5
  from PIL import Image
6
 
@@ -8,12 +7,12 @@ from config import Config as appConfig
8
 
9
 
10
  class ControlNetReq(BaseModel):
11
- controlnets: List[str] # ["canny", "tile", "depth", "scribble"]
12
  control_images: List[Image.Image]
13
  controlnet_conditioning_scale: List[float]
14
-
15
  class Config:
16
- arbitrary_types_allowed=True
17
 
18
 
19
  class BaseReq(BaseModel):
@@ -23,7 +22,7 @@ class BaseReq(BaseModel):
23
  fast_generation: Optional[bool] = True
24
  loras: Optional[list] = []
25
  embeddings: Optional[list] = None
26
- resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
27
  scheduler: Optional[str] = "euler_fl"
28
  height: int = 1024
29
  width: int = 1024
@@ -36,39 +35,38 @@ class BaseReq(BaseModel):
36
  vae: bool = True
37
  controlnet_config: Optional[ControlNetReq] = None
38
  custom_addons: Optional[Dict[Any, Any]] = None
39
-
40
  class Config:
41
- arbitrary_types_allowed=True
42
-
43
- @field_validator('model', 'negative_prompt', 'embeddings', 'clip_skip', 'controlnet_config')
44
  def check_model(cls, values):
45
  for m in appConfig.IMAGES_MODELS:
46
- gr.Info(f"{m.get('repo_id')} {values.get('model')}")
47
- if m.get('repo_id') == values.get('model'):
48
  loader = m.get('loader')
49
-
50
- if loader == "flux" and values.get('negative_prompt'):
51
- raise ValueError("Negative prompt is not supported for Flux models.")
52
- if loader == "flux" and values.get('embeddings'):
53
- raise ValueError("Embeddings are not supported for Flux models.")
54
- if loader == "flux" and values.get('clip_skip'):
55
- raise ValueError("Clip skip is not supported for Flux models.")
56
- if loader == "flux" and values.get('controlnet_config'):
57
- if "scribble" in values.get('controlnet_config').controlnets:
58
- raise ValueError("Scribble is not supported for Flux models.")
59
  return values
60
 
61
 
62
  class BaseImg2ImgReq(BaseReq):
63
  image: Image.Image
64
  strength: float = 1.0
65
-
66
  class Config:
67
- arbitrary_types_allowed=True
68
 
69
 
70
  class BaseInpaintReq(BaseImg2ImgReq):
71
  mask_image: Image.Image
72
-
73
  class Config:
74
- arbitrary_types_allowed=True
 
1
  from typing import List, Optional, Dict, Any
2
 
 
3
  from pydantic import BaseModel, field_validator
4
  from PIL import Image
5
 
 
7
 
8
 
9
  class ControlNetReq(BaseModel):
10
+ controlnets: List[str] # ["canny", "tile", "depth", "scribble"]
11
  control_images: List[Image.Image]
12
  controlnet_conditioning_scale: List[float]
13
+
14
  class Config:
15
+ arbitrary_types_allowed = True
16
 
17
 
18
  class BaseReq(BaseModel):
 
22
  fast_generation: Optional[bool] = True
23
  loras: Optional[list] = []
24
  embeddings: Optional[list] = None
25
+ resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
26
  scheduler: Optional[str] = "euler_fl"
27
  height: int = 1024
28
  width: int = 1024
 
35
  vae: bool = True
36
  controlnet_config: Optional[ControlNetReq] = None
37
  custom_addons: Optional[Dict[Any, Any]] = None
38
+
39
  class Config:
40
+ arbitrary_types_allowed = True
41
+
42
+ @field_validator('model', 'negative_prompt', 'embeddings', 'clip_skip', 'controlnet_config', mode='before')
43
  def check_model(cls, values):
44
  for m in appConfig.IMAGES_MODELS:
45
+ if isinstance(m, dict) and m.get('repo_id') == values.get('model'):
 
46
  loader = m.get('loader')
47
+
48
+ if loader == "flux" and values.get('negative_prompt'):
49
+ raise ValueError("Negative prompt is not supported for Flux models.")
50
+ if loader == "flux" and values.get('embeddings'):
51
+ raise ValueError("Embeddings are not supported for Flux models.")
52
+ if loader == "flux" and values.get('clip_skip'):
53
+ raise ValueError("Clip skip is not supported for Flux models.")
54
+ if loader == "flux" and values.get('controlnet_config'):
55
+ if "scribble" in values.get('controlnet_config').controlnets:
56
+ raise ValueError("Scribble is not supported for Flux models.")
57
  return values
58
 
59
 
60
  class BaseImg2ImgReq(BaseReq):
61
  image: Image.Image
62
  strength: float = 1.0
63
+
64
  class Config:
65
+ arbitrary_types_allowed = True
66
 
67
 
68
  class BaseInpaintReq(BaseImg2ImgReq):
69
  mask_image: Image.Image
70
+
71
  class Config:
72
+ arbitrary_types_allowed = True