aiWeb / main.py
benkada's picture
Update main.py
e1933c4 verified
raw
history blame
4.57 kB
import os
import io
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse
from huggingface_hub import InferenceClient
from PyPDF2 import PdfReader
from docx import Document
from PIL import Image
from io import BytesIO
# Load Hugging Face Token securely
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Hugging Face clients
summary_client = InferenceClient(model="facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
qa_client = InferenceClient(model="deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
image_caption_client = InferenceClient(model="nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
def extract_text_from_pdf(content: bytes) -> str:
reader = PdfReader(io.BytesIO(content))
return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
def extract_text_from_docx(content: bytes) -> str:
doc = Document(io.BytesIO(content))
return "\n".join(para.text for para in doc.paragraphs).strip()
def process_uploaded_file(file: UploadFile) -> str:
content = file.file.read()
extension = file.filename.split('.')[-1].lower()
if extension == "pdf":
return extract_text_from_pdf(content)
elif extension == "docx":
return extract_text_from_docx(content)
elif extension == "txt":
return content.decode("utf-8").strip()
else:
raise ValueError("Unsupported file type.")
@app.get("/", response_class=HTMLResponse)
async def serve_homepage():
with open("index.html", "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read(), status_code=200)
@app.post("/api/summarize")
async def summarize_document(file: UploadFile = File(...)):
try:
text = process_uploaded_file(file)
if len(text) < 20:
return {"result": "Document too short to summarize."}
summary = summary_client.summarization(text[:3000])
return {"result": summary}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/api/caption")
async def caption_image(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_pil.thumbnail((1024, 1024))
img_byte_arr = BytesIO()
image_pil.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
result = image_caption_client.image_to_text(img_byte_arr)
if isinstance(result, dict):
caption = result.get("generated_text") or result.get("caption") or "No caption found."
elif isinstance(result, list) and result:
caption = result[0].get("generated_text", "No caption found.")
elif isinstance(result, str):
caption = result
else:
caption = "No caption found."
return {"result": caption}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/api/qa")
async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
try:
content_type = file.content_type
if content_type.startswith("image/"):
image_bytes = await file.read()
image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_pil.thumbnail((1024, 1024))
img_byte_arr = BytesIO()
image_pil.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
result = image_caption_client.image_to_text(img_byte_arr)
context = result.get("generated_text") if isinstance(result, dict) else result
else:
text = process_uploaded_file(file)
if len(text) < 20:
return {"result": "Document too short to answer questions."}
context = text[:3000]
if not context:
return {"result": "No context available to answer."}
answer = qa_client.question_answering(question=question, context=context)
return {"result": answer.get("answer", "No answer found.")}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)