Meng Chen commited on
Commit
11b246c
·
1 Parent(s): 7ea8ba1

add handler

Browse files
Files changed (1) hide show
  1. handler.py +80 -0
handler.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import holidays
3
+ from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation
4
+ from PIL import Image
5
+ import torch
6
+ import base64
7
+ import io
8
+ import numpy as np
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ # Preload all the elements you are going to need at inference.
13
+ # pseudo:
14
+ # self.model= load_model(path)
15
+ self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
16
+ self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
17
+ self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
18
+
19
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
+ """
21
+ data args:
22
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
23
+ kwargs
24
+ Return:
25
+ A :obj:`list` | `dict`: will be serialized and returned
26
+ """
27
+
28
+ if "image" not in data or "text" not in data:
29
+ return [{"error": "Missing 'image' or 'text' key in input data"}]
30
+
31
+ try:
32
+ # Decode base64 image
33
+ image = self.decode_image(data["image"])
34
+ prompts = data["text"].split(",")
35
+
36
+ # Preprocess input
37
+ inputs = self.processor(
38
+ text=prompts,
39
+ images=[image] * len(prompts),
40
+ padding="max_length",
41
+ return_tensors="pt"
42
+ ).to("cuda")
43
+
44
+ # Run inference
45
+ with torch.no_grad():
46
+ outputs = self.model(**inputs)
47
+
48
+ segmentation_mask = outputs.logits.cpu().numpy()
49
+ segmentation_mask = segmentation_mask.squeeze()
50
+
51
+ segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6) # Normalize to 0-1
52
+ segmentation_mask = (segmentation_mask * 255).astype(np.uint8)
53
+
54
+ seg_image = Image.fromarray(segmentation_mask)
55
+
56
+ return [{"seg_image": seg_image}]
57
+
58
+ except Exception as e:
59
+ return [{"error": str(e)}]
60
+
61
+ # helper functions
62
+ def decode_image(self, image_data: str) -> Image.Image:
63
+ """Decodes a base64-encoded image into a PIL image."""
64
+ image_bytes = base64.b64decode(image_data)
65
+ return Image.open(io.BytesIO(image_bytes)).convert("RGB")
66
+
67
+ def process_depth(self, image):
68
+ print("Processing depth")
69
+ print(type(image))
70
+ if isinstance(image, np.ndarray):
71
+ image = Image.fromarray(image.astype("uint8"))
72
+ output = self.depth_pipe(image)
73
+ depth_map = np.array(output["depth"])
74
+
75
+ # Normalize to 0-255
76
+ depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6)
77
+ depth_map = (depth_map * 255).astype(np.uint8)
78
+
79
+ return Image.fromarray(depth_map)
80
+