huntrezz's picture
Update app.py
18c3385 verified
raw
history blame
2.26 kB
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import gradio as gr
import torch.nn.utils.prune as prune
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float16)
model.eval()
# Apply global unstructured pruning
parameters_to_prune = [
(module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.4, # Prune 40% of weights
)
for module, _ in parameters_to_prune:
prune.remove(module, "weight")
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
)
model = model.half().to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
color_map = torch.from_numpy(color_map).to(device)
input_tensor = torch.zeros((1, 3, 128, 128), dtype=torch.float16, device=device)
def preprocess_image(image):
image = torch.from_numpy(image).to(device, dtype=torch.float16)
image = torch.nn.functional.interpolate(image.permute(2, 0, 1).unsqueeze(0), size=(128, 128), mode='bilinear', align_corners=False)
return (image.squeeze(0) / 255.0)
static_input = torch.zeros((1, 3, 128, 128), device=device, dtype=torch.float16)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_output = model(static_input)
@torch.inference_mode()
def process_frame(image):
if image is None:
return None
preprocessed = preprocess_image(image)
static_input.copy_(preprocessed)
g.replay()
depth_map = static_output.predicted_depth.squeeze()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map = (depth_map * 255).to(torch.uint8)
depth_map_colored = color_map[depth_map]
return depth_map_colored.cpu().numpy()
interface = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources="webcam", streaming=True),
outputs="image",
live=True
)
interface.launch()