| import random |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from datasets import load_dataset |
|
|
| from .base import Benchmarker |
| from .registry import BENCHMARKS |
| from .utils import create_simple_sgl_function |
|
|
| GPQA_QUERY_TEMPLATE = """ |
| Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. |
| |
| {Question} |
| |
| A) {A} |
| B) {B} |
| C) {C} |
| D) {D} |
| """.strip() |
|
|
|
|
| def generate_question(row: Dict[str, Any]) -> str: |
| gold_index = random.randint(0, 3) |
| choices = [ |
| row["Incorrect Answer 1"], |
| row["Incorrect Answer 2"], |
| row["Incorrect Answer 3"], |
| ] |
| choices.insert(gold_index, row["Correct Answer"]) |
|
|
| question = GPQA_QUERY_TEMPLATE.format( |
| Question=row["Question"].strip(), |
| A=choices[0].strip(), |
| B=choices[1].strip(), |
| C=choices[2].strip(), |
| D=choices[3].strip(), |
| ) |
|
|
| |
| answer = ["A", "B", "C", "D"][gold_index] |
| return question, answer |
|
|
|
|
| @BENCHMARKS.register("gpqa") |
| class GPQABenchmarker(Benchmarker): |
| """GPQA benchmark implementation.""" |
|
|
| def __init__(self, num_samples: Optional[int] = None): |
| super().__init__(num_samples, None) |
|
|
| def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: |
| """Load and preprocess GPQA dataset.""" |
| |
| ds = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"] |
|
|
| questions = [] |
| labels = [] |
| for i in range((len(ds))): |
| if self.num_samples is not None and i >= self.num_samples: |
| break |
|
|
| question_text, answer = generate_question(ds[i]) |
| questions.append({"question": question_text}) |
| labels.append(answer) |
| return questions, labels |
|
|
| def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[int]: |
| if "Answer: " not in output: |
| return None |
| return output.split("Answer: ")[1].strip() |
|
|
| def compute_accuracy( |
| self, predictions: List[Any], labels: List[Any] |
| ) -> Optional[float]: |
| if not labels or len(labels) == 0: |
| return None |
| correct = sum(1 for pred, label in zip(predictions, labels) if pred == label) |
| return correct / len(labels) if len(labels) > 0 else 0.0 |
|
|
| def create_sgl_function(self): |
| return create_simple_sgl_function( |
| function_name="get_gpqa_mcq_answer", |
| answer_key="answer", |
| max_tokens=self.get_max_new_tokens(), |
| ) |
|
|