from fastapi import APIRouter, HTTPException
from typing import Dict, Any
import os
import time
from tasks.create_bench_config_file import CreateBenchConfigTask
from tasks.create_bench import CreateBenchTask

router = APIRouter(tags=["benchmark"])

# Store active tasks by session_id (imported in main.py)
active_tasks = {}

# Reference to session_files (will be provided by main.py)
# This declaration will be overwritten by assignment in __init__.py
session_files = {}

@router.post("/generate-benchmark")
async def generate_benchmark(data: Dict[str, Any]):
    """
    Generate a benchmark configuration and run the ingestion process
    
    Args:
        data: Dictionary containing session_id
        
    Returns:
        Dictionary with logs and status
    """
    session_id = data.get("session_id")
    
    # Debug to check session_files and received session_id
    print(f"DEBUG: Session ID received: {session_id}")
    print(f"DEBUG: Available session files: {list(router.session_files.keys())}")
    
    if not session_id or session_id not in router.session_files:
        return {"error": "Invalid or missing session ID"}
    
    # Check if a benchmark is already in progress or completed for this session
    if session_id in active_tasks:
        task = active_tasks[session_id]
        # If the benchmark is already completed, return existing logs
        if task.is_task_completed():
            return {
                "status": "already_completed",
                "logs": task.get_logs(),
                "is_completed": True
            }
        # If the benchmark is running, return current logs
        else:
            return {
                "status": "already_running",
                "logs": task.get_logs(),
                "is_completed": False
            }
    
    file_path = router.session_files[session_id]
    all_logs = []
    
    try:
        # Initialize the task that will handle the entire process
        task = UnifiedBenchmarkTask(session_uid=session_id)
        
        # Storage for later log retrieval
        active_tasks[session_id] = task
        
        # Start the benchmark process
        task.run(file_path)
        
        # Get initial logs
        all_logs = task.get_logs()
        
        return {
            "status": "running",
            "logs": all_logs
        }
    except Exception as e:
        return {
            "status": "error",
            "error": str(e),
            "logs": all_logs
        }

@router.get("/benchmark-progress/{session_id}")
async def get_benchmark_progress(session_id: str):
    """
    Get the logs and status for a running benchmark task
    
    Args:
        session_id: Session ID for the task
        
    Returns:
        Dictionary with logs and completion status
    """
    if session_id not in active_tasks:
        raise HTTPException(status_code=404, detail="Benchmark task not found")
    
    task = active_tasks[session_id]
    logs = task.get_logs()
    is_completed = task.is_task_completed()
    
    return {
        "logs": logs,
        "is_completed": is_completed
    }

# Create a class that unifies the benchmark process
class UnifiedBenchmarkTask:
    """
    Task that handles the entire benchmark process from configuration to completion
    """
    
    def __init__(self, session_uid: str):
        """
        Initialize the unified benchmark task
        
        Args:
            session_uid: Session ID for this task
        """
        self.session_uid = session_uid
        self.logs = []
        self.is_completed = False
        self.config_task = None
        self.bench_task = None
        
        self._add_log("[INFO] Initializing benchmark task")
    
    def _add_log(self, message: str):
        """
        Add a log message
        
        Args:
            message: Log message to add
        """
        if message not in self.logs:  # Avoid duplicates
            self.logs.append(message)
            # Force a copy to avoid reference problems
            self.logs = self.logs.copy()
            print(f"[{self.session_uid}] {message}")
    
    def get_logs(self):
        """
        Get all logs
        
        Returns:
            List of log messages
        """
        return self.logs.copy()
    
    def is_task_completed(self):
        """
        Check if the task is completed
        
        Returns:
            True if completed, False otherwise
        """
        return self.is_completed
    
    def run(self, file_path: str):
        """
        Run the benchmark process
        
        Args:
            file_path: Path to the uploaded file
        """
        # Start in a separate thread to avoid blocking
        import threading
        thread = threading.Thread(target=self._run_process, args=(file_path,))
        thread.daemon = True
        thread.start()
    
    def _run_process(self, file_path: str):
        """
        Internal method to run the process
        
        Args:
            file_path: Path to the uploaded file
        """
        try:
            # Step 1: Configuration
            self._add_log("[INFO] Starting configuration process")
            # Import and use DEFAULT_BENCHMARK_TIMEOUT
            from config.models_config import DEFAULT_BENCHMARK_TIMEOUT
            self.config_task = CreateBenchConfigTask(session_uid=self.session_uid, timeout=DEFAULT_BENCHMARK_TIMEOUT)
            
            # Execute the configuration task
            try:
                config_path = self.config_task.run(file_path=file_path)
                
                # Get configuration logs
                config_logs = self.config_task.get_logs()
                for log in config_logs:
                    self._add_log(log)
                
                # Mark configuration step as completed
                if "[SUCCESS] Stage completed: config_generation" not in self.logs:
                    self._add_log("[SUCCESS] Stage completed: configuration")
                
                # Step 2: Benchmark
                self._add_log("[INFO] Starting benchmark process")
                self.bench_task = CreateBenchTask(session_uid=self.session_uid, config_path=config_path)
                
                # Run the benchmark task
                self.bench_task.run()
                
                # Wait for the benchmark task to complete
                while not self.bench_task.is_task_completed():
                    # Get new logs and add them
                    bench_logs = self.bench_task.get_logs()
                    for log in bench_logs:
                        self._add_log(log)
                    time.sleep(1)
                
                # Get final logs
                final_logs = self.bench_task.get_logs()
                for log in final_logs:
                    self._add_log(log)
                
                # Mark as completed
                self.is_completed = True
                
                # Check if an error was detected in the benchmark logs
                # Specifically ignore JSON parsing errors that should not block the process
                has_error = any("[ERROR]" in log and not ("JSONDecodeError" in log or 
                                                          "Error processing QA pair" in log or 
                                                          "'str' object has no attribute 'get'" in log)
                               for log in final_logs)
                benchmark_terminated_with_error = any("Benchmark process terminated with error code" in log for log in final_logs)
                benchmark_already_marked_success = any("Benchmark process completed successfully" in log for log in final_logs)
                
                # Even if there are JSON errors, consider the benchmark successful
                json_errors_only = any(("JSONDecodeError" in log or 
                                        "Error processing QA pair" in log or 
                                        "'str' object has no attribute 'get'" in log) 
                                     for log in final_logs) and not has_error
                
                if json_errors_only:
                    self._add_log("[INFO] Benchmark completed with minor JSON parsing warnings, considered successful")
                
                # Only add success message if no serious errors were detected
                if (not has_error and not benchmark_terminated_with_error and not benchmark_already_marked_success) or json_errors_only:
                    self._add_log("[SUCCESS] Benchmark process completed successfully")
                
            except Exception as config_error:
                error_msg = str(config_error)
                # Log detailed error
                self._add_log(f"[ERROR] Configuration failed: {error_msg}")
                
                # Check if it's a provider error and provide a more user-friendly message
                if "Required models not available" in error_msg:
                    self._add_log("[ERROR] Some required models are not available at the moment. Please try again later.")
                    
                # Mark as completed with error
                self.is_completed = True
        
        except Exception as e:
            self._add_log(f"[ERROR] Benchmark process failed: {str(e)}")
            self.is_completed = True