import gradio as gr
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel
import tempfile
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub import list_models
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from packaging import version
import os
from torchao.quantization import (
    Int4WeightOnlyConfig,
    Int8WeightOnlyConfig,
    Int8DynamicActivationInt8WeightConfig,
    Float8WeightOnlyConfig,
    Float8DynamicActivationFloat8WeightConfig,
    GemliteUIntXWeightOnlyConfig,
)

MAP_QUANT_TYPE_TO_NAME = {
    "Int4WeightOnly": "int4wo",
    "GemliteUIntXWeightOnly": "intxwo-gemlite",
    "Int8WeightOnly": "int8wo",
    "Int8DynamicActivationInt8Weight": "int8da8w8",
    "Float8WeightOnly": "float8wo",
    "Float8DynamicActivationFloat8Weight": "float8da8w8",
    "autoquant": "autoquant",
}
MAP_QUANT_TYPE_TO_CONFIG = {
    "Int4WeightOnly": Int4WeightOnlyConfig,
    "GemliteUIntXWeightOnly": GemliteUIntXWeightOnlyConfig,
    "Int8WeightOnly": Int8WeightOnlyConfig,
    "Int8DynamicActivationInt8Weight": Int8DynamicActivationInt8WeightConfig,
    "Float8WeightOnly": Float8WeightOnlyConfig,
    "Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig,
}


def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
    # ^ expect a gr.OAuthProfile object as input to get the user's profile
    # if the user is not logged in, profile will be None
    if profile is None:
        return "Hello !"
    return f"Hello {profile.name} !"


def check_model_exists(
    oauth_token: gr.OAuthToken | None,
    username,
    quantization_type,
    group_size,
    model_name,
    quantized_model_name,
):
    """Check if a model exists in the user's Hugging Face repository."""
    try:
        models = list_models(author=username, token=oauth_token.token)
        model_names = [model.id for model in models]
        if quantized_model_name:
            repo_name = f"{username}/{quantized_model_name}"
        else:
            if (
                quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"]
            ) and (group_size is not None):
                repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
            else:
                repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
        if repo_name in model_names:
            return f"Model '{repo_name}' already exists in your repository."
        else:
            return None  # Model does not exist
    except Exception as e:
        return f"Error checking model existence: {str(e)}"


def create_model_card(model_name, quantization_type, group_size):
    # Try to download the original README
    original_readme = ""
    original_yaml_header = ""
    try:
        # Download the README.md file from the original model
        model_path = snapshot_download(
            repo_id=model_name, allow_patterns=["README.md"], repo_type="model"
        )
        readme_path = os.path.join(model_path, "README.md")

        if os.path.exists(readme_path):
            with open(readme_path, "r", encoding="utf-8") as f:
                content = f.read()

                if content.startswith("---"):
                    parts = content.split("---", 2)
                    if len(parts) >= 3:
                        original_yaml_header = parts[1]
                        original_readme = "---".join(parts[2:])
                    else:
                        original_readme = content
                else:
                    original_readme = content
    except Exception as e:
        print(f"Error reading original README: {str(e)}")
        original_readme = ""

    # Create new YAML header with base_model field
    yaml_header = f"""---
base_model:
- {model_name}"""

    # Add any original YAML fields except base_model
    if original_yaml_header:
        in_base_model_section = False
        found_tags = False
        for line in original_yaml_header.strip().split("\n"):
            # Skip if we're in a base_model section that continues to the next line
            if in_base_model_section:
                if (
                    line.strip().startswith("-")
                    or not line.strip()
                    or line.startswith(" ")
                ):
                    continue
                else:
                    in_base_model_section = False

            # Check for base_model field
            if line.strip().startswith("base_model:"):
                in_base_model_section = True
                # If base_model has inline value (like "base_model: model_name")
                if ":" in line and len(line.split(":", 1)[1].strip()) > 0:
                    in_base_model_section = False
                continue

            # Check for tags field and add bnb-my-repo
            if line.strip().startswith("tags:"):
                found_tags = True
                yaml_header += f"\n{line}"
                yaml_header += "\n- torchao-my-repo"
                continue

            yaml_header += f"\n{line}"

        # If tags field wasn't found, add it
        if not found_tags:
            yaml_header += "\ntags:"
            yaml_header += "\n- torchao-my-repo"
    # Complete the YAML header
    yaml_header += "\n---"

    # Create the quantization info section
    quant_info = f"""
# {model_name} (Quantized)

## Description
This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}). 

It's quantized using the TorchAO library using the [torchao-my-repo](https://huggingface.co/spaces/pytorch/torchao-my-repo) space.

## Quantization Details
- **Quantization Type**: {quantization_type}
- **Group Size**: {group_size}

"""

    # Combine everything
    model_card = yaml_header + quant_info

    # Append original README content if available
    if original_readme and not original_readme.isspace():
        model_card += "\n\n# 📄 Original Model Information\n\n" + original_readme
    return model_card


