"""Custom transformer implementation for fallback."""
import torch
import torch.nn as nn
import math
import logging

# Set up logging
logger = logging.getLogger(__name__)

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # Calculate RMS
        rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.weight * rms * x


class RotaryEmbedding(nn.Module):
    """Rotary positional embedding."""
    
    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        
        # Generate frequency tensor
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        # Generate cos and sin cache
        self._update_cos_sin_cache(max_seq_len)
        
    def _update_cos_sin_cache(self, max_seq_len):
        """Update the cache of cos and sin values."""
        self.max_seq_len = max_seq_len
        t = torch.arange(max_seq_len, device=self.inv_freq.device)
        
        # Compute cos and sin at each position
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        
        self.register_buffer("cos_cache", cos, persistent=False)
        self.register_buffer("sin_cache", sin, persistent=False)
        
    def forward(self, x, seq_len=None, pos=None):
        # Get appropriate parts of the cache
        if pos is not None:
            # Handle arbitrary positions
            cos = self.cos_cache[pos]
            sin = self.sin_cache[pos]
        else:
            # Handle sequential positions
            seq_len = x.shape[1] if seq_len is None else seq_len
            cos = self.cos_cache[:seq_len]
            sin = self.sin_cache[:seq_len]
            
        return cos, sin


def rotate_half(x):
    """Rotate half the dimensions of the input."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
    """Apply rotary position embedding to q and k."""
    if position_ids is not None:
        # Handle arbitrary positions
        cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    else:
        # Handle sequential positions
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim]
    
    # Apply rotation
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    
    return q_embed, k_embed


class CustomAttention(nn.Module):
    """Multi-head attention with support for KV caching."""
    
    def __init__(self, dim, num_heads, num_kv_heads=None, dropout=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads or num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        # Attention projections
        self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False)
        
        # Rotary embedding
        self.rope = RotaryEmbedding(self.head_dim)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def _repeat_kv(self, x):
        """Repeat KV heads to match the number of query heads."""
        if self.num_kv_heads == self.num_heads:
            return x
        
        b, s, n_kv_head, head_dim = x.shape
        
        # Repeat the KV heads to match the number of query heads
        repeats = self.num_heads // self.num_kv_heads
        x = x.repeat_interleave(repeats, dim=2)
        
        return x
        
    def forward(self, x, mask=None, input_pos=None, kv_cache=None):
        batch_size, seq_len, _ = x.shape
        
        # Project to q, k, v
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [b, nh, s, hd]
        k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)  # [b, nkh, s, hd]
        v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)  # [b, nkh, s, hd]
        
        # Apply rotary embeddings
        cos, sin = self.rope.forward(x, seq_len=seq_len, pos=input_pos)
        q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=input_pos)
        
        # Handle KV cache
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            
            if input_pos is not None:
                # Update cache at specific positions
                k_cache.index_copy_(2, input_pos, k)
                v_cache.index_copy_(2, input_pos, v)
                
                # Use the entire cache
                k, v = k_cache, v_cache
        
        # Repeat KV if needed
        k = self._repeat_kv(k)
        v = self._repeat_kv(v)
        
        # Calculate attention scores
        attention_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        # Apply mask if provided
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -10000.0)
            
        # Apply softmax and dropout
        attention_probs = self.dropout(torch.softmax(attention_scores, dim=-1))
        
        # Get context vector
        context = torch.matmul(attention_probs, v)
        
        # Reshape and project back
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        output = self.o_proj(context)
        
        return output


class FeedForward(nn.Module):
    """Feed-forward network with GELU activation."""
    
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.GELU()
        
    def forward(self, x):
        x = self.w1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.w2(x)
        return x


class TransformerLayer(nn.Module):
    """A single transformer layer."""
    
    def __init__(
        self, 
        dim, 
        num_heads, 
        num_kv_heads=None,
        ffn_dim=None, 
        dropout=0.0, 
        norm_eps=1e-5
    ):
        super().__init__()
        self.norm1 = RMSNorm(dim, eps=norm_eps)
        self.attn = CustomAttention(dim, num_heads, num_kv_heads, dropout)
        self.norm2 = RMSNorm(dim, eps=norm_eps)
        self.ffn = FeedForward(
            dim, 
            ffn_dim or 4 * dim, 
            dropout
        )
        
    def forward(self, x, mask=None, input_pos=None, kv_cache=None):
        # Self-attention with residual
        h = self.norm1(x)
        h = self.attn(h, mask=mask, input_pos=input_pos, kv_cache=kv_cache)
        x = x + h
        
        # FFN with residual
        h = self.norm2(x)
        h = self.ffn(h)
        x = x + h
        
        return x


class CustomTransformerDecoder(nn.Module):
    """Custom transformer decoder that mimics Llama architecture."""
    
    def __init__(
        self,
        vocab_size,
        num_layers,
        num_heads,
        num_kv_heads,
        embed_dim,
        max_seq_len,
        intermediate_dim,
        attn_dropout=0.0,
        norm_eps=1e-5,
        rope_base=10000,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.embed_dim = embed_dim
        
        # Token embeddings
        self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerLayer(
                embed_dim,
                num_heads,
                num_kv_heads,
                intermediate_dim,
                attn_dropout,
                norm_eps
            )
            for _ in range(num_layers)
        ])
        
        # Final normalization and output projection
        self.norm = RMSNorm(embed_dim, eps=norm_eps)
        self.output = nn.Linear(embed_dim, vocab_size, bias=False)
        
        # Initialize the KV cache
        self._kv_cache = None
        self._has_cache = False
        
        logger.info(f"Initialized CustomTransformerDecoder with {num_layers} layers, {num_heads} heads, {embed_dim} dim")
        
    def setup_caches(self, batch_size, dtype, decoder_max_seq_len=None):
        """Set up KV caches for inference."""
        max_seq_len = decoder_max_seq_len or self.max_seq_len
        device = next(self.parameters()).device
        
        self._kv_cache = []
        for i, layer in enumerate(self.layers):
            # Create a KV cache for each layer
            k_cache = torch.zeros(
                batch_size,
                layer.attn.num_kv_heads,
                max_seq_len,
                layer.attn.head_dim,
                device=device,
                dtype=dtype
            )
            v_cache = torch.zeros(
                batch_size,
                layer.attn.num_kv_heads,
                max_seq_len,
                layer.attn.head_dim,
                device=device,
                dtype=dtype
            )
            self._kv_cache.append((k_cache, v_cache))
        
        self._has_cache = True
        logger.info(f"KV caches set up for {batch_size} batches, {max_seq_len} seq length")
        
    def caches_are_enabled(self):
        """Check if caches are enabled."""
        return self._has_cache
        
    def reset_caches(self):
        """Reset the KV cache to zeros."""
        if self._has_cache and self._kv_cache:
            for k_cache, v_cache in self._kv_cache:
                k_cache.zero_()
                v_cache.zero_()
    
    def forward(self, x, mask=None, input_pos=None):
        batch_size, seq_len = x.shape[:2]
        
        # Apply embedding if input is token IDs
        if x.dim() == 2:
            x = self.tok_embeddings(x)
        
        # Apply transformer layers
        for i, layer in enumerate(self.layers):
            layer_cache = self._kv_cache[i] if self._has_cache else None
            x = layer(x, mask=mask, input_pos=input_pos, kv_cache=layer_cache)
        
        # Apply final norm
        x = self.norm(x)
        
        # Skip output projection if using Identity
        if isinstance(self.output, nn.Identity):
            return x
        
        # Apply output projection
        logits = self.output(x)
        
        return logits