huntrezz commited on
Commit
6380ca2
·
verified ·
1 Parent(s): ba195a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -22
app.py CHANGED
@@ -12,14 +12,13 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float32)
13
  model.eval()
14
 
15
- # Apply global unstructured pruning
16
  parameters_to_prune = [
17
  (module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
18
  ]
19
  prune.global_unstructured(
20
  parameters_to_prune,
21
  pruning_method=prune.L1Unstructured,
22
- amount=0.4, # Prune 40% of weights
23
  )
24
 
25
  for module, _ in parameters_to_prune:
@@ -33,48 +32,41 @@ model = model.to(device)
33
 
34
  processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
35
 
36
- color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
37
- color_map = torch.from_numpy(color_map).to(device)
38
-
39
  def preprocess_image(image):
40
- image = cv2.resize(image, (128, 72))
41
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
42
  return image / 255.0
43
 
44
- def plot_depth_map(depth_map):
45
- fig = plt.figure(figsize=(16, 9)) # Set figure size to 16:9
46
  ax = fig.add_subplot(111, projection='3d')
47
  x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
48
  ax.plot_surface(x, y, depth_map, cmap='viridis')
49
- ax.view_init(azim=90, elev=0) # Rotate the view to face forward
50
  ax.set_zlim(0, 1)
51
  plt.close(fig)
52
- return fig
 
 
 
 
 
53
 
54
  @torch.inference_mode()
55
- def process_frame(image):
56
  if image is None:
57
  return None
58
  preprocessed = preprocess_image(image)
59
  predicted_depth = model(preprocessed).predicted_depth
60
  depth_map = predicted_depth.squeeze().cpu().numpy()
61
 
62
- # Normalize depth map
63
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
64
 
65
- # Create 3D plot
66
- fig = plot_depth_map(depth_map)
67
-
68
- # Convert plot to image
69
- fig.canvas.draw()
70
- img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
71
- img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
72
-
73
- return img
74
 
75
  interface = gr.Interface(
76
  fn=process_frame,
77
- inputs=gr.Image(sources="webcam", streaming=True),
78
  outputs="image",
79
  live=True
80
  )
 
12
  model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float32)
13
  model.eval()
14
 
 
15
  parameters_to_prune = [
16
  (module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
17
  ]
18
  prune.global_unstructured(
19
  parameters_to_prune,
20
  pruning_method=prune.L1Unstructured,
21
+ amount=0.4,
22
  )
23
 
24
  for module, _ in parameters_to_prune:
 
32
 
33
  processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
34
 
 
 
 
35
  def preprocess_image(image):
36
+ image = cv2.resize(image, (128, 128))
37
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
38
  return image / 255.0
39
 
40
+ def plot_depth_map(depth_map, azimuth):
41
+ fig = plt.figure(figsize=(16, 9))
42
  ax = fig.add_subplot(111, projection='3d')
43
  x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
44
  ax.plot_surface(x, y, depth_map, cmap='viridis')
45
+ ax.view_init(elev=90, azim=azimuth) # Look down onto the depth map
46
  ax.set_zlim(0, 1)
47
  plt.close(fig)
48
+
49
+ fig.canvas.draw()
50
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
51
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
52
+
53
+ return img
54
 
55
  @torch.inference_mode()
56
+ def process_frame(image, azimuth):
57
  if image is None:
58
  return None
59
  preprocessed = preprocess_image(image)
60
  predicted_depth = model(preprocessed).predicted_depth
61
  depth_map = predicted_depth.squeeze().cpu().numpy()
62
 
 
63
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
64
 
65
+ return plot_depth_map(depth_map, azimuth)
 
 
 
 
 
 
 
 
66
 
67
  interface = gr.Interface(
68
  fn=process_frame,
69
+ inputs=[gr.Image(sources="webcam", streaming=True), gr.Slider(0, 360, step=1)],
70
  outputs="image",
71
  live=True
72
  )