import os
import io
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import tempfile

class Config:
    ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
    CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
    CHECKPOINTS = {
        "0.3b": "sapiens_0.3b_render_people_epoch_100_torchscript.pt2",
        "0.6b": "sapiens_0.6b_render_people_epoch_70_torchscript.pt2",
        "1b": "sapiens_1b_render_people_epoch_88_torchscript.pt2",
        "2b": "sapiens_2b_render_people_epoch_25_torchscript.pt2",
    }
    SEG_CHECKPOINTS = {
        "fg-bg-1b (recommended)": "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2",
        "no-bg-removal": None,
        "part-seg-1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
    }

class ModelManager:
    @staticmethod
    def load_model(checkpoint_name: str):
        if checkpoint_name is None:
            return None
        checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name)
        model = torch.jit.load(checkpoint_path)
        model.eval()
        model.to("cuda")
        return model

    @staticmethod
    @torch.inference_mode()
    def run_model(model, input_tensor, height, width):
        output = model(input_tensor)
        return F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)

class ImageProcessor:
    def __init__(self):
        self.transform_fn = transforms.Compose([
            transforms.Resize((1024, 768)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]),
        ])

    @spaces.GPU
    def process_image(self, image: Image.Image, depth_model_name: str, seg_model_name: str):
        depth_model = ModelManager.load_model(Config.CHECKPOINTS[depth_model_name])
        input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda")
        depth_output = ModelManager.run_model(depth_model, input_tensor, image.height, image.width)
        depth_map = depth_output.squeeze().cpu().numpy()

        if seg_model_name != "no-bg-removal":
            seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name])
            seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
            seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
            depth_map[seg_mask == 0] = np.nan
        
        depth_colored = self.colorize_depth_map(depth_map)
        npy_path = tempfile.mktemp(suffix='.npy')
        np.save(npy_path, depth_map)

        return Image.fromarray(depth_colored), npy_path

    @staticmethod
    def colorize_depth_map(depth_map):
        depth_foreground = depth_map[~np.isnan(depth_map)]
        if len(depth_foreground) > 0:
            min_val, max_val = np.nanmin(depth_foreground), np.nanmax(depth_foreground)
            depth_normalized = (depth_map - min_val) / (max_val - min_val)
            depth_normalized = 1 - depth_normalized
            depth_normalized = np.nan_to_num(depth_normalized, nan=0)
            cmap = plt.get_cmap('inferno')
            depth_colored = (cmap(depth_normalized) * 255).astype(np.uint8)[:, :, :3]
        else:
            depth_colored = np.zeros((depth_map.shape[0], depth_map.shape[1], 3), dtype=np.uint8)
        return depth_colored

