#!/usr/bin/env python3
"""
GPU Diagnostics Tool for Hugging Face Spaces
This script performs a comprehensive check of GPU availability and functionality.
"""

import os
import sys
import subprocess
import time
import json

print("=" * 80)
print("GPU DIAGNOSTICS TOOL")
print("=" * 80)

# Check Python version
print(f"Python version: {sys.version}")
print("-" * 80)

# Check environment variables
print("ENVIRONMENT VARIABLES:")
gpu_related_vars = [
    "CUDA_VISIBLE_DEVICES",
    "NVIDIA_VISIBLE_DEVICES",
    "PYTORCH_CUDA_ALLOC_CONF",
    "HF_HOME"
]

for var in gpu_related_vars:
    print(f"{var}: {os.environ.get(var, 'Not set')}")
print("-" * 80)

# Check for nvidia-smi
print("CHECKING FOR NVIDIA-SMI:")
try:
    result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    if result.returncode == 0:
        print("nvidia-smi is available and working!")
        print(result.stdout)
    else:
        print("nvidia-smi error:")
        print(result.stderr)
except Exception as e:
    print(f"Error running nvidia-smi: {str(e)}")
print("-" * 80)

# Check PyTorch and CUDA
print("CHECKING PYTORCH AND CUDA:")
try:
    import torch

    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'Not available'}")
    
    if torch.cuda.is_available():
        print(f"CUDA device count: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")
        print(f"Current CUDA device: {torch.cuda.current_device()}")
    
    # Try to create and operate on a CUDA tensor
    print("\nTesting CUDA tensor creation:")
    try:
        start_time = time.time()
        x = torch.rand(1000, 1000, device="cuda" if torch.cuda.is_available() else "cpu")
        y = x @ x  # Matrix multiplication to test computation
        torch.cuda.synchronize()  # Wait for the operation to complete
        end_time = time.time()
        
        if torch.cuda.is_available():
            print(f"Successfully created and operated on a CUDA tensor in {end_time - start_time:.4f} seconds")
        else:
            print(f"Created and operated on a CPU tensor in {end_time - start_time:.4f} seconds (CUDA not available)")
    except Exception as e:
        print(f"Error in tensor creation/operation: {str(e)}")
    
    # Try to get more detailed CUDA info
    if torch.cuda.is_available():
        print("\nDetailed CUDA information:")
        print(f"CUDA capability: {torch.cuda.get_device_capability(0)}")
        print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
        print(f"CUDA arch list: {torch.cuda.get_arch_list() if hasattr(torch.cuda, 'get_arch_list') else 'Not available'}")
except ImportError:
    print("PyTorch is not installed")
print("-" * 80)

# Create a simple GPU test with a web interface
print("CREATING SIMPLE GPU TEST WEB INTERFACE...")
try:
    import gradio as gr

    def check_gpu():
        results = {
            "python_version": sys.version,
            "environment_vars": {var: os.environ.get(var, "Not set") for var in gpu_related_vars},
            "torch_available": False,
            "cuda_available": False
        }
        
        try:
            import torch
            results["torch_available"] = True
            results["torch_version"] = torch.__version__
            results["cuda_available"] = torch.cuda.is_available()
            
            if torch.cuda.is_available():
                results["cuda_version"] = torch.version.cuda
                results["cuda_device_count"] = torch.cuda.device_count()
                results["cuda_device_name"] = torch.cuda.get_device_name(0)
                
                # Test tensor creation
                start_time = time.time()
                x = torch.rand(1000, 1000, device="cuda")
                y = x @ x
                torch.cuda.synchronize()
                end_time = time.time()
                results["tensor_test_time"] = f"{end_time - start_time:.4f} seconds"
                results["gpu_test_passed"] = True
            else:
                results["gpu_test_passed"] = False
        except Exception as e:
            results["error"] = str(e)
            results["gpu_test_passed"] = False
            
        return json.dumps(results, indent=2)

    demo = gr.Interface(
        fn=check_gpu,
        inputs=[],
        outputs="text",
        title="GPU Diagnostics",
        description="Click the button to run GPU diagnostics"
    )
    
    print("Starting Gradio web interface on port 7860...")
    demo.launch(server_name="0.0.0.0")
except ImportError:
    print("Gradio not installed, skipping web interface")
    print("Raw GPU diagnostics complete.")
print("-" * 80)