# Copyright (c) 2025 All rights reserved.

import os
import torch
import gradio as gr
import huggingface_hub
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont

# Import the base pipeline from diffusers
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from transformers import CLIPTextModel, CLIPTokenizer

# Define default parameters
DEFAULT_SEED = 42
DEFAULT_STEPS = 30
DEFAULT_GUIDANCE_SCALE = 7.5
RED_BG_COLOR = "#ffcccc"  # Light red background

# Initialize the model
def download_model():
    # Download the model (using a simple SD model as example)
    snapshot_download(repo_id='runwayml/stable-diffusion-v1-5', local_dir='./models/stable-diffusion', local_dir_use_symlinks=False)

def init_pipeline():
    # Initialize a simple text-to-image pipeline
    pipeline = StableDiffusionControlNetPipeline.from_pretrained(
        "./models/stable-diffusion",
        torch_dtype=torch.float16,
        safety_checker=None
    )
    pipeline = pipeline.to("cuda")
    return pipeline

# Generate image function
def generate_image(prompt, seed, num_steps, guidance_scale):
    try:
        # Make sure we have a valid seed
        if seed == 0:
            seed = torch.seed() & 0xFFFFFFFF
        
        # Set up generator for reproducibility
        generator = torch.Generator("cuda").manual_seed(seed)
        
        # Generate the image
        image = pipeline(
            prompt=prompt,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            generator=generator
        ).images[0]
        
        # Add watermark
        image = add_safety_watermark(image)
        
    except Exception as e:
        print(f"Error generating image: {e}")
        return gr.update()
    
    return gr.update(value=image, label=f"Generated Image, seed = {seed}")

# Add watermark to image
def add_safety_watermark(image, text='AI Generated'):
    width, height = image.size
    draw = ImageDraw.Draw(image)
    
    # Set font size based on image height
    font_size = int(height * 0.028)
    font = ImageFont.load_default()
    
    # Calculate text position
    text_width = len(text) * font_size * 0.6  # Approximate width
    x = width - text_width - 10
    y = height - font_size - 20
    
    # Add shadow and text
    draw.text((x+2, y+2), text, fill="black")
    draw.text((x, y), text, fill="white")
    
    return image

# Create example function
def generate_example(prompt, seed):
    return generate_image(prompt, seed, DEFAULT_STEPS, DEFAULT_GUIDANCE_SCALE)

# Sample examples
sample_list = [
    ['A majestic mountain landscape with snow peaks and pine trees', 123],
    ['A futuristic city with flying cars and tall skyscrapers', 456],
    ['A serene beach scene with clear blue waters', 789],
]

# Create the Gradio interface
with gr.Blocks(css=f".gradio-container {{ background-color: {RED_BG_COLOR} !important; }}") as demo:
    gr.HTML("""
    <div style="text-align: center; max-width: 800px; margin: 0 auto;">
        <h1 style="font-size: 2rem; font-weight: 700;">Simple Text to Image Generator</h1>
        <h2 style="font-size: 1.2rem; font-weight: 300; margin-bottom: 1rem;">Convert your text descriptions into images</h2>
    </div>
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            # Input components
            ui_prompt_text = gr.Textbox(label="Text Prompt", value="A beautiful landscape with mountains and trees")
            ui_seed = gr.Number(label="Seed (0 for random)", value=DEFAULT_SEED)
            ui_steps = gr.Slider(minimum=10, maximum=50, value=DEFAULT_STEPS, step=1, label="Number of Steps")
            ui_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, value=DEFAULT_GUIDANCE_SCALE, step=0.5, label="Guidance Scale")
            
            ui_btn_generate = gr.Button("Generate Image")
        
        with gr.Column(scale=3):
            # Output components
            image_output = gr.Image(label="Generated Image", interactive=False, height=512)
    
    gr.Examples(
        sample_list,
        inputs=[ui_prompt_text, ui_seed],
        outputs=[image_output],
        fn=generate_example,
        cache_examples=True
    )
    
    ui_btn_generate.click(
        generate_image,
        inputs=[ui_prompt_text, ui_seed, ui_steps, ui_guidance_scale],
        outputs=[image_output]
    )
    
    gr.Markdown(
        """
        ### How to Use:
        1. Enter a detailed text description of the image you want to create
        2. Adjust the parameters if needed (or leave as default)
        3. Click "Generate Image" and wait for the result
        
        ### Tips:
        - Detailed prompts work better than short ones
        - Try different seeds for different variations
        - Higher guidance scale values make the image follow the prompt more closely
        """
    )

# Initialize and launch
print("Downloading models...")
download_model()

print("Initializing pipeline...")
pipeline = init_pipeline()

print("Launching Gradio interface...")
demo.launch()