"""
Task to run evaluation using lighteval
"""
import os
import time
import subprocess
import tempfile
from pathlib import Path
import concurrent.futures
from dotenv import load_dotenv
from datetime import datetime
import json
import shutil
from typing import List, Dict
from tasks.get_available_model_provider import get_available_model_provider
from huggingface_hub import HfApi
import asyncio
from datasets import load_dataset
from config.models_config import DEFAULT_EVALUATION_MODELS, DEFAULT_EVALUATION_TIMEOUT

class EvaluationTask:
    """
    Task to run evaluation using lighteval
    """

    def __init__(self, session_uid: str, dataset_name: str, clean_old_results: bool = False, timeout: float = None):
        """
        Initialize the evaluation task
        
        Args:
            session_uid: Session ID for this task
            dataset_name: Name of the dataset to evaluate
            clean_old_results: If True, clean old results before evaluation
            timeout: Timeout in seconds for each model evaluation (if None, uses default)
        """
        self.session_uid = session_uid
        self.dataset_name = dataset_name
        self.is_completed = False
        self.results = []
        self.hf_api = HfApi()
        self.timeout = timeout if timeout is not None else DEFAULT_EVALUATION_TIMEOUT
        self.current_step = "initializing"
        self.completed_steps = []
        self.step_start_time = time.time()  # Record the start time of the current step
        
        # Clean old results if requested
        if clean_old_results:
            self.clean_old_results()

    async def update_step(self, step: str) -> None:
        """
        Update the current step and completed steps with a minimum delay of 1 second
        
        Args:
            step: Name of the step to update
        """
        # Calculate the elapsed time since the start of the previous step
        elapsed_since_step_start = time.time() - self.step_start_time
        
        # If less than one second has passed, wait to complete the second
        if elapsed_since_step_start < 1.0:
            await asyncio.sleep(1.0 - elapsed_since_step_start)
        
        # Update the current step and record the new timestamp
        self.current_step = step
        self.step_start_time = time.time()
        
        # Add to completed steps if necessary
        if step not in self.completed_steps:
            self.completed_steps.append(step)
        
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Step changed to: {step}")

    def get_progress(self) -> Dict:
        """
        Get the current progress of the task
        
        Returns:
            Dictionary containing current step and completed steps
        """
        return {
            "current_step": self.current_step,
            "completed_steps": self.completed_steps
        }

    def clean_old_results(self) -> None:
        """
        Clean old evaluation results to avoid confusion
        """
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Checking and cleaning old results...")
        
        # Path to LightEval results
        results_dir = Path(f"uploaded_files/{self.session_uid}/lighteval_results")
        
        # Delete if exists
        if results_dir.exists():
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Deleting old LightEval results")
            shutil.rmtree(results_dir)
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Cleaning complete")
        else:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] No old results found")
            
        # Also check for intermediate lighteval results
        if os.path.exists("data/lighteval_results"):
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Cleaning intermediate results")
            try:
                shutil.rmtree("data/lighteval_results", ignore_errors=True)
            except Exception as e:
                print(f"[{datetime.now().strftime('%H:%M:%S')}] Error cleaning intermediate results: {str(e)}")

    def _save_results_to_hub(self) -> None:
        """
        Save evaluation results directly to the dataset on the Hub without persisting locally
        """
        try:
            # Sort results by accuracy (from most accurate to least accurate)
            sorted_results = sorted(self.results, key=lambda x: x.get('accuracy', 0), reverse=True)
            
            # Create a temporary file for the results
            with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as temp_file:
                # Add metadata to the results
                final_results = {
                    "metadata": {
                        "evaluation_date": datetime.now().isoformat(),
                        "session_id": self.session_uid,
                        "dataset_name": self.dataset_name
                    },
                    "results": sorted_results
                }
                
                json.dump(final_results, temp_file, indent=2)
                temp_file_path = temp_file.name
            
            # Push to Hub
            self.hf_api.upload_file(
                path_or_fileobj=temp_file_path,
                path_in_repo="lighteval_results.json",
                repo_id=self.dataset_name,
                repo_type="dataset",
                commit_message="Add lighteval evaluation results"
            )
            
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Results saved to Hub at {self.dataset_name}/lighteval_results.json")
            
            # Delete the temporary file
            os.unlink(temp_file_path)
        except Exception as e:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Failed to save results to Hub: {str(e)}")

    async def _run_lighteval(self, model_name: str, provider: str) -> dict:
        start_time = time.time()
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting evaluation with {provider} provider for {model_name}")
        
        # Create temporary task file
        temp_file_path = tempfile.mktemp(suffix=".py")
        with open(temp_file_path, 'w') as temp_file:
            temp_file.write(f"""
from lighteval_task.lighteval_task import create_yourbench_task

# Create yourbench task
yourbench = create_yourbench_task("{self.dataset_name}", "single_shot_questions")

# Define TASKS_TABLE needed by lighteval
TASKS_TABLE = [yourbench]
""")

        # Create output directory in the session folder
        output_dir = f"uploaded_files/{self.session_uid}/lighteval_results"
        os.makedirs(output_dir, exist_ok=True)

        # LightEval command
        cmd_args = [
            "lighteval",
            "endpoint",
            "inference-providers",
            f"model_name={model_name},provider={provider},org_to_bill={os.getenv('HF_ORGANIZATION', 'yourbench')}",
            "custom|yourbench|0|0",
            "--custom-tasks",
            temp_file_path,
            "--max-samples", "30",
            "--output-dir", output_dir,
            "--save-details",
            "--no-push-to-hub"
        ]

        try:
            # Run the command with environment variables and increased timeout of 300 seconds
            process = await asyncio.create_subprocess_exec(
                *cmd_args,
                env=os.environ,
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE
            )
            
            try:
                print(f"[{datetime.now().strftime('%H:%M:%S')}] Running command: {' '.join(cmd_args)}")
                stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=self.timeout)
                
                # Log stdout and stderr
                if stdout:
                    stdout_decoded = stdout.decode('utf-8')
                    print(f"[{datetime.now().strftime('%H:%M:%S')}] LightEval STDOUT for {model_name}:")
                    for line in stdout_decoded.splitlines():
                        print(f"[STDOUT] {line}")
                
                if stderr:
                    stderr_decoded = stderr.decode('utf-8')
                    print(f"[{datetime.now().strftime('%H:%M:%S')}] LightEval STDERR for {model_name}:")
                    for line in stderr_decoded.splitlines():
                        print(f"[STDERR] {line}")
                
                # Check return code
                if process.returncode != 0:
                    print(f"[{datetime.now().strftime('%H:%M:%S')}] LightEval failed with return code {process.returncode}")
                
            except asyncio.TimeoutError:
                process.kill()
                print(f"[{datetime.now().strftime('%H:%M:%S')}] Evaluation timed out for {model_name} after {time.time() - start_time:.2f}s")
                
                # Clean up temporary files
                os.unlink(temp_file_path)
                
                return {
                    "model": model_name,
                    "provider": provider,
                    "accuracy": 0.0,
                    "execution_time": self.timeout,
                    "status": "timeout"
                }
        except Exception as e:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Error running evaluation for {model_name}: {str(e)}")
            
            # Clean up temporary files
            os.unlink(temp_file_path)
            
            return {
                "model": model_name,
                "provider": provider,
                "accuracy": 0.0,
                "execution_time": time.time() - start_time,
                "status": "error"
            }

        # Calculate execution time
        execution_time = time.time() - start_time
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Finished evaluation for {model_name} in {execution_time:.2f}s")

        try:
            # Get results from the output file
            results_dir = Path(output_dir) / "results" / model_name.replace("/", "/")
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Looking for results in {results_dir}")
            
            if not results_dir.exists():
                print(f"[{datetime.now().strftime('%H:%M:%S')}] Results directory doesn't exist for {model_name}")
                raise FileNotFoundError(f"Results directory not found: {results_dir}")
                
            results_files = list(results_dir.glob("results_*.json"))
            if not results_files:
                print(f"[{datetime.now().strftime('%H:%M:%S')}] No results files found in {results_dir}")
                raise FileNotFoundError(f"No results files found in {results_dir}")
                
            results_file = results_files[0]
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Using results file: {results_file}")
            
            with open(results_file) as f:
                results = json.load(f)
                print(f"[{datetime.now().strftime('%H:%M:%S')}] Results structure: {json.dumps(list(results.keys()))}")
                
                # Verify that the structure is as expected
                if "results" in results and "all" in results["results"] and "accuracy" in results["results"]["all"]:
                    accuracy = results["results"]["all"]["accuracy"]
                    print(f"[{datetime.now().strftime('%H:%M:%S')}] Extracted accuracy: {accuracy}")
                else:
                    print(f"[{datetime.now().strftime('%H:%M:%S')}] Unexpected results structure. Available keys: {list(results.keys())}")
                    if "results" in results:
                        print(f"[{datetime.now().strftime('%H:%M:%S')}] Keys in 'results': {list(results['results'].keys()) if isinstance(results['results'], dict) else 'not a dictionary'}")
                    raise ValueError(f"Unexpected results structure for {model_name}")

            result_data = {
                "model": model_name,
                "provider": provider,
                "accuracy": accuracy,
                "execution_time": execution_time,
                "status": "success"
            }
        except Exception as e:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Failed to parse results for {model_name} after {execution_time:.2f}s: {str(e)}")
            result_data = {
                "model": model_name,
                "provider": provider,
                "accuracy": 0.0,
                "execution_time": execution_time,
                "status": "parse_error"
            }
        
        # Clean up temporary files
        os.unlink(temp_file_path)
        
        return result_data

    async def run(self, clean_first: bool = True) -> None:
        """
        Run the evaluation task asynchronously
        
        Args:
            clean_first: If True, clean old results before starting (default: True)
        """
        # Systematically clean old results before starting
        self.clean_old_results()
        
        # Start global timer
        script_start_time = time.time()
        
        # Load environment variables
        load_dotenv()

        # Log to see the structure of the dataset
        try:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Attempting to load dataset {self.dataset_name} for inspection")
            dataset = load_dataset(self.dataset_name, "single_shot_questions", split="train")
            
            # Verify the structure of the first example
            if len(dataset) > 0:
                first_example = dataset[0]
                print(f"[{datetime.now().strftime('%H:%M:%S')}] Structure of the first example:")
                print(f"[{datetime.now().strftime('%H:%M:%S')}] Keys: {first_example.keys()}")
                print(f"[{datetime.now().strftime('%H:%M:%S')}] Citations: {first_example.get('citations', 'not found')}")
        except Exception as e:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] Error inspecting the dataset: {str(e)}")

        # Step 1: Check available providers for each model
        await self.update_step("finding_available_model_providers")
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Checking available providers for models...")
        
        model_providers = {}
        for model in DEFAULT_EVALUATION_MODELS:
            provider = get_available_model_provider(model, verbose=True)
            if provider:
                model_providers[model] = provider
            else:
                print(f"[{datetime.now().strftime('%H:%M:%S')}] No available provider found for {model}")
        
        if not model_providers:
            print(f"[{datetime.now().strftime('%H:%M:%S')}] No models with available providers found")
            return
            
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Found providers for {len(model_providers)} models")
        
        # Step 2: Run evaluations in parallel
        await self.update_step("starting_evaluation_process")
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting evaluation process...")
        
        # Step 3: Evaluate models
        await self.update_step("evaluating_models")
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Evaluating models...")
        
        tasks = []
        for model, provider in model_providers.items():
            tasks.append(self._run_lighteval(model, provider))
        
        # Run all evaluations in parallel
        results = await asyncio.gather(*tasks)
        
        # Filter out failed evaluations
        self.results = [r for r in results if r["status"] == "success"]
        
        # Step 4: Save results
        await self.update_step("storing_evaluation_results")
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Storing evaluation results...")
        self._save_results_to_hub()
        
        # Mark task as completed
        self.is_completed = True
        await self.update_step("completed")
        
        total_time = time.time() - script_start_time
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Evaluation completed in {total_time:.2f}s")

    def get_logs(self) -> List[str]:
        """
        Get the logs of the task
        
        Returns:
            List of log messages
        """
        return self.logs if hasattr(self, "logs") else []

    def is_task_completed(self) -> bool:
        """
        Check if the task is completed
        
        Returns:
            True if the task is completed, False otherwise
        """
        return self.is_completed