|
import argparse |
|
import re |
|
import torch |
|
import uvicorn |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
from transformers import AutoTokenizer |
|
import asyncio |
|
from collections import defaultdict |
|
import json |
|
from openai import AsyncOpenAI |
|
import time |
|
import math |
|
|
|
|
|
|
|
PROMPT_critic_updated = ''' |
|
Given a problem, determine whether the final answer in the provided (incomplete) solution process matches the reference answer. |
|
The reference answer may be one single option character (e.g., A, B, C, D), a numerical value, an expression, or a list of answers if multiple questions are involved. |
|
**The reference answer may be in Chinese or another language, but your evaluation should be language-agnostic.** |
|
|
|
Your task: |
|
- Compare the final output of the solution process with the reference answer. |
|
- If they **match exactly**, output **YES**. |
|
- If they **do not match**, output **NO**. |
|
- If the solution process is unclear, incomplete, or ambiguous, assume it is incorrect and output **NO**. |
|
|
|
Your output must be strictly **'YES'** or **'NO'**, with no additional words, punctuation, or explanation. |
|
|
|
--- |
|
|
|
**Question:** |
|
{question} |
|
|
|
**Solution Process (Final Step Only):** |
|
{response} |
|
|
|
**Reference Answer:** |
|
{reference} |
|
|
|
**Output:** |
|
''' |
|
|
|
|
|
|
|
def parse_im_sections(text): |
|
|
|
sections = re.findall(r"<\|im_start\|>(.*?)<\|im_end\|>", text, re.DOTALL) |
|
parsed = {} |
|
for section in sections: |
|
try: |
|
|
|
role, content = section.split("\n", 1) |
|
parsed[role.strip()] = content.strip() |
|
except ValueError: |
|
print(f"Skipping malformed section: {section}") |
|
return parsed |
|
|
|
def extract_last_non_empty_line(text, role="assistant"): |
|
|
|
pattern = fr"<\|im_start\|>{role}(.*?)(?:<\|im_start\|>|<\|endoftext\|>|<\|eot_id\|>|$)" |
|
match = re.search(pattern, text, re.DOTALL) |
|
if match: |
|
content = match.group(1).strip() |
|
|
|
lines = [line for line in content.splitlines() if line.strip()] |
|
if lines: |
|
last_non_empty_line=lines[-1] |
|
else: |
|
return "" |
|
return last_non_empty_line |
|
return "" |
|
|
|
|
|
def reward_normalization(rewards): |
|
if len(rewards) == 1: |
|
return [0.0] |
|
rewards = torch.tensor(rewards, dtype=torch.float64) |
|
if rewards.std() == 0: |
|
normalized_rewards = torch.zeros_like(rewards) |
|
else: |
|
normalized_rewards = (rewards - rewards.mean()) / rewards.std() |
|
|
|
return normalized_rewards.tolist() |
|
|
|
|
|
def strip_sequence(text, pad_token, eos_token): |
|
pad_token_escaped = re.escape(pad_token) |
|
eos_token_escaped = re.escape(eos_token) |
|
|
|
pattern = f"^({eos_token_escaped}|{pad_token_escaped})+" |
|
text = re.sub(pattern, "", text) |
|
|
|
pattern = f"({eos_token_escaped}|{pad_token_escaped})+$" |
|
text = re.sub(pattern, "", text) |
|
return text |
|
|
|
|
|
def group_reward_normalization(rewards, n_samples_per_prompt=4): |
|
rewards = torch.tensor(rewards, dtype=torch.float64) |
|
rewards = rewards.reshape(-1, n_samples_per_prompt) |
|
|
|
mean = rewards.mean(dim=-1, keepdim=True) |
|
std = rewards.std(dim=-1, keepdim=True) |
|
|
|
normalized_rewards = torch.where(std == 0, torch.zeros_like(rewards), (rewards - mean) / std) |
|
|
|
return normalized_rewards.flatten().tolist() |
|
|
|
|
|
class RewardModelProxy: |
|
def __init__(self, args): |
|
self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True) |
|
self.normalize_reward = args.normalize_reward |
|
self.group_normalize_reward = args.group_normalize_reward |
|
self.qa_dict = defaultdict(str) |
|
self.load_dict(args.answer_path) |
|
self.temperature = 0 |
|
self.stop=[self.tokenizer.eos_token,"<|im_end|>"] |
|
self.max_tokens=1 |
|
self.prob_reward=args.prob_reward |
|
self.log_path=args.log_path |
|
self.vllm_model=args.vllm_model |
|
|
|
def load_dict(self, path): |
|
|
|
with open(path, "r", encoding="utf-8") as file: |
|
data = json.load(file) |
|
for unit in data: |
|
question = unit["query"][1]["content"] |
|
label = unit["label"] |
|
self.qa_dict[question] = label |
|
|
|
if self.qa_dict: |
|
sample_question, sample_label = next(iter(self.qa_dict.items())) |
|
print("Sample Question:", sample_question) |
|
print("Sample Label:", sample_label) |
|
else: |
|
print("qa_dict is empty.") |
|
|
|
|
|
async def process_sample(self,query): |
|
query = strip_sequence(query, self.tokenizer.pad_token, self.tokenizer.eos_token)+ self.tokenizer.eos_token |
|
question = parse_im_sections(query)["user"] |
|
answer = extract_last_non_empty_line(query, role="assistant") |
|
if not answer.strip(): |
|
return 0.0 |
|
else: |
|
prompt_question = PROMPT_critic_updated.format(question=question, reference=self.qa_dict[question], response=answer) |
|
return await self.get_reward_from_vllm(prompt_question) |
|
|
|
async def get_reward_from_vllm(self, query): |
|
"""Retrieve model judgment reward (with probability analysis)""" |
|
max_retries = 10 |
|
delay=10 |
|
for attempt in range(max_retries): |
|
try: |
|
response = await client.chat.completions.create( |
|
model=self.vllm_model, |
|
messages=[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": query}, |
|
], |
|
temperature=self.temperature, |
|
max_tokens=self.max_tokens, |
|
stop=self.stop, |
|
logprobs=True, |
|
top_logprobs=10 |
|
) |
|
return self.calculate_reward_from_logprobs(response) |
|
|
|
except Exception as e: |
|
print(f"Attempt {attempt+1} failed: {str(e)}, retrying in {delay} seconds...") |
|
await asyncio.sleep(delay) |
|
print(f"Failed after {max_retries} retries, query content: {query[:200]}...") |
|
return 0.0 |
|
|
|
def calculate_reward_from_logprobs(self, response): |
|
"""Calculate normalized reward based on log probabilities""" |
|
|
|
logprobs = response.choices[0].logprobs.content[0].top_logprobs |
|
token_probs = {token.token: math.exp(token.logprob) for token in logprobs} |
|
|
|
|
|
yes_prob = sum(prob for token, prob in token_probs.items() if token.lower().strip()=="yes") |
|
no_prob = sum(prob for token, prob in token_probs.items()if token.lower().strip()=="no") |
|
total = yes_prob + no_prob |
|
if total == 0: |
|
return 0.0 |
|
if self.prob_reward: |
|
print(yes_prob/total) |
|
return yes_prob / total |
|
return 1.0 if yes_prob > no_prob else 0.0 |
|
|
|
async def get_reward(self, queries): |
|
print("Processing queries[0]: {}".format(queries[0])) |
|
tasks = [self.process_sample(query) for query in queries] |
|
scores = await asyncio.gather(*tasks) |
|
print("Generated scores: {}".format(scores)) |
|
if self.log_path: |
|
with open(self.log_path, 'a', encoding='utf-8') as f: |
|
unit = { |
|
"query_list": queries if isinstance(queries, list) else [], |
|
"hard_score_list": scores if isinstance(scores, list) else [] |
|
} |
|
json.dump(unit, f, ensure_ascii=False) |
|
f.write('\n') |
|
if self.normalize_reward: |
|
return reward_normalization(scores) |
|
elif self.group_normalize_reward: |
|
return group_reward_normalization(scores) |
|
else: |
|
return scores |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--tokenizer_path", type=str, default=None) |
|
parser.add_argument("--answer_path", type=str, default=None) |
|
parser.add_argument("--prob_reward", action="store_true", default=False) |
|
parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") |
|
parser.add_argument("--group_normalize_reward", action="store_true", default=False, help="Enable Group Reward Normazation") |
|
parser.add_argument("--port", type=int, default=5000, help="Port number for the server") |
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server") |
|
parser.add_argument("--log_path", type=str, default=None) |
|
parser.add_argument("--vllm_url", type=str, default=None) |
|
parser.add_argument("--vllm_model", type=str, default=None) |
|
args = parser.parse_args() |
|
openai_api_key = "EMPTY" |
|
openai_api_base = args.vllm_url |
|
|
|
client = AsyncOpenAI( |
|
api_key=openai_api_key, |
|
base_url=openai_api_base, |
|
) |
|
|
|
|
|
reward_model = RewardModelProxy(args) |
|
app = FastAPI() |
|
|
|
|
|
@app.post("/get_reward") |
|
async def get_reward(request: Request): |
|
data = await request.json() |
|
queries = data.get("query") |
|
rewards = await reward_model.get_reward(queries) |
|
result = {"rewards": rewards} |
|
print(f"Sent JSON response: {result}") |
|
return JSONResponse(result) |
|
|
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
|
|