huntrezz's picture
Update app.py
fd26002 verified
raw
history blame
2.09 kB
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import gradio as gr
import torch.quantization
import torch.nn.utils.prune as prune
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float32)
model.eval()
# Apply global unstructured pruning
parameters_to_prune = [
(module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.3, # Prune 30% of weights
)
for module, _ in parameters_to_prune:
prune.remove(module, "weight")
# Apply quantization after pruning
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
).to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
input_tensor = torch.zeros((1, 3, 128, 128), dtype=torch.float32, device=device)
def preprocess_image(image):
return cv2.resize(image, (128, 72), interpolation=cv2.INTER_AREA).transpose(2, 0, 1).astype(np.float32) / 255.0
@torch.inference_mode()
def process_frame(image):
if image is None:
return None
preprocessed = preprocess_image(image)
input_tensor = torch.from_numpy(preprocessed).unsqueeze(0).to(device)
predicted_depth = model(input_tensor).predicted_depth
depth_map = predicted_depth.squeeze().cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map = (depth_map * 255).astype(np.uint8)
depth_map_colored = cv2.applyColorMap(depth_map, cv2.COLORMAP_INFERNO)
return cv2.cvtColor(depth_map_colored, cv2.COLOR_BGR2RGB)
interface = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources="webcam", streaming=True),
outputs="image",
live=True
)
interface.launch()