Polaris-4B-Chat / app.py
ReallyFloppyPenguin's picture
Create app.py
e7f0eb4 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os
from typing import List, Tuple
import re
class PolarisModel:
"""
POLARIS-4B-Preview: A Post-training recipe for scaling RL on Advanced Reasoning models
Specialized for mathematical reasoning and problem-solving tasks.
"""
def __init__(self):
self.model_name = "POLARIS-Project/Polaris-4B-Preview"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = None
self.tokenizer = None
self.load_model()
def load_model(self):
"""Load the POLARIS model with optimized settings for reasoning tasks"""
try:
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
padding_side="left"
)
# Set pad token if not exists
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Configure for efficient inference
if self.device == "cuda":
# Use 4-bit quantization for GPU to fit 4B model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16
)
else:
# CPU inference
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="cpu",
trust_remote_code=True,
torch_dtype=torch.float32
)
print(f"✅ POLARIS-4B-Preview loaded successfully on {self.device}")
except Exception as e:
print(f"❌ Error loading model: {e}")
# Fallback to a smaller model if POLARIS fails to load
try:
print("🔄 Attempting to load fallback model...")
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
self.tokenizer.pad_token = self.tokenizer.eos_token
print("✅ Fallback model loaded")
except Exception as fallback_error:
print(f"❌ Fallback model also failed: {fallback_error}")
self.model = None
self.tokenizer = None
def generate_reasoning_response(
self,
prompt: str,
max_length: int = 2048,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True,
num_return_sequences: int = 1
) -> str:
"""
Generate response with chain-of-thought reasoning optimized for POLARIS
"""
if not self.model or not self.tokenizer:
return "❌ Model not loaded. Please check the model loading status."
try:
# Format prompt for mathematical reasoning
formatted_prompt = self.format_reasoning_prompt(prompt)
# Tokenize input
inputs = self.tokenizer.encode(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=1024
).to(self.device)
# Generate with optimized parameters for reasoning
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_new_tokens=max_length - inputs.shape[1],
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
num_return_sequences=num_return_sequences,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1,
length_penalty=1.0,
early_stopping=True
)
# Decode response
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_response[len(formatted_prompt):].strip()
return self.format_response(response)
except Exception as e:
return f"❌ Error generating response: {str(e)}"
def format_reasoning_prompt(self, user_input: str) -> str:
"""Format prompt to encourage step-by-step reasoning"""
if any(keyword in user_input.lower() for keyword in ['solve', 'calculate', 'find', 'prove', 'show']):
return f"""<|im_start|>system
You are POLARIS, an advanced reasoning model specialized in mathematical problem-solving.
Approach each problem step-by-step with clear reasoning. Show your work and explain each step.
<|im_end|>
<|im_start|>user
{user_input}
Please solve this step-by-step:
<|im_end|>
<|im_start|>assistant
I'll solve this step-by-step:
"""
else:
return f"""<|im_start|>system
You are POLARIS, an advanced reasoning model. Provide thoughtful, well-reasoned responses.
<|im_end|>
<|im_start|>user
{user_input}
<|im_end|>
<|im_start|>assistant
"""
def format_response(self, response: str) -> str:
"""Clean and format the model response"""
# Remove potential artifacts
response = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', response, flags=re.DOTALL)
response = response.strip()
# Ensure proper formatting for mathematical expressions
if '$$' in response or '\\(' in response:
response = "🧮 **Mathematical Solution:**\n\n" + response
return response
# Initialize the model
polaris_model = PolarisModel()
def chat_with_polaris(
message: str,
history: List[Tuple[str, str]] = None,
temperature: float = 0.7,
max_length: int = 1024
) -> Tuple[str, List[Tuple[str, str]]]:
"""Main chat function for Gradio interface"""
if history is None:
history = []
if not message.strip():
return "", history
# Generate response
response = polaris_model.generate_reasoning_response(
message,
temperature=temperature,
max_length=max_length
)
# Update history
history.append((message, response))
return "", history
def clear_chat():
"""Clear the chat history"""
return [], []
def get_model_info():
"""Return information about the POLARIS model"""
return """
## 🌠 POLARIS-4B-Preview
**POLARIS** is a post-training recipe for scaling Reinforcement Learning on Advanced Reasoning models.
### Key Features:
- **4B parameters** optimized for mathematical reasoning
- **Advanced Chain-of-Thought** reasoning capabilities
- **Superior performance** on mathematical benchmarks (AIME, AMC, Olympiad)
- **Outperforms larger models** through specialized RL training
### Benchmark Results:
- **AIME24**: 81.2% (avg@32)
- **AIME25**: 79.4% (avg@32)
- **AMC23**: 94.8% (avg@8)
- **Minerva Math**: 44.0% (avg@4)
- **Olympiad Bench**: 69.1% (avg@4)
### Best Use Cases:
- Mathematical problem solving
- Step-by-step reasoning tasks
- Competition math problems
- Logical reasoning challenges
Try asking mathematical questions or reasoning problems!
"""
# Create example problems for the interface
example_problems = [
"Solve: If x + y = 10 and x - y = 4, find the values of x and y.",
"Find the derivative of f(x) = 3x² + 2x - 1",
"Prove that the square root of 2 is irrational.",
"A rectangle has a perimeter of 24 cm and an area of 35 cm². Find its dimensions.",
"What is the sum of the first 100 positive integers?",
"Solve the quadratic equation: 2x² - 7x + 3 = 0"
]
# Create the Gradio interface
with gr.Blocks(
title="🌠 POLARIS-4B-Preview - Advanced Reasoning Model",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
.chat-message {
font-size: 16px !important;
}
"""
) as demo:
gr.Markdown("""
# 🌠 POLARIS-4B-Preview
## Advanced Reasoning Model for Mathematical Problem Solving
POLARIS uses reinforcement learning to achieve state-of-the-art performance on mathematical reasoning tasks.
Try asking mathematical questions, logic problems, or step-by-step reasoning challenges!
""")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
height=600,
show_label=False,
container=True,
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="Enter your mathematical problem or reasoning question...",
show_label=False,
scale=5,
container=False
)
submit_btn = gr.Button("🚀 Solve", scale=1, variant="primary")
clear_btn = gr.Button("🗑️ Clear", scale=1)
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Settings")
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative"
)
max_length = gr.Slider(
minimum=256,
maximum=2048,
value=1024,
step=128,
label="Max Response Length",
info="Maximum tokens to generate"
)
gr.Markdown("### 📚 Example Problems")
for i, example in enumerate(example_problems):
gr.Button(
f"Example {i+1}",
size="sm"
).click(
lambda x=example: x,
outputs=[msg]
)
with gr.Row():
with gr.Column():
gr.Markdown("### 📊 Model Information")
model_info = gr.Markdown(get_model_info())
# Event handlers
submit_btn.click(
chat_with_polaris,
inputs=[msg, chatbot, temperature, max_length],
outputs=[msg, chatbot]
)
msg.submit(
chat_with_polaris,
inputs=[msg, chatbot, temperature, max_length],
outputs=[msg, chatbot]
)
clear_btn.click(
clear_chat,
outputs=[chatbot]
)
# Launch configuration
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)