Meng Chen
commited on
Commit
·
d48d98c
1
Parent(s):
b8264a8
update handler
Browse files- handler.py +10 -5
handler.py
CHANGED
@@ -11,8 +11,9 @@ class EndpointHandler():
|
|
11 |
# Preload all the elements you are going to need at inference.
|
12 |
# pseudo:
|
13 |
# self.model= load_model(path)
|
|
|
14 |
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
15 |
-
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
16 |
self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
|
17 |
|
18 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
@@ -23,14 +24,18 @@ class EndpointHandler():
|
|
23 |
Return:
|
24 |
A :obj:`list` | `dict`: will be serialized and returned
|
25 |
"""
|
26 |
-
if "
|
|
|
|
|
|
|
|
|
27 |
return [{"error": "Missing 'image' or 'text' key in input data"}]
|
28 |
|
29 |
try:
|
30 |
# Decode base64 image
|
31 |
-
image = self.decode_image(
|
32 |
-
prompts =
|
33 |
-
|
34 |
# Preprocess input
|
35 |
inputs = self.processor(
|
36 |
text=prompts,
|
|
|
11 |
# Preload all the elements you are going to need at inference.
|
12 |
# pseudo:
|
13 |
# self.model= load_model(path)
|
14 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
16 |
+
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device)
|
17 |
self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
|
18 |
|
19 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
24 |
Return:
|
25 |
A :obj:`list` | `dict`: will be serialized and returned
|
26 |
"""
|
27 |
+
if "inputs" not in data:
|
28 |
+
return [{"error": "Missing 'inputs' key"}]
|
29 |
+
|
30 |
+
inputs_data = data["inputs"]
|
31 |
+
if "image" not in inputs_data or "text" not in inputs_data:
|
32 |
return [{"error": "Missing 'image' or 'text' key in input data"}]
|
33 |
|
34 |
try:
|
35 |
# Decode base64 image
|
36 |
+
image = self.decode_image(inputs_data["image"])
|
37 |
+
prompts = inputs_data["text"]
|
38 |
+
|
39 |
# Preprocess input
|
40 |
inputs = self.processor(
|
41 |
text=prompts,
|