Meng Chen commited on
Commit
d48d98c
·
1 Parent(s): b8264a8

update handler

Browse files
Files changed (1) hide show
  1. handler.py +10 -5
handler.py CHANGED
@@ -11,8 +11,9 @@ class EndpointHandler():
11
  # Preload all the elements you are going to need at inference.
12
  # pseudo:
13
  # self.model= load_model(path)
 
14
  self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
15
- self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
16
  self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -23,14 +24,18 @@ class EndpointHandler():
23
  Return:
24
  A :obj:`list` | `dict`: will be serialized and returned
25
  """
26
- if "image" not in data or "text" not in data:
 
 
 
 
27
  return [{"error": "Missing 'image' or 'text' key in input data"}]
28
 
29
  try:
30
  # Decode base64 image
31
- image = self.decode_image(data["image"])
32
- prompts = data["text"].split(",")
33
-
34
  # Preprocess input
35
  inputs = self.processor(
36
  text=prompts,
 
11
  # Preload all the elements you are going to need at inference.
12
  # pseudo:
13
  # self.model= load_model(path)
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
  self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
16
+ self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device)
17
  self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
18
 
19
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
24
  Return:
25
  A :obj:`list` | `dict`: will be serialized and returned
26
  """
27
+ if "inputs" not in data:
28
+ return [{"error": "Missing 'inputs' key"}]
29
+
30
+ inputs_data = data["inputs"]
31
+ if "image" not in inputs_data or "text" not in inputs_data:
32
  return [{"error": "Missing 'image' or 'text' key in input data"}]
33
 
34
  try:
35
  # Decode base64 image
36
+ image = self.decode_image(inputs_data["image"])
37
+ prompts = inputs_data["text"]
38
+
39
  # Preprocess input
40
  inputs = self.processor(
41
  text=prompts,