--- base_model: Qwen/Qwen2.5-0.5B-Instruct library_name: transformers tags: - qwen2 - math - reasoning - ppo - rlhf - gsm8k - merged-model - verl language: - en pipeline_tag: text-generation --- # Qwen2.5-0.5B-Instruct-GSM8K-PPO-Merged (perhaps) state-of-the-art math models on GSM8k with less than 0.5 billion parameters ## 📊 Model Overview This model is a **merged version** of multiple high-performing checkpoints derived from fine-tuning **Qwen2.5-0.5B-Instruct** using **PPO (Proximal Policy Optimization)** on the **GSM8K** mathematical reasoning dataset. ### 🎯 Key Features - **Base Model**: Qwen/Qwen2.5-0.5B-Instruct (494M parameters) - **Training Algorithm**: PPO via [VERL](https://github.com/volcengine/verl) framework - **Specialization**: Mathematical reasoning and problem-solving - **Model Merging**: Averaged from 3 best-performing checkpoints using [mergekit](https://github.com/cg123/mergekit) ## 📈 Performance | Dataset | Score | Improvement | |---------|-------|-------------| | **GSM8K** | **58.91%** | **+9.31%** over qwen2.5-0.5b-instruct model | > This represents a significant improvement in mathematical reasoning capabilities for a 0.5B parameter model. ## 🔧 Usage ### Quick Start ```python from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Load model and tokenizer model_name = "alphadl/ppo-gsm8k-0.5b" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto" ) # Example: Mathematical reasoning prompt = """Solve this step by step: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Let's think step by step and output the final answer after "####".""" inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, temperature=0.0, do_sample=False, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) print(response) ``` ### Expected Output Format The model is trained to provide step-by-step mathematical reasoning followed by the final answer in the format: ``` #### [numerical_answer] ``` ## 🛠️ Training Details ### Training Framework - **Framework**: [VERL (Volcano Engine Reinforcement Learning)](https://github.com/volcengine/verl) - **Algorithm**: PPO (Proximal Policy Optimization) - **Data Source**: GSM8K mathematical reasoning dataset ### Model Merging Strategy This model was created by merging 6 high-performing checkpoints using **linear interpolation**: | Checkpoint | GSM8K Score | Weight | |------------|-------------|---------| | global_step_5000 | 58.3% | 33% | | global_step_6000 | 58.7% | 33% | | global_step_7500 | 58.9% | 34% | **Result**: The merged model achieved **58.91%**, surpassing individual checkpoints! ### Training Configuration - **Base Model**: Qwen/Qwen2.5-0.5B-Instruct - **Training Steps**: 7,500+ steps - **Validation Frequency**: Every 1,000 steps - **Optimization**: AdamW with learning rate scheduling ## 🎯 Use Cases This model excels at: - **Mathematical Problem Solving**: Arithmetic, algebra, basic geometry - **Step-by-Step Reasoning**: Breaking down complex problems - **Educational Applications**: Math tutoring and explanation - **Computational Tasks**: Basic calculations with reasoning ## ⚠️ Limitations - **Model Size**: As a 0.5B parameter model, it may struggle with very complex mathematical concepts - **Domain Specificity**: Optimized for GSM8K-style problems; may not perform as well on other domains - **Context Length**: Limited by the base model's context window (32K tokens) ## 📄 License This model inherits the license from the base Qwen2.5-0.5B-Instruct model. Please refer to the [original model card](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) for licensing details. ## 🙏 Acknowledgments - **Base Model**: [Qwen Team](https://github.com/QwenLM/Qwen2.5) for Qwen2.5-0.5B-Instruct - **Training Framework**: [VERL Team](https://github.com/volcengine/verl) for the PPO implementation - **Model Merging**: [mergekit](https://github.com/cg123/mergekit) for the averaging capabilities - **Dataset**: [GSM8K](https://github.com/openai/grade-school-math) for mathematical reasoning data ## 📬 Contact For questions or issues, please open an issue in the repository or contact the model author. --- *This model was trained using VERL framework and merged using mergekit for optimal mathematical reasoning performance.*