brendanm12345 commited on
Commit
41e4cf8
·
verified ·
1 Parent(s): ba2e06b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +85 -36
README.md CHANGED
@@ -1,61 +1,110 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
4
- # Weaver Distilled - All Datasets (gte-Qwen2-1.5B-instruct)
5
 
6
- This is a distilled cross-encoder model based on [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct), trained to predict the correctness of answers across multiple domains: [MATH500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500), [GPQA](https://huggingface.co/datasets/Idavidrein/gpqa), and [MMLU Pro](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro). This general-purpose verifier was trained on Weaver scores aggregated over 35 different verifiers and reward models.
 
 
7
 
8
  ## Model Details
9
 
10
- - **Base Model**: [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)
11
  - **Architecture**: Cross-encoder with MLP head (1536 → 768 → 384 → 1)
12
- - **Max Sequence Length**: 4096
13
- - **Training Data**: Combined dataset of [MATH500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500), [GPQA](https://huggingface.co/datasets/Idavidrein/gpqa), and [MMLU Pro](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) from 35 different LM Judges and reward models aggregated with Weaver
14
- - **Training Objective**: Binary classification (correct/incorrect answer prediction)
15
 
16
- ## Usage
17
 
18
- ```python
19
- from custom_crossencoder import CustomCrossEncoder, TrainingConfig
 
 
 
20
 
21
- # Initialize model
22
- config = TrainingConfig(
23
- model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
24
- max_length=4096,
25
- mlp_hidden_dims=[1536, 768, 384]
26
- )
27
- model = CustomCrossEncoder(config)
28
-
29
- # Load checkpoint
30
- model.load_state_dict(torch.load("hazyresearch/Weaver_Distilled_All_Datasets_gte-Qwen2-1.5B-instruct"))
31
- model.eval()
32
-
33
- # Get prediction
34
- instruction = "Your instruction here"
35
- answer = "Your answer here"
36
- encoded = model.tokenizer(
37
- text=instruction,
38
- text_pair=answer,
 
39
  truncation=True,
40
  max_length=4096,
41
- padding="max_length",
42
  return_tensors="pt"
43
  )
 
 
44
  with torch.no_grad():
45
- prediction = model(encoded["input_ids"], encoded["attention_mask"])
 
 
 
 
46
  ```
47
 
 
48
 
49
- ## Running Evaluation
50
 
51
- TODO: ADD EVALUATION_SIMPLE COMMAND HERE
52
 
53
- ## License
54
 
55
- [Your chosen license]
 
 
 
 
 
 
 
 
56
 
57
- ## Citation
 
 
 
 
 
 
 
 
58
 
59
- If you use this model in your research, please cite:
60
 
61
- TODO
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ pipeline_tag: text-classification
4
+ library_name: transformers
5
+ base_model: Alibaba-NLP/gte-Qwen2-1.5B-instruct
6
+ tags:
7
+ - math
8
+ - science
9
+ - academic
10
+ - reasoning
11
+ - verification
12
+ - weaver
13
+ - cross-encoder
14
+ - multi-domain
15
+ language:
16
+ - en
17
  ---
 
18
 
19
+ # Weaver Distilled for All Datasets (gte-Qwen2-1.5B-instruct)
20
+
21
+ A general-purpose distilled cross-encoder model that captures 98.7% of Weaver's accuracy while reducing verification compute by 99.97%. This model is fine-tuned from gte-Qwen2-1.5B-instruct to predict the correctness of reasoning responses across multiple domains: mathematics (MATH500), science (GPQA), and academic knowledge (MMLU-Pro).
22
 
23
  ## Model Details
24
 
25
+ - **Base Model**: [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) (1.5B parameters)
26
  - **Architecture**: Cross-encoder with MLP head (1536 → 768 → 384 → 1)
27
+ - **Max Sequence Length**: 4096 tokens
28
+ - **Training Data**: Combined MATH500, GPQA, and MMLU-Pro with Weaver scores from 35 LM judges and reward models
29
+ - **Task**: Binary classification for answer correctness prediction across domains
30
 
31
+ ## Performance
32
 
33
+ Multi-domain performance with Llama 3.1 70B generations:
34
+ <!-- TODO: Update with actual performance numbers -->
35
+ - **Weaver (Full)**: XX.X% accuracy, high compute cost
36
+ - **Weaver (Distilled)**: XX.X% accuracy, 99.97% compute reduction
37
+ - **Majority Voting**: XX.X% accuracy
38
 
39
+ ## Quick Start
40
+
41
+ ```python
42
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
43
+ import torch
44
+
45
+ # Load model and tokenizer
46
+ model_name = "hazyresearch/Weaver_Distilled_All_Datasets_gte-Qwen2-1.5B-instruct"
47
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
48
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
49
+
50
+ # Example usage - works across math, science, and academic domains
51
+ instruction = "What is the derivative of f(x) = 3x² + 2x - 1?"
52
+ response = "Using the power rule: f'(x) = 6x + 2. The derivative of 3x² is 6x, the derivative of 2x is 2, and the derivative of -1 is 0."
53
+
54
+ # Tokenize input pair
55
+ inputs = tokenizer(
56
+ instruction,
57
+ response,
58
  truncation=True,
59
  max_length=4096,
60
+ padding=True,
61
  return_tensors="pt"
62
  )
63
+
64
+ # Get correctness score
65
  with torch.no_grad():
66
+ outputs = model(**inputs)
67
+ score = torch.sigmoid(outputs.logits).item()
68
+
69
+ print(f"Correctness score: {score:.3f}")
70
+ print(f"Prediction: {'Correct' if score > 0.5 else 'Incorrect'}")
71
  ```
72
 
73
+ ## Training Details
74
 
75
+ This model was trained using the [Weaver distillation pipeline](https://github.com/ScalingIntelligence/scaling-verification/tree/main/distillation) on a combined dataset spanning multiple reasoning domains. For training your own distilled models, see the [distillation README](https://github.com/ScalingIntelligence/scaling-verification/blob/main/distillation/README.md).
76
 
77
+ ## Evaluation
78
 
79
+ Evaluate this model on different datasets:
80
 
81
+ ```bash
82
+ # MATH500
83
+ python evaluate_crossencoder.py \
84
+ --model_name "Alibaba-NLP/gte-Qwen2-1.5B-instruct" \
85
+ --checkpoint_path "hazyresearch/Weaver_Distilled_All_Datasets_gte-Qwen2-1.5B-instruct" \
86
+ --dataset_path "hazyresearch/MATH500_with_Llama_3.1_70B_Instruct_v1" \
87
+ --dataset_split "data" \
88
+ --max_length 4096 \
89
+ --batch_size 64
90
 
91
+ # GPQA
92
+ python evaluate_crossencoder.py \
93
+ --model_name "Alibaba-NLP/gte-Qwen2-1.5B-instruct" \
94
+ --checkpoint_path "hazyresearch/Weaver_Distilled_All_Datasets_gte-Qwen2-1.5B-instruct" \
95
+ --dataset_path "hazyresearch/GPQA_with_Llama_3.1_70B_Instruct_v1" \
96
+ --dataset_split "data" \
97
+ --max_length 4096 \
98
+ --batch_size 64
99
+ ```
100
 
101
+ ## Citation
102
 
103
+ ```bibtex
104
+ @article{weaver2025,
105
+ title={Weaver: Shrinking the Generation-Verification Gap with Weak Verifiers},
106
+ author={},
107
+ journal={arXiv preprint},
108
+ year={2025}
109
+ }
110
+ ```