ReallyFloppyPenguin commited on
Commit
e7f0eb4
ยท
verified ยท
1 Parent(s): 899488b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -0
app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ import os
5
+ from typing import List, Tuple
6
+ import re
7
+
8
+ class PolarisModel:
9
+ """
10
+ POLARIS-4B-Preview: A Post-training recipe for scaling RL on Advanced Reasoning models
11
+ Specialized for mathematical reasoning and problem-solving tasks.
12
+ """
13
+
14
+ def __init__(self):
15
+ self.model_name = "POLARIS-Project/Polaris-4B-Preview"
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ self.model = None
18
+ self.tokenizer = None
19
+ self.load_model()
20
+
21
+ def load_model(self):
22
+ """Load the POLARIS model with optimized settings for reasoning tasks"""
23
+ try:
24
+ # Load tokenizer
25
+ self.tokenizer = AutoTokenizer.from_pretrained(
26
+ self.model_name,
27
+ trust_remote_code=True,
28
+ padding_side="left"
29
+ )
30
+
31
+ # Set pad token if not exists
32
+ if self.tokenizer.pad_token is None:
33
+ self.tokenizer.pad_token = self.tokenizer.eos_token
34
+
35
+ # Configure for efficient inference
36
+ if self.device == "cuda":
37
+ # Use 4-bit quantization for GPU to fit 4B model
38
+ quantization_config = BitsAndBytesConfig(
39
+ load_in_4bit=True,
40
+ bnb_4bit_compute_dtype=torch.float16,
41
+ bnb_4bit_use_double_quant=True,
42
+ bnb_4bit_quant_type="nf4"
43
+ )
44
+
45
+ self.model = AutoModelForCausalLM.from_pretrained(
46
+ self.model_name,
47
+ quantization_config=quantization_config,
48
+ device_map="auto",
49
+ trust_remote_code=True,
50
+ torch_dtype=torch.float16
51
+ )
52
+ else:
53
+ # CPU inference
54
+ self.model = AutoModelForCausalLM.from_pretrained(
55
+ self.model_name,
56
+ device_map="cpu",
57
+ trust_remote_code=True,
58
+ torch_dtype=torch.float32
59
+ )
60
+
61
+ print(f"โœ… POLARIS-4B-Preview loaded successfully on {self.device}")
62
+
63
+ except Exception as e:
64
+ print(f"โŒ Error loading model: {e}")
65
+ # Fallback to a smaller model if POLARIS fails to load
66
+ try:
67
+ print("๐Ÿ”„ Attempting to load fallback model...")
68
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
69
+ self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
70
+ self.tokenizer.pad_token = self.tokenizer.eos_token
71
+ print("โœ… Fallback model loaded")
72
+ except Exception as fallback_error:
73
+ print(f"โŒ Fallback model also failed: {fallback_error}")
74
+ self.model = None
75
+ self.tokenizer = None
76
+
77
+ def generate_reasoning_response(
78
+ self,
79
+ prompt: str,
80
+ max_length: int = 2048,
81
+ temperature: float = 0.7,
82
+ top_p: float = 0.9,
83
+ do_sample: bool = True,
84
+ num_return_sequences: int = 1
85
+ ) -> str:
86
+ """
87
+ Generate response with chain-of-thought reasoning optimized for POLARIS
88
+ """
89
+ if not self.model or not self.tokenizer:
90
+ return "โŒ Model not loaded. Please check the model loading status."
91
+
92
+ try:
93
+ # Format prompt for mathematical reasoning
94
+ formatted_prompt = self.format_reasoning_prompt(prompt)
95
+
96
+ # Tokenize input
97
+ inputs = self.tokenizer.encode(
98
+ formatted_prompt,
99
+ return_tensors="pt",
100
+ truncation=True,
101
+ max_length=1024
102
+ ).to(self.device)
103
+
104
+ # Generate with optimized parameters for reasoning
105
+ with torch.no_grad():
106
+ outputs = self.model.generate(
107
+ inputs,
108
+ max_new_tokens=max_length - inputs.shape[1],
109
+ temperature=temperature,
110
+ top_p=top_p,
111
+ do_sample=do_sample,
112
+ num_return_sequences=num_return_sequences,
113
+ pad_token_id=self.tokenizer.pad_token_id,
114
+ eos_token_id=self.tokenizer.eos_token_id,
115
+ repetition_penalty=1.1,
116
+ length_penalty=1.0,
117
+ early_stopping=True
118
+ )
119
+
120
+ # Decode response
121
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
122
+ response = full_response[len(formatted_prompt):].strip()
123
+
124
+ return self.format_response(response)
125
+
126
+ except Exception as e:
127
+ return f"โŒ Error generating response: {str(e)}"
128
+
129
+ def format_reasoning_prompt(self, user_input: str) -> str:
130
+ """Format prompt to encourage step-by-step reasoning"""
131
+ if any(keyword in user_input.lower() for keyword in ['solve', 'calculate', 'find', 'prove', 'show']):
132
+ return f"""<|im_start|>system
133
+ You are POLARIS, an advanced reasoning model specialized in mathematical problem-solving.
134
+ Approach each problem step-by-step with clear reasoning. Show your work and explain each step.
135
+ <|im_end|>
136
+ <|im_start|>user
137
+ {user_input}
138
+
139
+ Please solve this step-by-step:
140
+ <|im_end|>
141
+ <|im_start|>assistant
142
+ I'll solve this step-by-step:
143
+
144
+ """
145
+ else:
146
+ return f"""<|im_start|>system
147
+ You are POLARIS, an advanced reasoning model. Provide thoughtful, well-reasoned responses.
148
+ <|im_end|>
149
+ <|im_start|>user
150
+ {user_input}
151
+ <|im_end|>
152
+ <|im_start|>assistant
153
+ """
154
+
155
+ def format_response(self, response: str) -> str:
156
+ """Clean and format the model response"""
157
+ # Remove potential artifacts
158
+ response = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', response, flags=re.DOTALL)
159
+ response = response.strip()
160
+
161
+ # Ensure proper formatting for mathematical expressions
162
+ if '$$' in response or '\\(' in response:
163
+ response = "๐Ÿงฎ **Mathematical Solution:**\n\n" + response
164
+
165
+ return response
166
+
167
+ # Initialize the model
168
+ polaris_model = PolarisModel()
169
+
170
+ def chat_with_polaris(
171
+ message: str,
172
+ history: List[Tuple[str, str]] = None,
173
+ temperature: float = 0.7,
174
+ max_length: int = 1024
175
+ ) -> Tuple[str, List[Tuple[str, str]]]:
176
+ """Main chat function for Gradio interface"""
177
+ if history is None:
178
+ history = []
179
+
180
+ if not message.strip():
181
+ return "", history
182
+
183
+ # Generate response
184
+ response = polaris_model.generate_reasoning_response(
185
+ message,
186
+ temperature=temperature,
187
+ max_length=max_length
188
+ )
189
+
190
+ # Update history
191
+ history.append((message, response))
192
+
193
+ return "", history
194
+
195
+ def clear_chat():
196
+ """Clear the chat history"""
197
+ return [], []
198
+
199
+ def get_model_info():
200
+ """Return information about the POLARIS model"""
201
+ return """
202
+ ## ๐ŸŒ  POLARIS-4B-Preview
203
+
204
+ **POLARIS** is a post-training recipe for scaling Reinforcement Learning on Advanced Reasoning models.
205
+
206
+ ### Key Features:
207
+ - **4B parameters** optimized for mathematical reasoning
208
+ - **Advanced Chain-of-Thought** reasoning capabilities
209
+ - **Superior performance** on mathematical benchmarks (AIME, AMC, Olympiad)
210
+ - **Outperforms larger models** through specialized RL training
211
+
212
+ ### Benchmark Results:
213
+ - **AIME24**: 81.2% (avg@32)
214
+ - **AIME25**: 79.4% (avg@32)
215
+ - **AMC23**: 94.8% (avg@8)
216
+ - **Minerva Math**: 44.0% (avg@4)
217
+ - **Olympiad Bench**: 69.1% (avg@4)
218
+
219
+ ### Best Use Cases:
220
+ - Mathematical problem solving
221
+ - Step-by-step reasoning tasks
222
+ - Competition math problems
223
+ - Logical reasoning challenges
224
+
225
+ Try asking mathematical questions or reasoning problems!
226
+ """
227
+
228
+ # Create example problems for the interface
229
+ example_problems = [
230
+ "Solve: If x + y = 10 and x - y = 4, find the values of x and y.",
231
+ "Find the derivative of f(x) = 3xยฒ + 2x - 1",
232
+ "Prove that the square root of 2 is irrational.",
233
+ "A rectangle has a perimeter of 24 cm and an area of 35 cmยฒ. Find its dimensions.",
234
+ "What is the sum of the first 100 positive integers?",
235
+ "Solve the quadratic equation: 2xยฒ - 7x + 3 = 0"
236
+ ]
237
+
238
+ # Create the Gradio interface
239
+ with gr.Blocks(
240
+ title="๐ŸŒ  POLARIS-4B-Preview - Advanced Reasoning Model",
241
+ theme=gr.themes.Soft(),
242
+ css="""
243
+ .gradio-container {
244
+ max-width: 1200px !important;
245
+ }
246
+ .chat-message {
247
+ font-size: 16px !important;
248
+ }
249
+ """
250
+ ) as demo:
251
+
252
+ gr.Markdown("""
253
+ # ๐ŸŒ  POLARIS-4B-Preview
254
+ ## Advanced Reasoning Model for Mathematical Problem Solving
255
+
256
+ POLARIS uses reinforcement learning to achieve state-of-the-art performance on mathematical reasoning tasks.
257
+ Try asking mathematical questions, logic problems, or step-by-step reasoning challenges!
258
+ """)
259
+
260
+ with gr.Row():
261
+ with gr.Column(scale=3):
262
+ chatbot = gr.Chatbot(
263
+ height=600,
264
+ show_label=False,
265
+ container=True,
266
+ bubble_full_width=False
267
+ )
268
+
269
+ with gr.Row():
270
+ msg = gr.Textbox(
271
+ placeholder="Enter your mathematical problem or reasoning question...",
272
+ show_label=False,
273
+ scale=5,
274
+ container=False
275
+ )
276
+ submit_btn = gr.Button("๐Ÿš€ Solve", scale=1, variant="primary")
277
+ clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear", scale=1)
278
+
279
+ with gr.Column(scale=1):
280
+ gr.Markdown("### โš™๏ธ Settings")
281
+
282
+ temperature = gr.Slider(
283
+ minimum=0.1,
284
+ maximum=1.5,
285
+ value=0.7,
286
+ step=0.1,
287
+ label="Temperature",
288
+ info="Higher = more creative"
289
+ )
290
+
291
+ max_length = gr.Slider(
292
+ minimum=256,
293
+ maximum=2048,
294
+ value=1024,
295
+ step=128,
296
+ label="Max Response Length",
297
+ info="Maximum tokens to generate"
298
+ )
299
+
300
+ gr.Markdown("### ๐Ÿ“š Example Problems")
301
+
302
+ for i, example in enumerate(example_problems):
303
+ gr.Button(
304
+ f"Example {i+1}",
305
+ size="sm"
306
+ ).click(
307
+ lambda x=example: x,
308
+ outputs=[msg]
309
+ )
310
+
311
+ with gr.Row():
312
+ with gr.Column():
313
+ gr.Markdown("### ๐Ÿ“Š Model Information")
314
+ model_info = gr.Markdown(get_model_info())
315
+
316
+ # Event handlers
317
+ submit_btn.click(
318
+ chat_with_polaris,
319
+ inputs=[msg, chatbot, temperature, max_length],
320
+ outputs=[msg, chatbot]
321
+ )
322
+
323
+ msg.submit(
324
+ chat_with_polaris,
325
+ inputs=[msg, chatbot, temperature, max_length],
326
+ outputs=[msg, chatbot]
327
+ )
328
+
329
+ clear_btn.click(
330
+ clear_chat,
331
+ outputs=[chatbot]
332
+ )
333
+
334
+ # Launch configuration
335
+ if __name__ == "__main__":
336
+ demo.launch(
337
+ server_name="0.0.0.0",
338
+ server_port=7860,
339
+ share=True,
340
+ show_error=True
341
+ )