import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
import numpy as np
import spaces
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
DEFAULT_WIDTH = 384
DEFAULT_HEIGHT = 384
PARALLEL_SIZE = 5
PATCH_SIZE = 16

# Load model and processor with error handling
def load_model():
    try:
        model_path = "deepseek-ai/Janus-Pro-7B"
        config = AutoConfig.from_pretrained(model_path)
        language_config = config.language_config
        language_config._attn_implementation = 'eager'

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Loading model on device: {device}")

        vl_gpt = AutoModelForCausalLM.from_pretrained(
            model_path,
            language_config=language_config,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32
        ).to(device)

        vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
        return vl_gpt, vl_chat_processor, device
    
    except Exception as e:
        logger.error(f"Model loading failed: {str(e)}")
        raise RuntimeError("Failed to load model. Please check the model path and dependencies.")

try:
    vl_gpt, vl_chat_processor, device = load_model()
    tokenizer = vl_chat_processor.tokenizer
except RuntimeError as e:
    raise e

# Helper functions with improved memory management
def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, progress=None):
    try:
        torch.cuda.empty_cache()
        tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int, device=device)
        
        for i in range(parallel_size * 2):
            tokens[i, :] = input_ids
            if i % 2 != 0:
                tokens[i, 1:-1] = vl_chat_processor.pad_id

        with torch.no_grad():
            inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
            generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int, device=device)

            pkv = None
            total_steps = 576
            for i in range(total_steps):
                if progress is not None:
                    progress((i + 1) / total_steps, desc="Generating image tokens")
                
                outputs = vl_gpt.language_model.model(
                    inputs_embeds=inputs_embeds,
                    use_cache=True,
                    past_key_values=pkv
                )
                pkv = outputs.past_key_values
                hidden_states = outputs.last_hidden_state
                logits = vl_gpt.gen_head(hidden_states[:, -1, :])

                logit_cond = logits[0::2, :]
                logit_uncond = logits[1::2, :]
                logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)

                probs = torch.softmax(logits / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                generated_tokens[:, i] = next_token.squeeze(dim=-1)

                next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
                img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
                inputs_embeds = img_embeds.unsqueeze(dim=1)

        return generated_tokens
    
    except RuntimeError as e:
        logger.error(f"Generation error: {str(e)}")
        raise RuntimeError("Generation failed due to memory constraints. Try reducing the parallel size.")
    finally:
        torch.cuda.empty_cache()

def unpack(patches, width, height, parallel_size=5):
    try:
        patches = patches.detach().to(device='cpu', dtype=torch.float32).numpy()
        patches = patches.transpose(0, 2, 3, 1)
        patches = np.clip((patches + 1) / 2 * 255, 0, 255)
        return [Image.fromarray(patch.astype(np.uint8)) for patch in patches]
    except Exception as e:
        logger.error(f"Unpacking error: {str(e)}")
        raise RuntimeError("Failed to process generated image data.")

@torch.inference_mode()
@spaces.GPU(duration=120)
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress()):
    try:
        if not prompt.strip():
            raise gr.Error("Please enter a valid prompt.")

        if progress is not None:
            progress(0, desc="Initializing...")
        torch.cuda.empty_cache()

        # Seed management
        if seed is None:
            seed = torch.seed()
        else:
            seed = int(seed)
            
        torch.manual_seed(seed)
        if device.type == "cuda":
            torch.cuda.manual_seed(seed)

        messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}]
        text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
            conversations=messages,
            sft_format=vl_chat_processor.sft_format,
            system_prompt=''
        ) + vl_chat_processor.image_start_tag

        input_ids = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device)
        
        if progress is not None:
            progress(0.1, desc="Generating image tokens...")
        
        generated_tokens = generate(
            input_ids,
            DEFAULT_WIDTH,
            DEFAULT_HEIGHT,
            cfg_weight=guidance,
            temperature=t2i_temperature,
            parallel_size=PARALLEL_SIZE,
            progress=progress
        )
        
        if progress is not None:
            progress(0.9, desc="Processing images...")
        patches = vl_gpt.gen_vision_model.decode_code(
            generated_tokens.to(dtype=torch.int),
            shape=[PARALLEL_SIZE, 8, DEFAULT_WIDTH // PATCH_SIZE, DEFAULT_HEIGHT // PATCH_SIZE]
        )
        
        images = unpack(patches, DEFAULT_WIDTH, DEFAULT_HEIGHT, PARALLEL_SIZE)
        return images

    except Exception as e:
        logger.error(f"Generation failed: {str(e)}", exc_info=True)
        if "index out of range" in str(e).lower():
            raise gr.Error("Image generation failed due to internal error. Please try again with different parameters.")
        else:
            raise gr.Error(f"Image generation failed: {str(e)}")

def create_interface():
    with gr.Blocks(title="Janus-Pro-7B Image Generator", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # Text-to-Image Generation with Janus-Pro-7B
        **Generate high-quality images from text prompts using DeepSeek's advanced multimodal AI model.**
        """)

        with gr.Row():
            with gr.Column(scale=3):
                prompt_input = gr.Textbox(
                    label="Prompt", 
                    placeholder="Describe the image you want to generate...", 
                    lines=3
                )
                generate_btn = gr.Button("Generate Images", variant="primary")
                
                with gr.Accordion("Advanced Settings", open=False):
                    with gr.Group():
                        seed_input = gr.Number(
                            label="Seed", 
                            value=None, 
                            precision=0, 
                            info="Leave empty for random seed"
                        )
                        guidance_slider = gr.Slider(
                            label="CFG Guidance Weight",
                            minimum=3,
                            maximum=10,
                            value=5,
                            step=0.5,
                            info="Higher values = more prompt adherence, lower values = more creativity"
                        )
                        temp_slider = gr.Slider(
                            label="Temperature",
                            minimum=0.1,
                            maximum=1.0,
                            value=1.0,
                            step=0.1,
                            info="Higher values = more randomness, lower values = more deterministic"
                        )

            with gr.Column(scale=2):
                output_gallery = gr.Gallery(
                    label="Generated Images", 
                    columns=2, 
                    height=600, 
                    preview=True
                )
                status = gr.Textbox(
                    label="Status", 
                    interactive=False
                )

        gr.Examples(
            examples=[
                ["A futuristic cityscape at sunset with flying cars and holographic advertisements"],
                ["An astronaut riding a horse in photorealistic style"],
                ["A cute robotic cat sitting on a stack of ancient books, digital art"]
            ],
            inputs=prompt_input
        )

        gr.Markdown("""
        ## Model Information
        - **Model:** [Janus-Pro-7B](https://huggingface.co/deepseek-ai/Janus-Pro-7B)
        - **Output Resolution:** 384x384 pixels
        - **Parallel Generation:** 5 images per request
        """)

        # Footer Section
        gr.Markdown("""
        <hr style="margin-top: 2em; margin-bottom: 1em;">
        <div style="text-align: center; color: #666; font-size: 0.9em;">
            Created with ❤️ by <a href="https://bilsimaging.com" target="_blank" style="color: #2563eb; text-decoration: none;">bilsimaging.com</a>
        </div>
        """)

        # Visitor Badge
        gr.HTML("""
        <div style="text-align: center; margin-top: 1em;">
            <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F">
                <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F&countColor=%23263759" 
                     alt="Visitor Badge"
                     style="display: inline-block; margin: 0 auto;">
            </a>
        </div>
        """)

        generate_btn.click(
            generate_image,
            inputs=[prompt_input, seed_input, guidance_slider, temp_slider],
            outputs=output_gallery,
            api_name="generate"
        )

        demo.load(
            fn=lambda: f"Device Status: {'GPU ✅' if device.type == 'cuda' else 'CPU ⚠️'}",
            outputs=status,
            queue=False
        )

    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True)