| import argparse |
| import re |
| import torch |
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse |
| from transformers import AutoTokenizer |
| import asyncio |
| from collections import defaultdict |
| import json |
| from openai import AsyncOpenAI |
| import time |
| import math |
| |
|
|
| |
| PROMPT_critic_updated = ''' |
| Given a problem, determine whether the final answer in the provided (incomplete) solution process matches the reference answer. |
| The reference answer may be one single option character (e.g., A, B, C, D), a numerical value, an expression, or a list of answers if multiple questions are involved. |
| **The reference answer may be in Chinese or another language, but your evaluation should be language-agnostic.** |
| |
| Your task: |
| - Compare the final output of the solution process with the reference answer. |
| - If they **match exactly**, output **YES**. |
| - If they **do not match**, output **NO**. |
| - If the solution process is unclear, incomplete, or ambiguous, assume it is incorrect and output **NO**. |
| |
| Your output must be strictly **'YES'** or **'NO'**, with no additional words, punctuation, or explanation. |
| |
| --- |
| |
| **Question:** |
| {question} |
| |
| **Solution Process (Final Step Only):** |
| {response} |
| |
| **Reference Answer:** |
| {reference} |
| |
| **Output:** |
| ''' |
|
|
|
|
|
|
| def parse_im_sections(text): |
| |
| sections = re.findall(r"<\|im_start\|>(.*?)<\|im_end\|>", text, re.DOTALL) |
| parsed = {} |
| for section in sections: |
| try: |
| |
| role, content = section.split("\n", 1) |
| parsed[role.strip()] = content.strip() |
| except ValueError: |
| print(f"Skipping malformed section: {section}") |
| return parsed |
|
|
| def extract_last_non_empty_line(text, role="assistant"): |
| |
| pattern = fr"<\|im_start\|>{role}(.*?)(?:<\|im_start\|>|<\|endoftext\|>|<\|eot_id\|>|$)" |
| match = re.search(pattern, text, re.DOTALL) |
| if match: |
| content = match.group(1).strip() |
| |
| lines = [line for line in content.splitlines() if line.strip()] |
| if lines: |
| last_non_empty_line=lines[-1] |
| else: |
| return "" |
| return last_non_empty_line |
| return "" |
|
|
|
|
| def reward_normalization(rewards): |
| if len(rewards) == 1: |
| return [0.0] |
| rewards = torch.tensor(rewards, dtype=torch.float64) |
| if rewards.std() == 0: |
| normalized_rewards = torch.zeros_like(rewards) |
| else: |
| normalized_rewards = (rewards - rewards.mean()) / rewards.std() |
|
|
| return normalized_rewards.tolist() |
|
|
|
|
| def strip_sequence(text, pad_token, eos_token): |
| pad_token_escaped = re.escape(pad_token) |
| eos_token_escaped = re.escape(eos_token) |
|
|
| pattern = f"^({eos_token_escaped}|{pad_token_escaped})+" |
| text = re.sub(pattern, "", text) |
|
|
| pattern = f"({eos_token_escaped}|{pad_token_escaped})+$" |
| text = re.sub(pattern, "", text) |
| return text |
|
|
|
|
| def group_reward_normalization(rewards, n_samples_per_prompt=4): |
| rewards = torch.tensor(rewards, dtype=torch.float64) |
| rewards = rewards.reshape(-1, n_samples_per_prompt) |
|
|
| mean = rewards.mean(dim=-1, keepdim=True) |
| std = rewards.std(dim=-1, keepdim=True) |
|
|
| normalized_rewards = torch.where(std == 0, torch.zeros_like(rewards), (rewards - mean) / std) |
|
|
| return normalized_rewards.flatten().tolist() |
|
|
|
|
| class RewardModelProxy: |
| def __init__(self, args): |
| self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True) |
| self.normalize_reward = args.normalize_reward |
| self.group_normalize_reward = args.group_normalize_reward |
| self.qa_dict = defaultdict(str) |
| self.load_dict(args.answer_path) |
| self.temperature = 0 |
| self.stop=[self.tokenizer.eos_token,"<|im_end|>"] |
| self.max_tokens=1 |
| self.prob_reward=args.prob_reward |
| self.log_path=args.log_path |
| self.vllm_model=args.vllm_model |
|
|
| def load_dict(self, path): |
| |
| with open(path, "r", encoding="utf-8") as file: |
| data = json.load(file) |
| for unit in data: |
| question = unit["query"][1]["content"] |
| label = unit["label"] |
| self.qa_dict[question] = label |
|
|
| if self.qa_dict: |
| sample_question, sample_label = next(iter(self.qa_dict.items())) |
| print("Sample Question:", sample_question) |
| print("Sample Label:", sample_label) |
| else: |
| print("qa_dict is empty.") |
|
|
|
|
| async def process_sample(self,query): |
| query = strip_sequence(query, self.tokenizer.pad_token, self.tokenizer.eos_token)+ self.tokenizer.eos_token |
| question = parse_im_sections(query)["user"] |
| answer = extract_last_non_empty_line(query, role="assistant") |
| if not answer.strip(): |
| return 0.0 |
| else: |
| prompt_question = PROMPT_critic_updated.format(question=question, reference=self.qa_dict[question], response=answer) |
| return await self.get_reward_from_vllm(prompt_question) |
|
|
| async def get_reward_from_vllm(self, query): |
| """Retrieve model judgment reward (with probability analysis)""" |
| max_retries = 10 |
| delay=10 |
| for attempt in range(max_retries): |
| try: |
| response = await client.chat.completions.create( |
| model=self.vllm_model, |
| messages=[ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": query}, |
| ], |
| temperature=self.temperature, |
| max_tokens=self.max_tokens, |
| stop=self.stop, |
| logprobs=True, |
| top_logprobs=10 |
| ) |
| return self.calculate_reward_from_logprobs(response) |
| |
| except Exception as e: |
| print(f"Attempt {attempt+1} failed: {str(e)}, retrying in {delay} seconds...") |
| await asyncio.sleep(delay) |
| print(f"Failed after {max_retries} retries, query content: {query[:200]}...") |
| return 0.0 |
|
|
| def calculate_reward_from_logprobs(self, response): |
| """Calculate normalized reward based on log probabilities""" |
| |
| logprobs = response.choices[0].logprobs.content[0].top_logprobs |
| token_probs = {token.token: math.exp(token.logprob) for token in logprobs} |
| |
| |
| yes_prob = sum(prob for token, prob in token_probs.items() if token.lower().strip()=="yes") |
| no_prob = sum(prob for token, prob in token_probs.items()if token.lower().strip()=="no") |
| total = yes_prob + no_prob |
| if total == 0: |
| return 0.0 |
| if self.prob_reward: |
| print(yes_prob/total) |
| return yes_prob / total |
| return 1.0 if yes_prob > no_prob else 0.0 |
|
|
| async def get_reward(self, queries): |
| print("Processing queries[0]: {}".format(queries[0])) |
| tasks = [self.process_sample(query) for query in queries] |
| scores = await asyncio.gather(*tasks) |
| print("Generated scores: {}".format(scores)) |
| if self.log_path: |
| with open(self.log_path, 'a', encoding='utf-8') as f: |
| unit = { |
| "query_list": queries if isinstance(queries, list) else [], |
| "hard_score_list": scores if isinstance(scores, list) else [] |
| } |
| json.dump(unit, f, ensure_ascii=False) |
| f.write('\n') |
| if self.normalize_reward: |
| return reward_normalization(scores) |
| elif self.group_normalize_reward: |
| return group_reward_normalization(scores) |
| else: |
| return scores |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("--tokenizer_path", type=str, default=None) |
| parser.add_argument("--answer_path", type=str, default=None) |
| parser.add_argument("--prob_reward", action="store_true", default=False) |
| parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") |
| parser.add_argument("--group_normalize_reward", action="store_true", default=False, help="Enable Group Reward Normazation") |
| parser.add_argument("--port", type=int, default=5000, help="Port number for the server") |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server") |
| parser.add_argument("--log_path", type=str, default=None) |
| parser.add_argument("--vllm_url", type=str, default=None) |
| parser.add_argument("--vllm_model", type=str, default=None) |
| args = parser.parse_args() |
| openai_api_key = "EMPTY" |
| openai_api_base = args.vllm_url |
|
|
| client = AsyncOpenAI( |
| api_key=openai_api_key, |
| base_url=openai_api_base, |
| ) |
|
|
| |
| reward_model = RewardModelProxy(args) |
| app = FastAPI() |
|
|
|
|
| @app.post("/get_reward") |
| async def get_reward(request: Request): |
| data = await request.json() |
| queries = data.get("query") |
| rewards = await reward_model.get_reward(queries) |
| result = {"rewards": rewards} |
| print(f"Sent JSON response: {result}") |
| return JSONResponse(result) |
|
|
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
|
|