|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import os |
|
from typing import List, Tuple |
|
import re |
|
|
|
class PolarisModel: |
|
""" |
|
POLARIS-4B-Preview: A Post-training recipe for scaling RL on Advanced Reasoning models |
|
Specialized for mathematical reasoning and problem-solving tasks. |
|
""" |
|
|
|
def __init__(self): |
|
self.model_name = "POLARIS-Project/Polaris-4B-Preview" |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = None |
|
self.tokenizer = None |
|
self.load_model() |
|
|
|
def load_model(self): |
|
"""Load the POLARIS model with optimized settings for reasoning tasks""" |
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
self.model_name, |
|
trust_remote_code=True, |
|
padding_side="left" |
|
) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
if self.device == "cuda": |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4" |
|
) |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
quantization_config=quantization_config, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16 |
|
) |
|
else: |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
device_map="cpu", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float32 |
|
) |
|
|
|
print(f"✅ POLARIS-4B-Preview loaded successfully on {self.device}") |
|
|
|
except Exception as e: |
|
print(f"❌ Error loading model: {e}") |
|
|
|
try: |
|
print("🔄 Attempting to load fallback model...") |
|
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
print("✅ Fallback model loaded") |
|
except Exception as fallback_error: |
|
print(f"❌ Fallback model also failed: {fallback_error}") |
|
self.model = None |
|
self.tokenizer = None |
|
|
|
def generate_reasoning_response( |
|
self, |
|
prompt: str, |
|
max_length: int = 2048, |
|
temperature: float = 0.7, |
|
top_p: float = 0.9, |
|
do_sample: bool = True, |
|
num_return_sequences: int = 1 |
|
) -> str: |
|
""" |
|
Generate response with chain-of-thought reasoning optimized for POLARIS |
|
""" |
|
if not self.model or not self.tokenizer: |
|
return "❌ Model not loaded. Please check the model loading status." |
|
|
|
try: |
|
|
|
formatted_prompt = self.format_reasoning_prompt(prompt) |
|
|
|
|
|
inputs = self.tokenizer.encode( |
|
formatted_prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=1024 |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
inputs, |
|
max_new_tokens=max_length - inputs.shape[1], |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=do_sample, |
|
num_return_sequences=num_return_sequences, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
repetition_penalty=1.1, |
|
length_penalty=1.0, |
|
early_stopping=True |
|
) |
|
|
|
|
|
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
response = full_response[len(formatted_prompt):].strip() |
|
|
|
return self.format_response(response) |
|
|
|
except Exception as e: |
|
return f"❌ Error generating response: {str(e)}" |
|
|
|
def format_reasoning_prompt(self, user_input: str) -> str: |
|
"""Format prompt to encourage step-by-step reasoning""" |
|
if any(keyword in user_input.lower() for keyword in ['solve', 'calculate', 'find', 'prove', 'show']): |
|
return f"""<|im_start|>system |
|
You are POLARIS, an advanced reasoning model specialized in mathematical problem-solving. |
|
Approach each problem step-by-step with clear reasoning. Show your work and explain each step. |
|
<|im_end|> |
|
<|im_start|>user |
|
{user_input} |
|
|
|
Please solve this step-by-step: |
|
<|im_end|> |
|
<|im_start|>assistant |
|
I'll solve this step-by-step: |
|
|
|
""" |
|
else: |
|
return f"""<|im_start|>system |
|
You are POLARIS, an advanced reasoning model. Provide thoughtful, well-reasoned responses. |
|
<|im_end|> |
|
<|im_start|>user |
|
{user_input} |
|
<|im_end|> |
|
<|im_start|>assistant |
|
""" |
|
|
|
def format_response(self, response: str) -> str: |
|
"""Clean and format the model response""" |
|
|
|
response = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', response, flags=re.DOTALL) |
|
response = response.strip() |
|
|
|
|
|
if '$$' in response or '\\(' in response: |
|
response = "🧮 **Mathematical Solution:**\n\n" + response |
|
|
|
return response |
|
|
|
|
|
polaris_model = PolarisModel() |
|
|
|
def chat_with_polaris( |
|
message: str, |
|
history: List[Tuple[str, str]] = None, |
|
temperature: float = 0.7, |
|
max_length: int = 1024 |
|
) -> Tuple[str, List[Tuple[str, str]]]: |
|
"""Main chat function for Gradio interface""" |
|
if history is None: |
|
history = [] |
|
|
|
if not message.strip(): |
|
return "", history |
|
|
|
|
|
response = polaris_model.generate_reasoning_response( |
|
message, |
|
temperature=temperature, |
|
max_length=max_length |
|
) |
|
|
|
|
|
history.append((message, response)) |
|
|
|
return "", history |
|
|
|
def clear_chat(): |
|
"""Clear the chat history""" |
|
return [], [] |
|
|
|
def get_model_info(): |
|
"""Return information about the POLARIS model""" |
|
return """ |
|
## 🌠 POLARIS-4B-Preview |
|
|
|
**POLARIS** is a post-training recipe for scaling Reinforcement Learning on Advanced Reasoning models. |
|
|
|
### Key Features: |
|
- **4B parameters** optimized for mathematical reasoning |
|
- **Advanced Chain-of-Thought** reasoning capabilities |
|
- **Superior performance** on mathematical benchmarks (AIME, AMC, Olympiad) |
|
- **Outperforms larger models** through specialized RL training |
|
|
|
### Benchmark Results: |
|
- **AIME24**: 81.2% (avg@32) |
|
- **AIME25**: 79.4% (avg@32) |
|
- **AMC23**: 94.8% (avg@8) |
|
- **Minerva Math**: 44.0% (avg@4) |
|
- **Olympiad Bench**: 69.1% (avg@4) |
|
|
|
### Best Use Cases: |
|
- Mathematical problem solving |
|
- Step-by-step reasoning tasks |
|
- Competition math problems |
|
- Logical reasoning challenges |
|
|
|
Try asking mathematical questions or reasoning problems! |
|
""" |
|
|
|
|
|
example_problems = [ |
|
"Solve: If x + y = 10 and x - y = 4, find the values of x and y.", |
|
"Find the derivative of f(x) = 3x² + 2x - 1", |
|
"Prove that the square root of 2 is irrational.", |
|
"A rectangle has a perimeter of 24 cm and an area of 35 cm². Find its dimensions.", |
|
"What is the sum of the first 100 positive integers?", |
|
"Solve the quadratic equation: 2x² - 7x + 3 = 0" |
|
] |
|
|
|
|
|
with gr.Blocks( |
|
title="🌠 POLARIS-4B-Preview - Advanced Reasoning Model", |
|
theme=gr.themes.Soft(), |
|
css=""" |
|
.gradio-container { |
|
max-width: 1200px !important; |
|
} |
|
.chat-message { |
|
font-size: 16px !important; |
|
} |
|
""" |
|
) as demo: |
|
|
|
gr.Markdown(""" |
|
# 🌠 POLARIS-4B-Preview |
|
## Advanced Reasoning Model for Mathematical Problem Solving |
|
|
|
POLARIS uses reinforcement learning to achieve state-of-the-art performance on mathematical reasoning tasks. |
|
Try asking mathematical questions, logic problems, or step-by-step reasoning challenges! |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot( |
|
height=600, |
|
show_label=False, |
|
container=True, |
|
bubble_full_width=False |
|
) |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
placeholder="Enter your mathematical problem or reasoning question...", |
|
show_label=False, |
|
scale=5, |
|
container=False |
|
) |
|
submit_btn = gr.Button("🚀 Solve", scale=1, variant="primary") |
|
clear_btn = gr.Button("🗑️ Clear", scale=1) |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### ⚙️ Settings") |
|
|
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.5, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature", |
|
info="Higher = more creative" |
|
) |
|
|
|
max_length = gr.Slider( |
|
minimum=256, |
|
maximum=2048, |
|
value=1024, |
|
step=128, |
|
label="Max Response Length", |
|
info="Maximum tokens to generate" |
|
) |
|
|
|
gr.Markdown("### 📚 Example Problems") |
|
|
|
for i, example in enumerate(example_problems): |
|
gr.Button( |
|
f"Example {i+1}", |
|
size="sm" |
|
).click( |
|
lambda x=example: x, |
|
outputs=[msg] |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### 📊 Model Information") |
|
model_info = gr.Markdown(get_model_info()) |
|
|
|
|
|
submit_btn.click( |
|
chat_with_polaris, |
|
inputs=[msg, chatbot, temperature, max_length], |
|
outputs=[msg, chatbot] |
|
) |
|
|
|
msg.submit( |
|
chat_with_polaris, |
|
inputs=[msg, chatbot, temperature, max_length], |
|
outputs=[msg, chatbot] |
|
) |
|
|
|
clear_btn.click( |
|
clear_chat, |
|
outputs=[chatbot] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True, |
|
show_error=True |
|
) |