import gradio as gr
from PIL import Image
import os
import time
import numpy as np
import torch
import warnings
import stat
import subprocess
import sys

# Set environment variables
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

# Print system information
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available via PyTorch: {torch.cuda.is_available()}")
print(f"CUDA version via PyTorch: {torch.version.cuda if torch.cuda.is_available() else 'Not available'}")

# Try to run nvidia-smi
def run_nvidia_smi():
    try:
        result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        if result.returncode == 0:
            print("nvidia-smi output:")
            print(result.stdout)
            return True
        else:
            print("nvidia-smi error:")
            print(result.stderr)
            return False
    except Exception as e:
        print(f"Error running nvidia-smi: {str(e)}")
        return False

# Run nvidia-smi
nvidia_smi_available = run_nvidia_smi()
print(f"nvidia-smi available: {nvidia_smi_available}")

# Show CUDA devices
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")

# Ensure all cache directories exist with proper permissions
def setup_cache_directories():
    # Gradio cache directory
    cache_dir = os.path.join(os.getcwd(), "gradio_cached_examples")
    os.makedirs(cache_dir, exist_ok=True)
    
    # HuggingFace cache directories
    hf_cache = os.path.join(os.getcwd(), ".cache", "huggingface")
    transformers_cache = os.path.join(hf_cache, "transformers")
    os.makedirs(hf_cache, exist_ok=True)
    os.makedirs(transformers_cache, exist_ok=True)
    
    # Set permissions
    try:
        for directory in [cache_dir, hf_cache, transformers_cache]:
            if os.path.exists(directory):
                os.chmod(directory, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)  # 0o777
                print(f"Set permissions for {directory}")
    except Exception as e:
        print(f"Warning: Could not set permissions: {str(e)}")
    
    return cache_dir

# Set up cache directories
cache_dir = setup_cache_directories()

# Suppress specific warnings that might be caused by package version mismatches
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
warnings.filterwarnings("ignore", message=".*Torch is not compiled with CUDA enabled.*")
warnings.filterwarnings("ignore", category=UserWarning)

# Check for actual GPU availability
def check_gpu_availability():
    """Check if GPU is actually available and working"""
    print("Checking GPU availability...")
    
    if not torch.cuda.is_available():
        print("CUDA is not available in PyTorch")
        return False
    
    try:
        # Try to initialize CUDA and run a simple operation
        print("Attempting to create a tensor on CUDA...")
        x = torch.rand(10, device="cuda")
        y = x + x
        print("Successfully created and operated on CUDA tensor")
        return True
    except Exception as e:
        print(f"GPU initialization failed: {str(e)}")
        return False

# Global variables
internvl2_model = None
USE_GPU = check_gpu_availability()

if USE_GPU:
    print("GPU is available and working properly")
else:
    print("WARNING: GPU is not available or not working properly. This application requires GPU acceleration.")

# ALTERNATIVE MODEL: Let's try a simpler vision model as backup
try:
    from transformers import BlipProcessor, BlipForConditionalGeneration
    HAS_BLIP = True
    blip_processor = None
    blip_model = None
    print("Successfully imported BLIP model")
except ImportError:
    HAS_BLIP = False
    print("BLIP model not available, will try InternVL2")

# Try importing lmdeploy for InternVL2
try:
    from lmdeploy import pipeline, TurbomindEngineConfig
    HAS_LMDEPLOY = True
    print("Successfully imported lmdeploy")
except ImportError as e:
    HAS_LMDEPLOY = False
    print(f"lmdeploy import failed: {str(e)}. Will try backup model.")

# Try to load the appropriate model
def load_model():
    global internvl2_model, blip_processor, blip_model
    
    if not USE_GPU:
        print("Cannot load models without GPU acceleration.")
        return False
    
    # Try to load BLIP first since it's more reliable
    if HAS_BLIP:
        try:
            print("Loading BLIP model...")
            blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
            blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
            print("BLIP model loaded successfully!")
        except Exception as e:
            print(f"Failed to load BLIP: {str(e)}")
            blip_processor = None
            blip_model = None
    
    # Then try InternVL2 if lmdeploy is available
    if HAS_LMDEPLOY:
        try:
            print("Attempting to load InternVL2 model...")
            # Configure for AWQ quantized model with larger context
            backend_config = TurbomindEngineConfig(
                model_format='awq',
                session_len=4096,  # Increased session length
                max_batch_size=1,  # Limit batch size to reduce memory usage
                cache_max_entry_count=0.3,  # Adjust cache to optimize for single requests
                tp=1  # Set tensor parallelism to 1 (use single GPU)
            )
            
            # Set to non-streaming mode with explicit token limits
            internvl2_model = pipeline(
                "OpenGVLab/InternVL2-40B-AWQ", 
                backend_config=backend_config,
                model_name_or_path=None,
                backend_name="turbomind",
                stream=False,  # Disable streaming
                max_new_tokens=512,  # Explicitly set max new tokens
            )
            
            print("InternVL2 model loaded successfully!")
        except Exception as e:
            print(f"Failed to load InternVL2: {str(e)}")
            internvl2_model = None
    
    # Return True if at least one model is loaded
    return (blip_model is not None and blip_processor is not None) or (internvl2_model is not None)

