import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import os from typing import List, Tuple import re class PolarisModel: """ POLARIS-4B-Preview: A Post-training recipe for scaling RL on Advanced Reasoning models Specialized for mathematical reasoning and problem-solving tasks. """ def __init__(self): self.model_name = "POLARIS-Project/Polaris-4B-Preview" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = None self.tokenizer = None self.load_model() def load_model(self): """Load the POLARIS model with optimized settings for reasoning tasks""" try: # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True, padding_side="left" ) # Set pad token if not exists if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Configure for efficient inference if self.device == "cuda": # Use 4-bit quantization for GPU to fit 4B model quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, quantization_config=quantization_config, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16 ) else: # CPU inference self.model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="cpu", trust_remote_code=True, torch_dtype=torch.float32 ) print(f"✅ POLARIS-4B-Preview loaded successfully on {self.device}") except Exception as e: print(f"❌ Error loading model: {e}") # Fallback to a smaller model if POLARIS fails to load try: print("🔄 Attempting to load fallback model...") self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") self.tokenizer.pad_token = self.tokenizer.eos_token print("✅ Fallback model loaded") except Exception as fallback_error: print(f"❌ Fallback model also failed: {fallback_error}") self.model = None self.tokenizer = None def generate_reasoning_response( self, prompt: str, max_length: int = 2048, temperature: float = 0.7, top_p: float = 0.9, do_sample: bool = True, num_return_sequences: int = 1 ) -> str: """ Generate response with chain-of-thought reasoning optimized for POLARIS """ if not self.model or not self.tokenizer: return "❌ Model not loaded. Please check the model loading status." try: # Format prompt for mathematical reasoning formatted_prompt = self.format_reasoning_prompt(prompt) # Tokenize input inputs = self.tokenizer.encode( formatted_prompt, return_tensors="pt", truncation=True, max_length=1024 ).to(self.device) # Generate with optimized parameters for reasoning with torch.no_grad(): outputs = self.model.generate( inputs, max_new_tokens=max_length - inputs.shape[1], temperature=temperature, top_p=top_p, do_sample=do_sample, num_return_sequences=num_return_sequences, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.1, length_penalty=1.0, early_stopping=True ) # Decode response full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) response = full_response[len(formatted_prompt):].strip() return self.format_response(response) except Exception as e: return f"❌ Error generating response: {str(e)}" def format_reasoning_prompt(self, user_input: str) -> str: """Format prompt to encourage step-by-step reasoning""" if any(keyword in user_input.lower() for keyword in ['solve', 'calculate', 'find', 'prove', 'show']): return f"""<|im_start|>system You are POLARIS, an advanced reasoning model specialized in mathematical problem-solving. Approach each problem step-by-step with clear reasoning. Show your work and explain each step. <|im_end|> <|im_start|>user {user_input} Please solve this step-by-step: <|im_end|> <|im_start|>assistant I'll solve this step-by-step: """ else: return f"""<|im_start|>system You are POLARIS, an advanced reasoning model. Provide thoughtful, well-reasoned responses. <|im_end|> <|im_start|>user {user_input} <|im_end|> <|im_start|>assistant """ def format_response(self, response: str) -> str: """Clean and format the model response""" # Remove potential artifacts response = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', response, flags=re.DOTALL) response = response.strip() # Ensure proper formatting for mathematical expressions if '$$' in response or '\\(' in response: response = "🧮 **Mathematical Solution:**\n\n" + response return response # Initialize the model polaris_model = PolarisModel() def chat_with_polaris( message: str, history: List[Tuple[str, str]] = None, temperature: float = 0.7, max_length: int = 1024 ) -> Tuple[str, List[Tuple[str, str]]]: """Main chat function for Gradio interface""" if history is None: history = [] if not message.strip(): return "", history # Generate response response = polaris_model.generate_reasoning_response( message, temperature=temperature, max_length=max_length ) # Update history history.append((message, response)) return "", history def clear_chat(): """Clear the chat history""" return [], [] def get_model_info(): """Return information about the POLARIS model""" return """ ## 🌠 POLARIS-4B-Preview **POLARIS** is a post-training recipe for scaling Reinforcement Learning on Advanced Reasoning models. ### Key Features: - **4B parameters** optimized for mathematical reasoning - **Advanced Chain-of-Thought** reasoning capabilities - **Superior performance** on mathematical benchmarks (AIME, AMC, Olympiad) - **Outperforms larger models** through specialized RL training ### Benchmark Results: - **AIME24**: 81.2% (avg@32) - **AIME25**: 79.4% (avg@32) - **AMC23**: 94.8% (avg@8) - **Minerva Math**: 44.0% (avg@4) - **Olympiad Bench**: 69.1% (avg@4) ### Best Use Cases: - Mathematical problem solving - Step-by-step reasoning tasks - Competition math problems - Logical reasoning challenges Try asking mathematical questions or reasoning problems! """ # Create example problems for the interface example_problems = [ "Solve: If x + y = 10 and x - y = 4, find the values of x and y.", "Find the derivative of f(x) = 3x² + 2x - 1", "Prove that the square root of 2 is irrational.", "A rectangle has a perimeter of 24 cm and an area of 35 cm². Find its dimensions.", "What is the sum of the first 100 positive integers?", "Solve the quadratic equation: 2x² - 7x + 3 = 0" ] # Create the Gradio interface with gr.Blocks( title="🌠 POLARIS-4B-Preview - Advanced Reasoning Model", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 1200px !important; } .chat-message { font-size: 16px !important; } """ ) as demo: gr.Markdown(""" # 🌠 POLARIS-4B-Preview ## Advanced Reasoning Model for Mathematical Problem Solving POLARIS uses reinforcement learning to achieve state-of-the-art performance on mathematical reasoning tasks. Try asking mathematical questions, logic problems, or step-by-step reasoning challenges! """) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( height=600, show_label=False, container=True, bubble_full_width=False ) with gr.Row(): msg = gr.Textbox( placeholder="Enter your mathematical problem or reasoning question...", show_label=False, scale=5, container=False ) submit_btn = gr.Button("🚀 Solve", scale=1, variant="primary") clear_btn = gr.Button("🗑️ Clear", scale=1) with gr.Column(scale=1): gr.Markdown("### ⚙️ Settings") temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature", info="Higher = more creative" ) max_length = gr.Slider( minimum=256, maximum=2048, value=1024, step=128, label="Max Response Length", info="Maximum tokens to generate" ) gr.Markdown("### 📚 Example Problems") for i, example in enumerate(example_problems): gr.Button( f"Example {i+1}", size="sm" ).click( lambda x=example: x, outputs=[msg] ) with gr.Row(): with gr.Column(): gr.Markdown("### 📊 Model Information") model_info = gr.Markdown(get_model_info()) # Event handlers submit_btn.click( chat_with_polaris, inputs=[msg, chatbot, temperature, max_length], outputs=[msg, chatbot] ) msg.submit( chat_with_polaris, inputs=[msg, chatbot, temperature, max_length], outputs=[msg, chatbot] ) clear_btn.click( clear_chat, outputs=[chatbot] ) # Launch configuration if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )