import cv2 import torch import numpy as np from transformers import DPTForDepthEstimation, DPTImageProcessor import time import warnings import asyncio import json import websockets warnings.filterwarnings("ignore", message="It looks like you are trying to rescale already rescaled images.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float16).to(device) processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256") cap = cv2.VideoCapture(0) def resize_image(image, target_size=(256, 256)): return cv2.resize(image, target_size) def manual_normalize(depth_map): min_val = np.min(depth_map) max_val = np.max(depth_map) if min_val != max_val: normalized = (depth_map - min_val) / (max_val - min_val) return (normalized * 255).astype(np.uint8) else: return np.zeros_like(depth_map, dtype=np.uint8) frame_skip = 4 color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO) connected = set() async def broadcast(message): for websocket in connected: try: await websocket.send(message) except websockets.exceptions.ConnectionClosed: connected.remove(websocket) async def handler(websocket, path): connected.add(websocket) try: await websocket.wait_closed() finally: connected.remove(websocket) async def process_frames(): frame_count = 0 prev_frame_time = 0 while True: ret, frame = cap.read() if not ret: break frame_count += 1 if frame_count % frame_skip != 0: continue rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) resized_frame = resize_image(rgb_frame) inputs = processor(images=resized_frame, return_tensors="pt").to(device) inputs = {k: v.to(torch.float16) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) predicted_depth = outputs.predicted_depth depth_map = predicted_depth.squeeze().cpu().numpy() depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0) depth_map = depth_map.astype(np.float32) if depth_map.size == 0: depth_map = np.zeros((256, 256), dtype=np.uint8) else: if np.any(depth_map) and np.min(depth_map) != np.max(depth_map): depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) else: depth_map = np.zeros_like(depth_map, dtype=np.uint8) if np.all(depth_map == 0): depth_map = manual_normalize(depth_map) data = { 'depthMap': depth_map.tolist(), 'rgbFrame': rgb_frame.tolist() } await broadcast(json.dumps(data)) new_frame_time = time.time() fps = 1 / (new_frame_time - prev_frame_time) prev_frame_time = new_frame_time if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() async def main(): server = await websockets.serve(handler, "localhost", 8765) await asyncio.gather(server.wait_closed(), process_frames()) if __name__ == "__main__": asyncio.run(main())