def quantize_model(
    model_name, quantization_type, group_size=128, auth_token=None, username=None, progress=gr.Progress()
):
    print(f"Quantizing model: {quantization_type}")
    progress(0, desc="Preparing Quantization")
    if (
        quantization_type == "GemliteUIntXWeightOnly"
    ):
        quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](
            group_size=group_size
        )
        quantization_config = TorchAoConfig(quant_config)
    elif quantization_type == "Int4WeightOnly":
        from torchao.dtypes import Int4CPULayout
        
        quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](
            group_size=group_size, layout=Int4CPULayout()
        )
        quantization_config = TorchAoConfig(quant_config)
    elif quantization_type == "autoquant":
        quantization_config = TorchAoConfig(quantization_type)
    else:
        quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]()
        quantization_config = TorchAoConfig(quant_config)
    progress(0.10, desc="Quantizing model")
    model = AutoModel.from_pretrained(
        model_name,
        torch_dtype="auto",
        quantization_config=quantization_config,
        device_map="cpu",
        use_auth_token=auth_token.token,
    )
    progress(0.45, desc="Quantization completed")
    return model


def save_model(
    model,
    model_name,
    quantization_type,
    group_size=128,
    username=None,
    auth_token=None,
    quantized_model_name=None,
    public=True,
    progress=gr.Progress(),
):
    progress(0.50, desc="Preparing to push")
    print("Saving quantized model")
    with tempfile.TemporaryDirectory() as tmpdirname:
        # Load and save the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_name, use_auth_token=auth_token.token
        )
        tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token)

        # Save the model
        progress(0.60, desc="Saving model")
        model.save_pretrained(
            tmpdirname, safe_serialization=False, use_auth_token=auth_token.token
        )
        
        if quantized_model_name:
            repo_name = f"{username}/{quantized_model_name}"
        else:
            if (
                quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"]
            ) and (group_size is not None):
                repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}"
            else:
                repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}"
        progress(0.70, desc="Creating model card")
        model_card = create_model_card(model_name, quantization_type, group_size)
        with open(os.path.join(tmpdirname, "README.md"), "w") as f:
            f.write(model_card)
        # Push to Hub
        api = HfApi(token=auth_token.token)
        api.create_repo(repo_name, exist_ok=True, private=not public)
        progress(0.80, desc="Pushing to Hub")
        api.upload_folder(
            folder_path=tmpdirname,
            repo_id=repo_name,
            repo_type="model",
        )
        progress(1.00, desc="Pushing to Hub completed")
    
    import io
    from contextlib import redirect_stdout
    import html

    # Capture the model architecture string
    f = io.StringIO()
    with redirect_stdout(f):
        print(model)
    model_architecture_str = f.getvalue()

    # Escape HTML characters and format with line breaks
    model_architecture_str_html = html.escape(model_architecture_str).replace(
        "\n", "<br/>"
    )

    # Format it for display in markdown with proper styling
    model_architecture_info = f"""
    <div class="model-architecture-container" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;">
        <h3 style="margin-top: 0; color: #2E7D32;">📋 Model Architecture</h3>
        <div class="model-architecture" style="max-height: 500px; overflow-y: auto; overflow-x: auto; background-color: #f5f5f5; padding: 5px; border-radius: 8px; font-family: monospace; white-space: pre-wrap;">
        <div style="line-height: 1.2; font-size: 0.75em;">{model_architecture_str_html}</div>
        </div>
    </div>
    """

    repo_link = f"""
    <div class="repo-link" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;">
        <h3 style="margin-top: 0; color: #2E7D32;">🔗 Repository Link</h3>
        <p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a></p>
    </div>
    """
    return (
        f"<h1>🎉 Quantization Completed</h1><br/>{repo_link}{model_architecture_info}"
    )