# Try to load a model at startup
MODEL_LOADED = load_model()
WHICH_MODEL = "InternVL2" if internvl2_model is not None else "BLIP" if blip_model is not None else "None"

def analyze_image(image, prompt):
    """Analyze the image using available model"""
    if not MODEL_LOADED:
        return "No model could be loaded. Please check the logs for details."
    
    if not USE_GPU:
        return "ERROR: This application requires GPU acceleration. No GPU detected."
    
    try:
        # Convert image to right format if needed
        if isinstance(image, np.ndarray):
            pil_image = Image.fromarray(image).convert('RGB')
        else:
            pil_image = image.convert('RGB')
        
        # Try BLIP first since it's more reliable
        if blip_model is not None and blip_processor is not None:
            try:
                print("Running inference with BLIP...")
                # BLIP doesn't use prompts the same way, simplify
                inputs = blip_processor(pil_image, return_tensors="pt").to("cuda")
                out = blip_model.generate(**inputs, max_length=80, min_length=10, num_beams=5)
                result = blip_processor.decode(out[0], skip_special_tokens=True)
                
                # Check if BLIP result is empty
                if not result or result.strip() == "":
                    print("BLIP model returned an empty response")
                    # Only fall through to InternVL2 if BLIP fails
                    raise ValueError("Empty response from BLIP")
                
                return f"[BLIP] {result}"
            except Exception as e:
                print(f"Error with BLIP: {str(e)}")
                # If BLIP fails, try InternVL2 if available
        
        # Try InternVL2 if available
        if internvl2_model is not None:
            try:
                print("Running inference with InternVL2...")
                print(f"Using prompt: '{prompt}'")
                
                # Create a specifically formatted prompt for InternVL2
                formatted_prompt = f"<image>\n{prompt}"
                print(f"Formatted prompt: '{formatted_prompt}'")
                
                # Run the model with more explicit parameters
                response = internvl2_model(
                    (formatted_prompt, pil_image),
                    max_new_tokens=512,  # Set higher token limit
                    temperature=0.7,     # Add temperature for better generation
                    top_p=0.9            # Add top_p for better generation
                )
                
                # Print debug info about the response
                print(f"Response type: {type(response)}")
                print(f"Response attributes: {dir(response) if hasattr(response, '__dir__') else 'No dir available'}")
                
                # Try different ways to extract the text
                if hasattr(response, "text"):
                    result = response.text
                    print(f"Found 'text' attribute: '{result}'")
                elif hasattr(response, "response"):
                    result = response.response
                    print(f"Found 'response' attribute: '{result}'")
                elif hasattr(response, "generated_text"):
                    result = response.generated_text
                    print(f"Found 'generated_text' attribute: '{result}'")
                else:
                    # If no attribute worked, convert the whole response to string
                    result = str(response)
                    print(f"Using string conversion: '{result}'")
                
                # Check if we got an empty result
                if not result or result.strip() == "":
                    print("WARNING: Received empty response from InternVL2")
                    return "InternVL2 failed to analyze the image (empty response). This may be due to token limits."
                
                return f"[InternVL2] {result}"
                    
            except Exception as e:
                print(f"Error with InternVL2: {str(e)}")
                return f"Error with InternVL2: {str(e)}"
        
        return "No model was able to analyze the image. See logs for details."
        
    except Exception as e:
        print(f"Error in image analysis: {str(e)}")
        # Try to clean up memory in case of error
        if USE_GPU:
            torch.cuda.empty_cache()
        return f"Error in image analysis: {str(e)}"

