import os
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

app = FastAPI()

# -------------------------------------------------------------------------
# Since Falcon 7B Instruct is not gated, you do NOT need an HF token.
# We omit any 'use_auth_token' parameter.
# -------------------------------------------------------------------------
model_name = "Sao10K/L3-8B-Stheno-v3.2"

print(f"Loading tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

print(f"Loading model from: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True
)

# Choose device based on availability (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)

@app.post("/predict")
async def predict(request: Request):
    """
    Endpoint for streaming responses from Falcon-7B-Instruct.
    Expects JSON: { "prompt": "<Your prompt>" }
    Returns a text/event-stream of tokens (SSE).
    """
    data = await request.json()
    prompt = data.get("prompt", "")
    if not prompt:
        return {"error": "Prompt is required"}

    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs.input_ids             # shape: [batch_size, seq_len], typically [1, seq_len]
    attention_mask = inputs.attention_mask   # same shape

    def token_generator():
        nonlocal input_ids, attention_mask

        # Basic generation hyperparameters
        temperature = 0.7
        top_p = 0.9
        max_new_tokens = 30  # Increase if you want longer outputs

        for _ in range(max_new_tokens):
            with torch.no_grad():
                # 1) Forward pass: compute logits for the next token
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                next_token_logits = outputs.logits[:, -1, :]

                # 2) Apply temperature scaling
                next_token_logits = next_token_logits / temperature

                # 3) Convert logits -> probabilities
                next_token_probs = F.softmax(next_token_logits, dim=-1)

                # 4) Nucleus (top-p) sampling
                sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                valid_indices = cumulative_probs <= top_p
                filtered_probs = sorted_probs[valid_indices]
                filtered_indices = sorted_indices[valid_indices]

                # 5) If no tokens remain after filtering, fall back to greedy
                if len(filtered_probs) == 0:
                    next_token_id = torch.argmax(next_token_probs)
                else:
                    sampled_id = torch.multinomial(filtered_probs, 1)
                    next_token_id = filtered_indices[sampled_id]

                # 6) Ensure next_token_id has shape [batch_size, 1]
                if next_token_id.dim() == 0:
                    # shape [] => [1]
                    next_token_id = next_token_id.unsqueeze(0)
                # shape [1] => [1,1]
                next_token_id = next_token_id.unsqueeze(-1)

                # 7) Append the new token to input_ids
                input_ids = torch.cat([input_ids, next_token_id], dim=-1)

                # 8) Update the attention mask
                new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
                attention_mask = torch.cat([attention_mask, new_mask], dim=-1)

                # 9) Decode and yield the generated token
                token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
                yield token + " "

                # 10) Stop if EOS token is generated (if the model uses one)
                if tokenizer.eos_token_id is not None:
                    if next_token_id.squeeze().item() == tokenizer.eos_token_id:
                        break

    # Return a StreamingResponse for SSE
    return StreamingResponse(token_generator(), media_type="text/plain")