def quantize_and_save(
    profile: gr.OAuthProfile | None,
    oauth_token: gr.OAuthToken | None,
    model_name,
    quantization_type,
    group_size,
    quantized_model_name,
    public,
):
    if oauth_token is None:
        return """
        <div class="error-box">
            <h3>❌ Authentication Error</h3>
            <p>Please sign in to your HuggingFace account to use the quantizer.</p>
        </div>
        """
    if not profile:
        return """
        <div class="error-box">
            <h3>❌ Authentication Error</h3>
            <p>Please sign in to your HuggingFace account to use the quantizer.</p>
        </div>
        """
    if not group_size.isdigit():
        if group_size != "":
            return """
            <div class="error-box">
                <h3>❌ Group Size Error</h3>
                <p>Group Size is a parameter for Int4WeightOnly or GemliteUIntXWeightOnly</p>
            </div>
            """

    if group_size and group_size.strip():
        group_size = int(group_size)
    else:
        group_size = None

    exists_message = check_model_exists(
        oauth_token,
        profile.username,
        quantization_type,
        group_size,
        model_name,
        quantized_model_name,
    )
    if exists_message:
        return f"""
        <div class="warning-box">
            <h3>⚠️ Model Already Exists</h3>
            <p>{exists_message}</p>
        </div>
        """
    # if quantization_type == "int4_weight_only" :
    #     return "int4_weight_only not supported on cpu"

    try:
        quantized_model = quantize_model(
            model_name, quantization_type, group_size, oauth_token, profile.username
        )
        return save_model(
            quantized_model,
            model_name,
            quantization_type,
            group_size,
            profile.username,
            oauth_token,
            quantized_model_name,
            public,
        )
    except Exception as e:
        # raise e
        return str(e)


def get_model_size(model):
    """
    Calculate the size of a PyTorch model in gigabytes.

    Args:
        model: PyTorch model

    Returns:
        float: Size of the model in GB
    """
    # Get model state dict
    state_dict = model.state_dict()

    # Calculate total size in bytes
    total_size = 0
    for param in state_dict.values():
        # Calculate bytes for each parameter
        total_size += param.nelement() * param.element_size()

    # Convert bytes to gigabytes (1 GB = 1,073,741,824 bytes)
    size_gb = total_size / (1024**3)
    size_gb = round(size_gb, 2)

    return size_gb


# Add enhanced CSS styling
css = """
/* Custom CSS for enhanced UI */
.gradio-container {overflow-y: auto;}

/* Fix alignment for radio buttons and dropdowns */
.gradio-radio, .gradio-dropdown {
    display: flex !important;
    align-items: center !important;
    margin: 10px 0 !important;
}

/* Consistent spacing and alignment */
.gradio-dropdown, .gradio-textbox, .gradio-radio {
    margin-bottom: 12px !important;
    width: 100% !important;
}


button[variant="primary"]::before {
    content: "🔥 ";  /* PyTorch flame icon */
}

button[variant="primary"]:hover {
    transform: translateY(-5px) scale(1.05) !important;
    box-shadow: 0 10px 25px rgba(238, 76, 44, 0.7) !important;
}

@keyframes pytorch-glow {
    from {
        box-shadow: 0 0 10px rgba(238, 76, 44, 0.5);
    }
    to {
        box-shadow: 0 0 20px rgba(238, 76, 44, 0.8), 0 0 30px rgba(255, 156, 0, 0.5);
    }
}

/* Login button styling */
#login-button {
    background: linear-gradient(135deg, #EE4C2C, #FF9C00) !important;
    color: white !important;
    font-weight: 700 !important;
    border: none !important;
    border-radius: 15px !important;
    box-shadow: 0 0 15px rgba(238, 76, 44, 0.5) !important;
    transition: all 0.3s ease !important;
    max-width: 300px !important;
    margin: 0 auto !important;
}

.quantize-button {
    background: linear-gradient(135deg, #EE4C2C, #FF9C00) !important;
    color: white !important;
    font-weight: 700 !important;
    border: none !important;
    border-radius: 15px !important;
    box-shadow: 0 0 15px rgba(238, 76, 44, 0.5) !important;
    transition: all 0.3s ease !important;
    animation: pytorch-glow 1.5s infinite alternate !important;
    transform-origin: center !important;
    letter-spacing: 0.5px !important;
    text-shadow: 0 1px 2px rgba(0, 0, 0, 0.2) !important;
}

.quantize-button:hover {
    transform: translateY(-3px) scale(1.03) !important;
    box-shadow: 0 8px 20px rgba(238, 76, 44, 0.7) !important;
}
"""

