Spaces:
Runtime error
Runtime error
File size: 2,467 Bytes
f8b3886 0143794 e8486cb 893be2d 3548ace f8b3886 893be2d 99bbe3e 726a72f 893be2d 4f1fd81 893be2d fd26002 726a72f 4f1fd81 99bbe3e 893be2d 7b83683 f8b3886 f170544 a42d79c 99bbe3e a42d79c 3548ace 79684c1 e8486cb 1f906f0 cafea28 99bbe3e 40334e7 001bc7d 40334e7 3548ace 40334e7 3548ace 40334e7 3548ace 7bc8ed0 4f1fd81 d3c5921 4f1fd81 f8b3886 f170544 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import gradio as gr
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float32)
model.eval()
# Apply global unstructured pruning
parameters_to_prune = [
(module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # Prune 40% of weights
)
for module, _ in parameters_to_prune:
prune.remove(module, "weight")
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
)
model = model.to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
color_map = torch.from_numpy(color_map).to(device)
visualizer = DepthVisualizer()
def preprocess_image(image):
image = cv2.resize(image, (128, 128))
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
return image / 255.0
def plot_depth_map(depth_map):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
ax.plot_surface(x, y, depth_map, cmap='viridis')
ax.set_zlim(0, 1)
plt.close(fig)
return fig
@torch.inference_mode()
def process_frame(image):
if image is None:
return None
preprocessed = preprocess_image(image)
predicted_depth = model(preprocessed).predicted_depth
depth_map = predicted_depth.squeeze().cpu().numpy()
# Normalize depth map
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
# Create 3D plot
fig = plot_depth_map(depth_map)
# Convert plot to image
fig.canvas.draw()
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return img
interface = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources="webcam", streaming=True),
outputs="image",
live=True
)
interface.launch() |