import math
from functools import reduce

import torch
import numpy as np
import os
from pathlib import Path


def factors(n):
    return reduce(list.__add__,
                  ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0))


def file_line_count(filename: str) -> int:
    """Count the number of lines in a file"""
    with open(filename, 'rb') as f:
        return sum(1 for _ in f)


def compute_attention(qkv, scale=None):
    """
    Compute attention matrix (same as in the pytorch scaled dot product attention)
    Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
    :param qkv: Query, key and value tensors concatenated along the first dimension
    :param scale: Scale factor for the attention computation
    :return:
    """
    if isinstance(qkv, torch.Tensor):
        query, key, value = qkv.unbind(0)
    else:
        query, key, value = qkv
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    L, S = query.size(-2), key.size(-2)
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_out = attn_weight @ value
    return attn_weight, attn_out


def compute_dot_product_similarity(a, b):
    scores = a @ b.transpose(-1, -2)
    return scores


def compute_cross_entropy(p, q):
    q = torch.nn.functional.log_softmax(q, dim=-1)
    loss = torch.sum(p * q, dim=-1)
    return - loss.mean()


def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")):
    """
    Perform attention rollout, 
    Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16
    Parameters
    ----------
    attentions : list
        List of attention matrices, one for each transformer layer
    discard_ratio : float
        Ratio of lowest attention values to discard
    head_fusion : str
        Type of fusion to use for attention heads. One of "mean", "max", "min"
    device : torch.device
        Device to use for computation
    Returns
    -------
    mask : np.ndarray
        Mask of shape (width, width), where width is the square root of the number of patches
    """
    result = torch.eye(attentions[0].size(-1), device=device)
    attentions = [attention.to(device) for attention in attentions]
    with torch.no_grad():
        for attention in attentions:
            if head_fusion == "mean":
                attention_heads_fused = attention.mean(axis=1)
            elif head_fusion == "max":
                attention_heads_fused = attention.max(axis=1).values
            elif head_fusion == "min":
                attention_heads_fused = attention.min(axis=1).values
            else:
                raise "Attention head fusion type Not supported"

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
            indices = indices[indices != 0]
            flat[0, indices] = 0

            I = torch.eye(attention_heads_fused.size(-1), device=device)
            a = (attention_heads_fused + 1.0 * I) / 2
            a = a / a.sum(dim=-1)

            result = torch.matmul(a, result)

    # Normalize the result by max value in each row
    result = result / result.max(dim=-1, keepdim=True)[0]
    return result


def sync_bn_conversion(model: torch.nn.Module):
    """
    Convert BatchNorm to SyncBatchNorm (used for DDP)
    :param model: PyTorch model
    :return:
    model: PyTorch model with SyncBatchNorm layers
    """
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    return model


def check_snapshot(args):
    """
    Create directory to save training checkpoints, otherwise load the existing checkpoint.
    Additionally, if it is an array training job, create a new directory for each training job.
    :param args: Arguments from the argument parser
    :return:
    """
    # Check if it is an array training job (i.e. training with multiple random seeds on the same settings)
    if args.array_training_job and not args.resume_training:
        args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed))
        if not os.path.exists(args.snapshot_dir):
            save_dir = Path(args.snapshot_dir)
            save_dir.mkdir(parents=True, exist_ok=True)
    else:
        # Create directory to save training checkpoints, otherwise load the existing checkpoint
        if not os.path.exists(args.snapshot_dir):
            if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir:
                save_dir = Path(args.snapshot_dir)
                save_dir.mkdir(parents=True, exist_ok=True)
            else:
                raise ValueError('Snapshot checkpoint does not exist.')