amoghrrao commited on
Commit
15ad333
·
verified ·
1 Parent(s): 3f5877b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -47
app.py CHANGED
@@ -7,64 +7,42 @@ from torchvision import transforms
7
  from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation
8
 
9
  def load_segmentation_model():
10
- try:
11
- print("Loading segmentation model...")
12
- model_name = "ZhengPeng7/BiRefNet"
13
- model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
14
- model.to(device)
15
- print("Segmentation model loaded successfully.")
16
- return model
17
- except Exception as e:
18
- print(f"Error loading segmentation model: {e}")
19
- return None
20
 
21
  def load_depth_model():
22
- try:
23
- print("Loading depth estimation model...")
24
- model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
25
- processor = AutoProcessor.from_pretrained(model_name)
26
- model = AutoModelForDepthEstimation.from_pretrained(model_name)
27
- model.to(device)
28
- print("Depth estimation model loaded successfully.")
29
- return processor, model
30
- except Exception as e:
31
- print(f"Error loading depth estimation model: {e}")
32
- return None, None
33
 
34
  def process_segmentation_image(image):
35
  transform = transforms.Compose([
36
  transforms.Resize((512, 512)),
37
  transforms.ToTensor(),
38
  ])
39
- input_tensor = transform(image).unsqueeze(0).to(device)
40
  return image, input_tensor
41
 
42
  def process_depth_image(image, processor):
43
  image = image.resize((512, 512))
44
- inputs = processor(images=image, return_tensors="pt").to(device)
45
  return image, inputs
46
 
47
  def segment_image(image, input_tensor, model):
48
- try:
49
- with torch.no_grad():
50
- outputs = model(input_tensor)
51
- output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
52
- mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy()
53
- mask = (mask > 0.5).astype(np.uint8) * 255
54
- return mask
55
- except Exception as e:
56
- print(f"Error during segmentation: {e}")
57
- return np.zeros((512, 512), dtype=np.uint8)
58
 
59
  def estimate_depth(inputs, model):
60
- try:
61
- with torch.no_grad():
62
- outputs = model(**inputs)
63
- depth_map = outputs.predicted_depth.squeeze().cpu().numpy()
64
- return depth_map
65
- except Exception as e:
66
- print(f"Error during depth estimation: {e}")
67
- return np.zeros((512, 512), dtype=np.float32)
68
 
69
  def normalize_depth_map(depth_map):
70
  min_val = np.min(depth_map)
@@ -95,9 +73,6 @@ def process_image_pipeline(image):
95
  segmentation_model = load_segmentation_model()
96
  depth_processor, depth_model = load_depth_model()
97
 
98
- if segmentation_model is None or depth_model is None:
99
- return Image.fromarray(np.zeros((512, 512), dtype=np.uint8)), image, image
100
-
101
  _, input_tensor = process_segmentation_image(image)
102
  _, inputs = process_depth_image(image, depth_processor)
103
 
@@ -108,9 +83,6 @@ def process_image_pipeline(image):
108
 
109
  return Image.fromarray(segmentation_mask), blurred_image, gaussian_blur_image
110
 
111
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
- print(f"Using device: {device}")
113
-
114
  iface = gr.Interface(
115
  fn=process_image_pipeline,
116
  inputs=gr.Image(type="pil"),
 
7
  from transformers import AutoProcessor, AutoModelForImageSegmentation, AutoModelForDepthEstimation
8
 
9
  def load_segmentation_model():
10
+ model_name = "ZhengPeng7/BiRefNet"
11
+ model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
12
+ return model
 
 
 
 
 
 
 
13
 
14
  def load_depth_model():
15
+ model_name = "depth-anything/Depth-Anything-V2-Metric-Indoor-Base-hf"
16
+ processor = AutoProcessor.from_pretrained(model_name)
17
+ model = AutoModelForDepthEstimation.from_pretrained(model_name)
18
+ return processor, model
 
 
 
 
 
 
 
19
 
20
  def process_segmentation_image(image):
21
  transform = transforms.Compose([
22
  transforms.Resize((512, 512)),
23
  transforms.ToTensor(),
24
  ])
25
+ input_tensor = transform(image).unsqueeze(0)
26
  return image, input_tensor
27
 
28
  def process_depth_image(image, processor):
29
  image = image.resize((512, 512))
30
+ inputs = processor(images=image, return_tensors="pt")
31
  return image, inputs
32
 
33
  def segment_image(image, input_tensor, model):
34
+ with torch.no_grad():
35
+ outputs = model(input_tensor)
36
+ output_tensor = outputs[0] if isinstance(outputs, list) else outputs.logits
37
+ mask = torch.sigmoid(output_tensor).squeeze().cpu().numpy()
38
+ mask = (mask > 0.5).astype(np.uint8) * 255
39
+ return mask
 
 
 
 
40
 
41
  def estimate_depth(inputs, model):
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ depth_map = outputs.predicted_depth.squeeze().cpu().numpy()
45
+ return depth_map
 
 
 
 
46
 
47
  def normalize_depth_map(depth_map):
48
  min_val = np.min(depth_map)
 
73
  segmentation_model = load_segmentation_model()
74
  depth_processor, depth_model = load_depth_model()
75
 
 
 
 
76
  _, input_tensor = process_segmentation_image(image)
77
  _, inputs = process_depth_image(image, depth_processor)
78
 
 
83
 
84
  return Image.fromarray(segmentation_mask), blurred_image, gaussian_blur_image
85
 
 
 
 
86
  iface = gr.Interface(
87
  fn=process_image_pipeline,
88
  inputs=gr.Image(type="pil"),