# Update the main app layout
with gr.Blocks(css=css) as demo:
    gr.Markdown(
        """
        # 🤗 TorchAO Model Quantizer ✨
        
        Quantize your favorite Hugging Face models using TorchAO and save them to your profile!
        
        <br/>
        """
    )

    gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250)

    m1 = gr.Markdown()
    demo.load(hello, inputs=None, outputs=m1)

    with gr.Row():
        with gr.Column():
            with gr.Row():
                model_name = HuggingfaceHubSearch(
                    label="🔍 Hub Model ID",
                    placeholder="Search for model id on Huggingface",
                    search_type="model",
                )

            gr.Markdown("""### ⚙️ Quantization Settings""")
            with gr.Row():
                with gr.Column():
                    quantization_type = gr.Dropdown(
                        info="Select the Quantization method",
                        choices=[
                            "Int4WeightOnly",
                            "GemliteUIntXWeightOnly"
                            "Int8WeightOnly",
                            "Int8DynamicActivationInt8Weight",
                            "Float8WeightOnly",
                            "Float8DynamicActivationFloat8Weight",
                            "autoquant",
                        ],
                        value="int8_weight_only",
                        filterable=False,
                        show_label=False,
                    )

                    group_size = gr.Textbox(
                        info="Group Size (only for int4_weight_only and int8_weight_only)",
                        value="128",
                        interactive=(quantization_type.value == "int4_weight_only" or quantization_type.value == "int8_weight_only"),
                        show_label=False,
                    )

            gr.Markdown(
                        """
                        ### 💾 Saving Settings
                        """
                    )
            with gr.Row():
                quantized_model_name = gr.Textbox(
                    label="✏️ Model Name",
                    info="Model Name (optional : to override default)",
                    value="",
                    interactive=True,
                    elem_classes="model-name-textbox",
                    show_label=False,
                )
            with gr.Row():
                public = gr.Checkbox(
                    label="🌐 Make model public",
                    info="If checked, the model will be publicly accessible",
                    value=True,
                    interactive=True,
                    show_label=True,
                )

        with gr.Column():
            quantize_button = gr.Button(
                "🚀 Quantize and Push to Hub", elem_classes="quantize-button", elem_id="quantize-button"
            )
            output_link = gr.Markdown(
                label="🔗 Quantized Model Info", container=True, min_height=200
            )

    # Add information section
    with gr.Accordion("📚 About TorchAO Quantization", open=True):
        gr.Markdown(
            """
            ## 📝 Quantization Options
            
            ### Quantization Types
            - **Int4WeightOnly**: 4-bit weight-only quantization
            - **GemliteUIntXWeightOnly**: uintx gemlite quantization (default to 4 bit only for now)
            - **Int8WeightOnly**: 8-bit weight-only quantization
            - **Int8DynamicActivationInt8Weight**: 8-bit quantization for both weights and activations
            - **Float8WeightOnly**: float8 weight-only quantization
            - **Float8DynamicActivationFloat8Weight**: float8 quantization for both weights and activations
            - **autoquant**: automatic quantization (uses the best quantization method for the model)

            ### Group Size
            - Only applicable for Int4WeightOnly and GemliteUIntXWeightOnly quantization
            - Default value is 128
            - Affects the granularity of quantization
            
            ## 🔍 How It Works
            1. Downloads the original model
            2. Applies TorchAO quantization with your selected settings
            3. Uploads the quantized model to your HuggingFace account
            
            ## 📊 Memory Benefits
            - int4 quantization can reduce model size by up to 75%
            - int8 quantization typically reduces size by about 50%
            """
        )
    # Keep existing click handler
    quantize_button.click(
        fn=quantize_and_save,
        inputs=[model_name, quantization_type, group_size, quantized_model_name, public],
        outputs=[output_link],
    )

# Launch the app
demo.launch(share=True)