from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel, Extra import torch from transformers import AutoModelForCausalLM, AutoTokenizer import time import uuid import json from typing import Optional, List, Union, Dict, Any # --- Configuration --- MODEL_ID = "Qwen/Qwen2.5-Coder-0.5B-Instruct" DEVICE = "cpu" # Qwen/Qwen3-1.7B # deepseek-ai/deepseek-coder-1.3b-instruct # Qwen/Qwen2.5-Coder-0.5B-Instruct # --- Chargement du modèle --- print(f"Début du chargement du modèle : {MODEL_ID}") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map=DEVICE ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) print("Modèle et tokenizer chargés avec succès sur le CPU.") # --- Création de l'application API --- app = FastAPI() # --- Modèles de données pour accepter la structure complexe de l'extension --- class ContentPart(BaseModel): type: str text: str class ChatMessage(BaseModel): role: str content: Union[str, List[ContentPart]] class ChatCompletionRequest(BaseModel): model: Optional[str] = None messages: List[ChatMessage] stream: Optional[bool] = False class Config: extra = Extra.ignore class ModelData(BaseModel): id: str object: str = "model" owned_by: str = "user" class ModelList(BaseModel): object: str = "list" data: List[ModelData] # --- Définition des API --- @app.get("/models", response_model=ModelList) async def list_models(): """Répond à la requête GET /models pour satisfaire l'extension.""" return ModelList(data=[ModelData(id=MODEL_ID)]) @app.post("/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """Endpoint principal qui gère la génération de texte en streaming.""" # On extrait le prompt de l'utilisateur de la structure complexe user_prompt = "" last_message = request.messages[-1] if isinstance(last_message.content, list): for part in last_message.content: if part.type == 'text': user_prompt += part.text + "\n" elif isinstance(last_message.content, str): user_prompt = last_message.content if not user_prompt: return {"error": "Prompt non trouvé."} # Préparation pour le modèle DeepSeek messages_for_model = [{'role': 'user', 'content': user_prompt}] inputs = tokenizer.apply_chat_template(messages_for_model, add_generation_prompt=True, return_tensors="pt").to(DEVICE) # Génération de la réponse complète outputs = model.generate(inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id) response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) # Fonction génératrice pour le streaming async def stream_generator(): response_id = f"chatcmpl-{uuid.uuid4()}" # On envoie la réponse caractère par caractère, au format attendu for char in response_text: chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{ "index": 0, "delta": {"content": char}, "finish_reason": None }] } yield f"data: {json.dumps(chunk)}\n\n" await asyncio.sleep(0.01) # Petite pause pour simuler un flux # On envoie le chunk final de fin final_chunk = { "id": response_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop" }] } yield f"data: {json.dumps(final_chunk)}\n\n" # On envoie le signal [DONE] yield "data: [DONE]\n\n" # Si l'extension demande un stream, on renvoie le générateur if request.stream: return StreamingResponse(stream_generator(), media_type="text/event-stream") else: # Code de secours si le stream n'est pas demandé (peu probable) return {"choices": [{"message": {"role": "assistant", "content": response_text}}]} @app.get("/") def root(): return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID} # On a besoin de asyncio pour la pause dans le stream import asyncio @app.get("/") def root(): return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID} # On a besoin de asyncio pour la pause dans le stream import asyncio