File size: 2,570 Bytes
4468cfe
45123df
5102dda
4468cfe
5102dda
e05b36f
 
 
4468cfe
 
 
 
5102dda
 
4468cfe
 
 
 
 
 
 
 
 
5102dda
 
 
4468cfe
45123df
5102dda
 
 
74b564f
 
 
 
5102dda
74b564f
 
 
5102dda
74b564f
 
 
 
 
5102dda
74b564f
 
 
 
 
5102dda
74b564f
 
5102dda
74b564f
 
 
5102dda
74b564f
 
 
5102dda
45123df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

app = FastAPI()

# Load the model and tokenizer
model_name = "EleutherAI/gpt-neo-1.3B"  # Replace with your desired model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

@app.post("/predict")
async def predict(request: Request):
    data = await request.json()
    prompt = data.get("prompt", "")
    if not prompt:
        return {"error": "Prompt is required"}

    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    def token_generator():
        temperature = 0.7
        top_p = 0.9

        for _ in range(100):  # Limit to 100 tokens
            with torch.no_grad():  # Disable gradient computation for inference
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                next_token_logits = outputs.logits[:, -1, :]

                # Apply temperature and softmax
                next_token_logits = next_token_logits / temperature
                next_token_probs = F.softmax(next_token_logits, dim=-1)

                # Apply nucleus sampling (top-p)
                sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                sorted_probs = sorted_probs[cumulative_probs <= top_p]
                sorted_indices = sorted_indices[:len(sorted_probs)]

                # Sample next token
                if len(sorted_probs) > 0:
                    next_token_id = sorted_indices[torch.multinomial(sorted_probs, 1)]
                else:
                    next_token_id = torch.argmax(next_token_probs)

                # Append the new token to the input sequence
                input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)

                # Decode and yield the token
                token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
                yield token + " "

                # Stop if the end-of-sequence token is generated
                if next_token_id.squeeze().item() == tokenizer.eos_token_id:
                    break

    return StreamingResponse(token_generator(), media_type="text/plain")