from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer import torch import torch.nn.functional as F app = FastAPI() # ------------------------------------------------------------------------- # Update this to the Llama 2 Chat model you prefer. This example uses the # 7B chat version. For larger models (13B, 70B), ensure you have enough RAM. # ------------------------------------------------------------------------- model_name = "meta-llama/Llama-2-7b-chat-hf" # ------------------------------------------------------------------------- # If the repo is gated, you may need: # use_auth_token="YOUR_HF_TOKEN", # trust_remote_code=True, # or you can set environment variables in your HF Space to authenticate. # ------------------------------------------------------------------------- print(f"Loading model/tokenizer from: {model_name}") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True # use_auth_token="YOUR_HF_TOKEN", # If needed for private/gated model ) # ------------------------------------------------------------------------- # If you had GPU available, you might do: # model = AutoModelForCausalLM.from_pretrained( # model_name, # torch_dtype=torch.float16, # device_map="auto", # trust_remote_code=True # ) # But for CPU, we do a simpler load: # ------------------------------------------------------------------------- model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True # use_auth_token="YOUR_HF_TOKEN", # If needed ) # Choose device based on availability device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model.to(device) @app.post("/predict") async def predict(request: Request): """ Endpoint for streaming responses from the Llama 2 chat model. Expects JSON: { "prompt": "" } Returns a text/event-stream of tokens. """ data = await request.json() prompt = data.get("prompt", "") if not prompt: return {"error": "Prompt is required"} # Tokenize the input prompt inputs = tokenizer(prompt, return_tensors="pt").to(device) input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len] attention_mask = inputs.attention_mask # same shape def token_generator(): """ A generator that yields tokens one by one for SSE streaming. """ nonlocal input_ids, attention_mask # Basic generation hyperparameters temperature = 0.7 top_p = 0.9 max_new_tokens = 30 # Increase for longer outputs for _ in range(max_new_tokens): with torch.no_grad(): # 1) Forward pass: compute logits for next token outputs = model(input_ids=input_ids, attention_mask=attention_mask) next_token_logits = outputs.logits[:, -1, :] # 2) Apply temperature scaling next_token_logits = next_token_logits / temperature # 3) Convert logits -> probabilities next_token_probs = F.softmax(next_token_logits, dim=-1) # 4) Nucleus (top-p) sampling sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) valid_indices = cumulative_probs <= top_p filtered_probs = sorted_probs[valid_indices] filtered_indices = sorted_indices[valid_indices] # 5) If no tokens are valid under top_p, fallback to greedy if len(filtered_probs) == 0: next_token_id = torch.argmax(next_token_probs) else: sampled_id = torch.multinomial(filtered_probs, 1) next_token_id = filtered_indices[sampled_id] # 6) Ensure next_token_id has shape [batch_size, 1] if next_token_id.dim() == 0: # shape [] => [1] next_token_id = next_token_id.unsqueeze(0) # shape [1] => [1,1] next_token_id = next_token_id.unsqueeze(-1) # 7) Append token to input_ids input_ids = torch.cat([input_ids, next_token_id], dim=-1) # 8) Update attention_mask for the new token new_mask = attention_mask.new_ones((attention_mask.size(0), 1)) attention_mask = torch.cat([attention_mask, new_mask], dim=-1) # 9) Decode and yield token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True) yield token + " " # 10) Stop if we encounter EOS if tokenizer.eos_token_id is not None: if next_token_id.squeeze().item() == tokenizer.eos_token_id: break # Return a StreamingResponse for SSE return StreamingResponse(token_generator(), media_type="text/plain")