| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from datasets import load_dataset |
|
|
| from .base import Benchmarker |
| from .registry import BENCHMARKS |
| from .utils import create_simple_sgl_function |
|
|
|
|
| def generate_question(row: Dict[str, Any]) -> str: |
| question = row["problem"].strip() |
| return question |
|
|
|
|
| @BENCHMARKS.register("simpleqa") |
| class SimpleQABenchmarker(Benchmarker): |
| """SimpleQA benchmark implementation.""" |
|
|
| def __init__(self, num_samples: Optional[int] = None): |
| super().__init__(num_samples, None) |
|
|
| def load_data(self) -> Tuple[List[Dict[str, Any]], List[int]]: |
| |
| ds = load_dataset("basicv8vc/SimpleQA")["test"] |
|
|
| questions = [] |
| labels = [] |
| for i in range((len(ds))): |
| if self.num_samples is not None and i >= self.num_samples: |
| break |
|
|
| question_text = generate_question(ds[i]) |
| questions.append({"question": question_text}) |
| labels.append(None) |
| return questions, labels |
|
|
| def create_sgl_function(self): |
| return create_simple_sgl_function( |
| function_name="get_simpleqa_answer", |
| answer_key="answer", |
| max_tokens=self.get_max_new_tokens(), |
| ) |
|
|