import os import io import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt import tempfile class Config: ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") CHECKPOINTS = { "0.3b": "sapiens_0.3b_render_people_epoch_100_torchscript.pt2", "0.6b": "sapiens_0.6b_render_people_epoch_70_torchscript.pt2", "1b": "sapiens_1b_render_people_epoch_88_torchscript.pt2", "2b": "sapiens_2b_render_people_epoch_25_torchscript.pt2", } SEG_CHECKPOINTS = { "fg-bg-1b (recommended)": "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2", "no-bg-removal": None, "part-seg-1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", } class ModelManager: @staticmethod def load_model(checkpoint_name: str): if checkpoint_name is None: return None checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name) model = torch.jit.load(checkpoint_path) model.eval() model.to("cuda") return model @staticmethod @torch.inference_mode() def run_model(model, input_tensor, height, width): output = model(input_tensor) return F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False) class ImageProcessor: def __init__(self): self.transform_fn = transforms.Compose([ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]), ]) @spaces.GPU def process_image(self, image: Image.Image, depth_model_name: str, seg_model_name: str): depth_model = ModelManager.load_model(Config.CHECKPOINTS[depth_model_name]) input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda") depth_output = ModelManager.run_model(depth_model, input_tensor, image.height, image.width) depth_map = depth_output.squeeze().cpu().numpy() if seg_model_name != "no-bg-removal": seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name]) seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width) seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0] depth_map[seg_mask == 0] = np.nan depth_colored = self.colorize_depth_map(depth_map) npy_path = tempfile.mktemp(suffix='.npy') np.save(npy_path, depth_map) return Image.fromarray(depth_colored), npy_path @staticmethod def colorize_depth_map(depth_map): depth_foreground = depth_map[~np.isnan(depth_map)] if len(depth_foreground) > 0: min_val, max_val = np.nanmin(depth_foreground), np.nanmax(depth_foreground) depth_normalized = (depth_map - min_val) / (max_val - min_val) depth_normalized = 1 - depth_normalized depth_normalized = np.nan_to_num(depth_normalized, nan=0) cmap = plt.get_cmap('inferno') depth_colored = (cmap(depth_normalized) * 255).astype(np.uint8)[:, :, :3] else: depth_colored = np.zeros((depth_map.shape[0], depth_map.shape[1], 3), dtype=np.uint8) return depth_colored class GradioInterface: def __init__(self): self.image_processor = ImageProcessor() def create_interface(self): app_styles = """ <style> /* Global Styles */ body, #root { font-family: Helvetica, Arial, sans-serif; background-color: #1a1a1a; color: #fafafa; } /* Header Styles */ .app-header { background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); padding: 24px; border-radius: 8px; margin-bottom: 24px; text-align: center; } .app-title { font-size: 48px; margin: 0; color: #fafafa; } .app-subtitle { font-size: 24px; margin: 8px 0 16px; color: #fafafa; } .app-description { font-size: 16px; line-height: 1.6; opacity: 0.8; margin-bottom: 24px; } /* Button Styles */ .publication-links { display: flex; justify-content: center; flex-wrap: wrap; gap: 8px; margin-bottom: 16px; } .publication-link { display: inline-flex; align-items: center; padding: 8px 16px; background-color: #333; color: #fff !important; text-decoration: none !important; border-radius: 20px; font-size: 14px; transition: background-color 0.3s; } .publication-link:hover { background-color: #555; } .publication-link i { margin-right: 8px; } /* Content Styles */ .content-container { background-color: #2a2a2a; border-radius: 8px; padding: 24px; margin-bottom: 24px; } /* Image Styles */ .image-preview img { max-width: 512px; max-height: 512px; margin: 0 auto; border-radius: 4px; display: block; object-fit: contain; } /* Control Styles */ .control-panel { background-color: #333; padding: 16px; border-radius: 8px; margin-top: 16px; } /* Gradio Component Overrides */ .gr-button { background-color: #4a4a4a; color: #fff; border: none; border-radius: 4px; padding: 8px 16px; cursor: pointer; transition: background-color 0.3s; } .gr-button:hover { background-color: #5a5a5a; } .gr-input, .gr-dropdown { background-color: #3a3a3a; color: #fff; border: 1px solid #4a4a4a; border-radius: 4px; padding: 8px; } .gr-form { background-color: transparent; } .gr-panel { border: none; background-color: transparent; } /* Override any conflicting styles from Bulma */ .button.is-normal.is-rounded.is-dark { color: #fff !important; text-decoration: none !important; } </style> """ header_html = f""" <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css"> <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> {app_styles} <div class="app-header"> <h1 class="app-title">Sapiens: Depth Estimation</h1> <h2 class="app-subtitle">ECCV 2024 (Oral)</h2> <p class="app-description"> Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. This demo showcases the finetuned depth model. </p> <div class="publication-links"> <a href="https://arxiv.org/abs/2408.12569" class="publication-link"> <i class="fas fa-file-pdf"></i>arXiv </a> <a href="https://github.com/facebookresearch/sapiens" class="publication-link"> <i class="fab fa-github"></i>Code </a> <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link"> <i class="fas fa-globe"></i>Meta </a> <a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link"> <i class="fas fa-chart-bar"></i>Results </a> </div> <div class="publication-links"> <a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link"> <i class="fas fa-user"></i>Demo-Pose </a> <a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link"> <i class="fas fa-puzzle-piece"></i>Demo-Seg </a> <a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link"> <i class="fas fa-cube"></i>Demo-Depth </a> <a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link"> <i class="fas fa-vector-square"></i>Demo-Normal </a> </div> </div> """ js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'dark') { url.searchParams.set('__theme', 'dark'); window.location.href = url.href; } } """ def process_image(image, depth_model_name, seg_model_name): result, npy_path = self.image_processor.process_image(image, depth_model_name, seg_model_name) return result, npy_path with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: gr.HTML(header_html) with gr.Row(elem_classes="content-container"): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") with gr.Row(elem_classes="control-panel"): depth_model_name = gr.Dropdown( label="Depth Model Size", choices=list(Config.CHECKPOINTS.keys()), value="1b", ) seg_model_name = gr.Dropdown( label="Background Removal Model", choices=list(Config.SEG_CHECKPOINTS.keys()), value="fg-bg-1b (recommended)", ) example_model = gr.Examples( inputs=input_image, examples_per_page=14, examples=[ os.path.join(Config.ASSETS_DIR, "images", img) for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images")) ], ) with gr.Column(): result_image = gr.Image(label="Depth Estimation Result", type="pil", elem_classes="image-preview") npy_output = gr.File(label="Output (.npy). Note: Background depth is NaN.") run_button = gr.Button("Run", elem_classes="gr-button") run_button.click( fn=process_image, inputs=[input_image, depth_model_name, seg_model_name], outputs=[result_image, npy_output], ) return demo def main(): if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True interface = GradioInterface() demo = interface.create_interface() demo.launch(share=False) if __name__ == "__main__": main()