ibrahim256 commited on
Commit
9665dce
·
verified ·
1 Parent(s): 969581d

Upload deployment/api_server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. deployment/api_server.py +196 -0
deployment/api_server.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AuraMind REST API Server
4
+ Production-ready API for AuraMind smartphone deployment
5
+ """
6
+
7
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ from typing import Optional, List, Dict
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ import uvicorn
14
+ import logging
15
+ import time
16
+ from datetime import datetime
17
+ import os
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Request/Response models
24
+ class ChatRequest(BaseModel):
25
+ message: str
26
+ mode: str = "Assistant" # "Therapist" or "Assistant"
27
+ max_tokens: int = 200
28
+ temperature: float = 0.7
29
+
30
+ class ChatResponse(BaseModel):
31
+ response: str
32
+ mode: str
33
+ inference_time_ms: float
34
+ timestamp: str
35
+
36
+ class ModelInfo(BaseModel):
37
+ variant: str
38
+ memory_usage: str
39
+ inference_speed: str
40
+ status: str
41
+
42
+ # Initialize FastAPI app
43
+ app = FastAPI(
44
+ title="AuraMind API",
45
+ description="Smartphone-optimized dual-mode AI companion API",
46
+ version="1.0.0"
47
+ )
48
+
49
+ # Add CORS middleware
50
+ app.add_middleware(
51
+ CORSMiddleware,
52
+ allow_origins=["*"], # Configure appropriately for production
53
+ allow_credentials=True,
54
+ allow_methods=["*"],
55
+ allow_headers=["*"],
56
+ )
57
+
58
+ # Global model variables
59
+ tokenizer = None
60
+ model = None
61
+ model_variant = None
62
+
63
+ def load_model(variant: str = "270m"):
64
+ """Load AuraMind model"""
65
+ global tokenizer, model, model_variant
66
+
67
+ try:
68
+ logger.info(f"Loading AuraMind {variant}...")
69
+
70
+ model_name = "zail-ai/Auramind"
71
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ model_name,
74
+ torch_dtype=torch.float16,
75
+ device_map="auto",
76
+ low_cpu_mem_usage=True
77
+ )
78
+
79
+ model.eval()
80
+ model_variant = variant
81
+
82
+ logger.info(f"✅ AuraMind {variant} loaded successfully")
83
+
84
+ except Exception as e:
85
+ logger.error(f"Failed to load model: {e}")
86
+ raise
87
+
88
+ @app.on_event("startup")
89
+ async def startup_event():
90
+ """Initialize model on startup"""
91
+ variant = os.getenv("MODEL_VARIANT", "270m")
92
+ load_model(variant)
93
+
94
+ @app.get("/health")
95
+ async def health_check():
96
+ """Health check endpoint"""
97
+ return {
98
+ "status": "healthy",
99
+ "model_loaded": model is not None,
100
+ "variant": model_variant,
101
+ "timestamp": datetime.now().isoformat()
102
+ }
103
+
104
+ @app.get("/model/info", response_model=ModelInfo)
105
+ async def get_model_info():
106
+ """Get model information"""
107
+ if model is None:
108
+ raise HTTPException(status_code=503, detail="Model not loaded")
109
+
110
+ variant_configs = {
111
+ "270m": {"memory": "~680MB RAM", "speed": "100-300ms"},
112
+ "180m": {"memory": "~450MB RAM", "speed": "80-200ms"},
113
+ "90m": {"memory": "~225MB RAM", "speed": "50-150ms"}
114
+ }
115
+
116
+ config = variant_configs.get(model_variant, {"memory": "Unknown", "speed": "Unknown"})
117
+
118
+ return ModelInfo(
119
+ variant=model_variant,
120
+ memory_usage=config["memory"],
121
+ inference_speed=config["speed"],
122
+ status="ready"
123
+ )
124
+
125
+ @app.post("/chat", response_model=ChatResponse)
126
+ async def chat(request: ChatRequest):
127
+ """Generate chat response"""
128
+ if model is None or tokenizer is None:
129
+ raise HTTPException(status_code=503, detail="Model not loaded")
130
+
131
+ if request.mode not in ["Therapist", "Assistant"]:
132
+ raise HTTPException(status_code=400, detail="Mode must be 'Therapist' or 'Assistant'")
133
+
134
+ try:
135
+ start_time = time.time()
136
+
137
+ # Format prompt
138
+ prompt = f"<|start_of_turn|>user\n[{request.mode} Mode] {request.message}<|end_of_turn|>\n<|start_of_turn|>model\n"
139
+
140
+ # Tokenize
141
+ inputs = tokenizer(
142
+ prompt,
143
+ return_tensors="pt",
144
+ truncation=True,
145
+ max_length=512
146
+ )
147
+
148
+ # Generate
149
+ with torch.no_grad():
150
+ outputs = model.generate(
151
+ **inputs,
152
+ max_new_tokens=request.max_tokens,
153
+ temperature=request.temperature,
154
+ do_sample=True,
155
+ top_p=0.9,
156
+ repetition_penalty=1.1,
157
+ pad_token_id=tokenizer.eos_token_id
158
+ )
159
+
160
+ # Decode response
161
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
+ response = full_response.split("<|start_of_turn|>model\n")[-1].strip()
163
+
164
+ inference_time = (time.time() - start_time) * 1000
165
+
166
+ return ChatResponse(
167
+ response=response,
168
+ mode=request.mode,
169
+ inference_time_ms=round(inference_time, 2),
170
+ timestamp=datetime.now().isoformat()
171
+ )
172
+
173
+ except Exception as e:
174
+ logger.error(f"Error generating response: {e}")
175
+ raise HTTPException(status_code=500, detail="Failed to generate response")
176
+
177
+ @app.post("/chat/batch")
178
+ async def chat_batch(requests: List[ChatRequest]):
179
+ """Process multiple chat requests"""
180
+ if len(requests) > 10: # Limit batch size
181
+ raise HTTPException(status_code=400, detail="Batch size limited to 10 requests")
182
+
183
+ responses = []
184
+ for req in requests:
185
+ response = await chat(req)
186
+ responses.append(response)
187
+
188
+ return {"responses": responses}
189
+
190
+ if __name__ == "__main__":
191
+ uvicorn.run(
192
+ app,
193
+ host="0.0.0.0",
194
+ port=int(os.getenv("PORT", 8000)),
195
+ workers=1 # Single worker for model consistency
196
+ )