huntrezz's picture
Update app.py
6380ca2 verified
raw
history blame
2.29 kB
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import gradio as gr
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float32)
model.eval()
parameters_to_prune = [
(module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4,
)
for module, _ in parameters_to_prune:
prune.remove(module, "weight")
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
)
model = model.to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
def preprocess_image(image):
image = cv2.resize(image, (128, 128))
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
return image / 255.0
def plot_depth_map(depth_map, azimuth):
fig = plt.figure(figsize=(16, 9))
ax = fig.add_subplot(111, projection='3d')
x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
ax.plot_surface(x, y, depth_map, cmap='viridis')
ax.view_init(elev=90, azim=azimuth) # Look down onto the depth map
ax.set_zlim(0, 1)
plt.close(fig)
fig.canvas.draw()
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return img
@torch.inference_mode()
def process_frame(image, azimuth):
if image is None:
return None
preprocessed = preprocess_image(image)
predicted_depth = model(preprocessed).predicted_depth
depth_map = predicted_depth.squeeze().cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
return plot_depth_map(depth_map, azimuth)
interface = gr.Interface(
fn=process_frame,
inputs=[gr.Image(sources="webcam", streaming=True), gr.Slider(0, 360, step=1)],
outputs="image",
live=True
)
interface.launch()