def process_image(image, analysis_type="general"):
    """Process the image and return the analysis"""
    if image is None:
        return "Please upload an image."
    
    # Define prompt based on analysis type
    if analysis_type == "general":
        prompt = "Describe this image in detail."
    elif analysis_type == "text":
        prompt = "What text can you see in this image? Please transcribe it accurately."
    elif analysis_type == "chart":
        prompt = "Analyze any charts, graphs or diagrams in this image in detail, including trends, data points, and conclusions."
    elif analysis_type == "people":
        prompt = "Describe the people in this image - their appearance, actions, and expressions."
    elif analysis_type == "technical":
        prompt = "Provide a technical analysis of this image, including object identification, spatial relationships, and any technical elements present."
    else:
        prompt = "Describe this image in detail."
    
    start_time = time.time()
    
    # Get analysis from the model
    analysis = analyze_image(image, prompt)
    
    elapsed_time = time.time() - start_time
    return f"{analysis}\n\nAnalysis completed in {elapsed_time:.2f} seconds."

# Define the Gradio interface
def create_interface():
    with gr.Blocks(title="Image Analysis with InternVL2") as demo:
        gr.Markdown(f"# Image Analysis with {WHICH_MODEL}")
        
        # System diagnostics
        system_info = f"""
        ## System Diagnostics:
        - Model Used: {WHICH_MODEL}
        - Model Loaded: {MODEL_LOADED}
        - PyTorch Version: {torch.__version__}
        - CUDA Available: {torch.cuda.is_available()}
        - GPU Working: {USE_GPU}
        - nvidia-smi Available: {nvidia_smi_available}
        """
        
        gr.Markdown(system_info)
        gr.Markdown(f"Upload an image to analyze it using the {WHICH_MODEL} model.")
        
        # Show warnings based on system status
        if not MODEL_LOADED:
            gr.Markdown("⚠️ **WARNING**: No model could be loaded. This demo will not function correctly.", elem_classes=["warning-message"])
        
        if not USE_GPU:
            gr.Markdown("🚫 **ERROR**: NVIDIA GPU not detected. This application requires GPU acceleration.", elem_classes=["error-message"])
        
        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(type="pil", label="Upload Image")
                analysis_type = gr.Radio(
                    ["general", "text", "chart", "people", "technical"],
                    label="Analysis Type",
                    value="general"
                )
                submit_btn = gr.Button("Analyze Image")
                
                # Disable button if GPU is not available or no model loaded
                if not USE_GPU or not MODEL_LOADED:
                    submit_btn.interactive = False
            
            with gr.Column(scale=2):
                output_text = gr.Textbox(label="Analysis Result", lines=20)
                if not USE_GPU:
                    output_text.value = f"""ERROR: NVIDIA GPU driver not detected. This application requires GPU acceleration.

Diagnostics:
- Model Used: {WHICH_MODEL}
- PyTorch Version: {torch.__version__}
- CUDA Available via PyTorch: {torch.cuda.is_available()}
- nvidia-smi Available: {nvidia_smi_available}
- GPU Working: {USE_GPU}

Please ensure this Space is using a GPU-enabled instance and that the GPU is correctly initialized."""
                elif not MODEL_LOADED:
                    output_text.value = f"""ERROR: No model could be loaded.

Diagnostics:
- Model Used: {WHICH_MODEL}
- PyTorch Version: {torch.__version__}
- CUDA Available via PyTorch: {torch.cuda.is_available()}
- nvidia-smi Available: {nvidia_smi_available}
- GPU Working: {USE_GPU}

Please check the logs for more details."""
        
        submit_btn.click(
            fn=process_image,
            inputs=[input_image, analysis_type],
            outputs=output_text
        )
        
        gr.Markdown("""
        ## Analysis Types
        - **General**: General description of the image
        - **Text**: Focus on identifying and transcribing text in the image
        - **Chart**: Detailed analysis of charts, graphs, and diagrams
        - **People**: Description of people, their appearance and actions
        - **Technical**: Technical analysis identifying objects and spatial relationships
        """)
        
        # Hardware requirements notice
        gr.Markdown("""
        ## System Requirements
        This application requires:
        - NVIDIA GPU with CUDA support
        - At least 16GB of GPU memory recommended
        - GPU drivers properly installed and configured
        
        If you're running this on Hugging Face Spaces, make sure to select a GPU-enabled hardware type.
        """)
    
    return demo

# Main function
if __name__ == "__main__":
    # Create the Gradio interface
    demo = create_interface()
    
    # Launch the interface
    demo.launch(share=False, server_name="0.0.0.0")