import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image

# Print system information
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Load a smaller model that should work even with limited resources
model_id = "Salesforce/blip-image-captioning-base"  # ~1 GB model, very reliable
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Create global variables for model and processor
processor = None
model = None

def load_model():
    global processor, model
    try:
        print("Loading model and processor...")
        processor = AutoProcessor.from_pretrained(model_id)
        model = AutoModelForVision2Seq.from_pretrained(model_id).to(device)
        print("Model loaded successfully")
        return True
    except Exception as e:
        print(f"Error loading model: {e}")
        return False

def analyze_image(image):
    # If model not loaded yet, try to load it
    global processor, model
    if model is None:
        success = load_model()
        if not success:
            return "Failed to load model. Check logs for details."
    
    try:
        if isinstance(image, str):
            # If image is a filepath
            image = Image.open(image).convert('RGB')
        elif not isinstance(image, Image.Image):
            # If image is numpy array (from gradio)
            image = Image.fromarray(image).convert('RGB')
            
        # Process image
        inputs = processor(images=image, return_tensors="pt").to(device)
        
        # Generate caption
        with torch.no_grad():
            output = model.generate(**inputs, max_length=100)
        
        # Decode caption
        caption = processor.decode(output[0], skip_special_tokens=True)
        
        # Get device information
        if device == "cuda":
            memory_info = torch.cuda.memory_allocated() / 1024**2
            return f"Caption: {caption}\n\nUsing device: {device} ({torch.cuda.get_device_name(0)})\nGPU memory used: {memory_info:.2f} MB"
        else:
            return f"Caption: {caption}\n\nUsing device: {device}"
            
    except Exception as e:
        print(f"Error during inference: {e}")
        return f"Error during inference: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Simple GPU Test") as demo:
    gr.Markdown("# Simple GPU Test with BLIP Image Captioning")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload an image")
            submit_btn = gr.Button("Generate Caption")
            
            # Show if GPU is available
            if torch.cuda.is_available():
                gr.Markdown(f"✅ **GPU detected**: {torch.cuda.get_device_name(0)}")
            else:
                gr.Markdown("❌ **No GPU detected**. Running on CPU.")
                
        with gr.Column():
            output_text = gr.Textbox(label="Result", lines=5)
    
    submit_btn.click(
        fn=analyze_image,
        inputs=[image_input],
        outputs=[output_text]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")