Lekr0's picture
Add files using upload-large-folder tool
212a146 verified
import random
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from .base import Benchmarker
from .registry import BENCHMARKS
from .utils import create_simple_sgl_function
GPQA_QUERY_TEMPLATE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()
def generate_question(row: Dict[str, Any]) -> str:
gold_index = random.randint(0, 3)
choices = [
row["Incorrect Answer 1"],
row["Incorrect Answer 2"],
row["Incorrect Answer 3"],
]
choices.insert(gold_index, row["Correct Answer"])
question = GPQA_QUERY_TEMPLATE.format(
Question=row["Question"].strip(),
A=choices[0].strip(),
B=choices[1].strip(),
C=choices[2].strip(),
D=choices[3].strip(),
)
# 0 means A, 1 means B, 2 means C, 3 means D
answer = ["A", "B", "C", "D"][gold_index]
return question, answer
@BENCHMARKS.register("gpqa")
class GPQABenchmarker(Benchmarker):
"""GPQA benchmark implementation."""
def __init__(self, num_samples: Optional[int] = None):
super().__init__(num_samples, None)
def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]:
"""Load and preprocess GPQA dataset."""
# Read data
ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"]
questions = []
labels = []
for i in range((len(ds))):
if self.num_samples is not None and i >= self.num_samples:
break
question_text, answer = generate_question(ds[i])
questions.append({"question": question_text})
labels.append(answer)
return questions, labels
def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]:
if "Answer: " not in output:
return None
return output.split("Answer: ")[1].strip()
def compute_accuracy(
self, predictions: List[Any], labels: List[Any]
) -> Optional[float]:
if not labels or len(labels) == 0:
return None
correct = sum(1 for pred, label in zip(predictions, labels) if pred == label)
return correct / len(labels) if len(labels) > 0 else 0.0
def create_sgl_function(self):
return create_simple_sgl_function(
function_name="get_gpqa_mcq_answer",
answer_key="answer",
max_tokens=self.get_max_new_tokens(),
)