class GradioInterface:
    def __init__(self):
        self.image_processor = ImageProcessor()

    def create_interface(self):
        app_styles = """
        <style>
            /* Global Styles */
            body, #root {
                font-family: Helvetica, Arial, sans-serif;
                background-color: #1a1a1a;
                color: #fafafa;
            }

            /* Header Styles */
            .app-header {
                background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
                padding: 24px;
                border-radius: 8px;
                margin-bottom: 24px;
                text-align: center;
            }

            .app-title {
                font-size: 48px;
                margin: 0;
                color: #fafafa;
            }

            .app-subtitle {
                font-size: 24px;
                margin: 8px 0 16px;
                color: #fafafa;
            }

            .app-description {
                font-size: 16px;
                line-height: 1.6;
                opacity: 0.8;
                margin-bottom: 24px;
            }

            /* Button Styles */
            .publication-links {
                display: flex;
                justify-content: center;
                flex-wrap: wrap;
                gap: 8px;
                margin-bottom: 16px;
            }

            .publication-link {
                display: inline-flex;
                align-items: center;
                padding: 8px 16px;
                background-color: #333;
                color: #fff !important;
                text-decoration: none !important;
                border-radius: 20px;
                font-size: 14px;
                transition: background-color 0.3s;
            }

            .publication-link:hover {
                background-color: #555;
            }

            .publication-link i {
                margin-right: 8px;
            }

            /* Content Styles */
            .content-container {
                background-color: #2a2a2a;
                border-radius: 8px;
                padding: 24px;
                margin-bottom: 24px;
            }

            /* Image Styles */
            .image-preview img {
                max-width: 512px;
                max-height: 512px;  
                margin: 0 auto;
                border-radius: 4px;
                display: block;
                object-fit: contain;  
            }

            /* Control Styles */
            .control-panel {
                background-color: #333;
                padding: 16px;
                border-radius: 8px;
                margin-top: 16px;
            }

            /* Gradio Component Overrides */
            .gr-button {
                background-color: #4a4a4a;
                color: #fff;
                border: none;
                border-radius: 4px;
                padding: 8px 16px;
                cursor: pointer;
                transition: background-color 0.3s;
            }

            .gr-button:hover {
                background-color: #5a5a5a;
            }

            .gr-input, .gr-dropdown {
                background-color: #3a3a3a;
                color: #fff;
                border: 1px solid #4a4a4a;
                border-radius: 4px;
                padding: 8px;
            }

            .gr-form {
                background-color: transparent;
            }

            .gr-panel {
                border: none;
                background-color: transparent;
            }

            /* Override any conflicting styles from Bulma */
            .button.is-normal.is-rounded.is-dark {
                color: #fff !important;
                text-decoration: none !important;
            }
        </style>
        """

        header_html = f"""
        <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css">
        <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
        {app_styles}
        <div class="app-header">
            <h1 class="app-title">Sapiens: Depth Estimation</h1>
            <h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
            <p class="app-description">
                Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. 
                This demo showcases the finetuned depth model.
            </p>
            <div class="publication-links">
                <a href="https://arxiv.org/abs/2408.12569" class="publication-link">
                    <i class="fas fa-file-pdf"></i>arXiv
                </a>
                <a href="https://github.com/facebookresearch/sapiens" class="publication-link">
                    <i class="fab fa-github"></i>Code
                </a>
                <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link">
                    <i class="fas fa-globe"></i>Meta
                </a>
                <a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link">
                    <i class="fas fa-chart-bar"></i>Results
                </a>
            </div>
            <div class="publication-links">
                <a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link">
                    <i class="fas fa-user"></i>Demo-Pose
                </a>
                <a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link">
                    <i class="fas fa-puzzle-piece"></i>Demo-Seg
                </a>
                <a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link">
                    <i class="fas fa-cube"></i>Demo-Depth
                </a>
                <a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link">
                    <i class="fas fa-vector-square"></i>Demo-Normal
                </a>
            </div>
        </div>
        """

        js_func = """
        function refresh() {
            const url = new URL(window.location);
            if (url.searchParams.get('__theme') !== 'dark') {
                url.searchParams.set('__theme', 'dark');
                window.location.href = url.href;
            }
        }
        """

        def process_image(image, depth_model_name, seg_model_name):
            result, npy_path = self.image_processor.process_image(image, depth_model_name, seg_model_name)
            return result, npy_path

        with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
            gr.HTML(header_html)
            with gr.Row(elem_classes="content-container"):
                with gr.Column():
                    input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
                    with gr.Row(elem_classes="control-panel"):
                        depth_model_name = gr.Dropdown(
                            label="Depth Model Size",
                            choices=list(Config.CHECKPOINTS.keys()),
                            value="1b",
                        )
                        seg_model_name = gr.Dropdown(
                            label="Background Removal Model",
                            choices=list(Config.SEG_CHECKPOINTS.keys()),
                            value="fg-bg-1b (recommended)",
                        )
                    example_model = gr.Examples(
                        inputs=input_image,
                        examples_per_page=14,
                        examples=[
                            os.path.join(Config.ASSETS_DIR, "images", img)
                            for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images"))
                        ],
                    )
                with gr.Column():
                    result_image = gr.Image(label="Depth Estimation Result", type="pil", elem_classes="image-preview")
                    npy_output = gr.File(label="Output (.npy). Note: Background depth is NaN.")
                    run_button = gr.Button("Run", elem_classes="gr-button")

            run_button.click(
                fn=process_image,
                inputs=[input_image, depth_model_name, seg_model_name],
                outputs=[result_image, npy_output],
            )

        return demo

def main():
    if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    interface = GradioInterface()
    demo = interface.create_interface()
    demo.launch(share=False)

if __name__ == "__main__":
    main()