custom-api / app.py
DataChem's picture
Update app.py
4468cfe verified
raw
history blame
935 Bytes
from fastapi import FastAPI, Request
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
app = FastAPI()
# Load the model and tokenizer
model_name = "EleutherAI/gpt-neo-1.3B" # Replace with your desired model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/predict")
async def predict(request: Request):
data = await request.json()
prompt = data.get("prompt", "")
if not prompt:
return {"error": "Prompt is required"}
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to("cpu") # Use "cuda" if GPU is enabled
# Generate tokens
outputs = model.generate(inputs.input_ids, max_length=40, num_return_sequences=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": response}