import os
import torch
import dill
import pickle
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms

# Define the image transformation
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

def decision_function(segm_map):
    """
    Calculate anomaly score from segmentation map using mean of top 10 values.
    """
    mean_top_10_values = []
    for map in segm_map:
        flattened_tensor = map.reshape(-1)
        sorted_tensor, _ = torch.sort(flattened_tensor, descending=True)
        mean_top_10_value = sorted_tensor[:10].mean()
        mean_top_10_values.append(mean_top_10_value)

    return torch.stack(mean_top_10_values)

def run_inference_autoencoder(image_path, model, backbone, threshold):
    """
    Run inference on a single image using Autoencoder.
    """
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    test_image = transform(image).cuda().unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        features = backbone(test_image)
        recon = model(features)

        # Compute segmentation map and anomaly score
        segm_map = ((features - recon) ** 2).mean(axis=(1))
        y_score = decision_function(segm_map=segm_map)
        is_anomaly = (y_score >= threshold).cpu().numpy().item()

        # Create heatmap
        heat_map = cv2.resize(segm_map.squeeze().cpu().numpy(), (224, 224))

    return {
        'original_image': test_image.squeeze().permute(1, 2, 0).cpu().numpy(),
        'heat_map': heat_map,
        'anomaly_score': y_score.item(),
        'threshold': threshold,
        'is_anomaly': is_anomaly,
        'classification': 'NOK' if is_anomaly else 'OK'
    }

def run_inference_knn(image_path, model, memory_bank, threshold):
    """
    Run inference on a single image using KNN.
    """
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    test_image = transform(image).cuda().unsqueeze(0)

    # Move memory bank to GPU **only once**
    memory_bank = memory_bank.cuda()

    # Extract features using the backbone
    with torch.no_grad():
        features = model(test_image)

    # Compute distances (optimized)
    distances = torch.cdist(features, memory_bank, p=2.0)  # Batched distance calculation
    dist_score, _ = torch.min(distances, dim=1)  # Get the nearest neighbor
    y_score = torch.max(dist_score)  # Get the anomaly score

    is_anomaly = (y_score >= threshold).cpu().item()

    # Compute segmentation map
    segm_map = dist_score.view(1, 1, 28, 28)
    segm_map = torch.nn.functional.interpolate(segm_map, size=(224, 224), mode='bilinear').cpu().squeeze().numpy()
    
    # Convert segmentation map to heatmap
    heat_map = cv2.resize(segm_map, (224, 224))

    return {
        'original_image': test_image.squeeze().permute(1, 2, 0).cpu().numpy(),
        'heat_map': heat_map,
        'anomaly_score': y_score.item(),
        'threshold': threshold,
        'is_anomaly': is_anomaly,
        'classification': 'NOK' if is_anomaly else 'OK',
    }


def load_model_autoencoder(checkpoint_dir='models'):
    """
    Load the saved models and evaluation metrics for Autoencoder.
    """
    models_path = os.path.join(checkpoint_dir, 'models_autoencoder.dill')
    backbone_path = os.path.join(checkpoint_dir, 'backbone_autoencoder.dill')
    metrics_path = os.path.join(checkpoint_dir, 'evaluation_metrics_autoencoder.pkl')

    # Ensure files exist before loading
    if not os.path.exists(models_path):
        raise FileNotFoundError(f"Autoencoder model file not found: {models_path}")
    if not os.path.exists(backbone_path):
        raise FileNotFoundError(f"Autoencoder backbone file not found: {backbone_path}")
    if not os.path.exists(metrics_path):
        raise FileNotFoundError(f"Autoencoder metrics file not found: {metrics_path}")

    try:
        with open(models_path, 'rb') as f:
            models = dill.load(f)

        with open(backbone_path, 'rb') as f:
            backbone = dill.load(f)

        with open(metrics_path, 'rb') as f:
            evaluation_metrics = pickle.load(f)

        return models, backbone, evaluation_metrics

    except Exception as e:
        raise RuntimeError(f"Error loading Autoencoder models: {e}")

def load_model_knn(checkpoint_dir='models'):
    """
    Load the saved models and evaluation metrics for KNN.
    """
    memory_bank_path = os.path.join(checkpoint_dir, 'memory_bank_selected.pkl')
    backbone_path = os.path.join(checkpoint_dir, 'backbone_knn.dill')  # Fixed incorrect backbone file
    metrics_path = os.path.join(checkpoint_dir, 'evaluation_metrics_knn.pkl')

    # Ensure files exist before loading
    if not os.path.exists(memory_bank_path):
        raise FileNotFoundError(f"KNN memory bank file not found: {memory_bank_path}")
    if not os.path.exists(backbone_path):
        raise FileNotFoundError(f"KNN backbone file not found: {backbone_path}")
    if not os.path.exists(metrics_path):
        raise FileNotFoundError(f"KNN metrics file not found: {metrics_path}")

    try:
        with open(memory_bank_path, 'rb') as f:
            memory_bank = pickle.load(f)

        with open(backbone_path, 'rb') as f:
            backbone = dill.load(f)

        with open(metrics_path, 'rb') as f:
            evaluation_metrics = pickle.load(f)

        return backbone, memory_bank, evaluation_metrics

    except Exception as e:
        raise RuntimeError(f"Error loading KNN models: {e}")