| """ |
| Base class for benchmark implementations. |
| """ |
|
|
| import time |
| from abc import ABC, abstractmethod |
| from argparse import Namespace |
| from typing import Any, Callable, Dict, List, Optional, Tuple |
|
|
| from sglang import set_default_backend |
| from sglang.test.test_utils import select_sglang_backend |
|
|
| from .utils import compute_metrics |
|
|
|
|
| class Benchmarker(ABC): |
| """ |
| Base class for benchmark implementations. |
| |
| Subclasses should implement: |
| - load_data(): Load and preprocess dataset |
| - create_sgl_function(): Create the SGL function for inference |
| |
| Optional overrides: |
| - extract_answer(): Extract answer from model output (if needed) |
| - compute_accuracy(): Compute accuracy metric (if applicable) |
| - get_answer_keys(): Get list of answer keys for multi-turn conversations |
| |
| Args: |
| num_samples: The number of samples to run the benchmark on. If not provided, all questions will be used. |
| subset: The subset of the dataset to run the benchmark on. If not provided, all subsets will be used. |
| """ |
|
|
| def __init__( |
| self, num_samples: Optional[int] = None, subset: Optional[List[str]] = None |
| ): |
| self.num_samples = num_samples |
| self.subset = subset |
|
|
| @abstractmethod |
| def load_data(self) -> Tuple[List[Dict[str, Any]], List[Any]]: |
| """ |
| Load and preprocess the dataset. |
| |
| Returns: |
| Tuple of (questions, labels) where: |
| - questions: List of question dicts for SGL function |
| - labels: List of ground truth labels (can be None if not applicable) |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def create_sgl_function(self) -> Callable: |
| """ |
| Create the SGL function for inference. |
| |
| Returns: |
| SGL function decorated with @sgl.function |
| """ |
| raise NotImplementedError |
|
|
| def extract_answer(self, output: str, label: Optional[Any] = None) -> Optional[Any]: |
| """ |
| Extract answer from model output. |
| |
| Args: |
| output: Raw model output string |
| label: Optional ground truth label for reference |
| |
| Returns: |
| Extracted answer, or None if extraction fails |
| """ |
| return output |
|
|
| def compute_accuracy( |
| self, predictions: List[Any], labels: List[Any] |
| ) -> Optional[float]: |
| """ |
| Compute accuracy metric. |
| |
| Args: |
| predictions: List of predicted answers |
| labels: List of ground truth labels |
| |
| Returns: |
| Accuracy score (0-1), or None if not applicable |
| """ |
| return None |
|
|
| def get_answer_keys(self) -> Optional[List[str]]: |
| """ |
| Get list of answer keys for multi-turn conversations. |
| |
| Returns: |
| List of answer keys (e.g., ["answer_1", "answer_2"]), or None for single-turn |
| """ |
| return None |
|
|
| def get_max_new_tokens(self) -> int: |
| """ |
| Get maximum number of new tokens to generate. |
| |
| Returns: |
| Maximum tokens (default: 2048) |
| """ |
| return 2048 |
|
|
| def run( |
| self, |
| host: str, |
| port: int, |
| batch_size: int, |
| max_new_tokens: int = None, |
| num_runs: int = 1, |
| ): |
| """ |
| Run the benchmark evaluation. |
| |
| This method handles the common workflow: |
| 1. Initialize backend |
| 2. Load data |
| 3. Create SGL function |
| 4. Run inference loops |
| 5. Compute metrics |
| 6. Print results |
| |
| Args: |
| host (str): The host of the SGLang server |
| port (int): The port of the SGLang server |
| batch_size (int): The number of prompts to process in parallel |
| num_samples (int): The number of samples to run the benchmark on. If not provided, all samples will be used. |
| max_new_tokens (int): Maximum number of new tokens to generate, default is 2048 |
| num_runs (int): The number of times to run this benchmark, default is 1. You can set it to a larger number if you want to get more stable results. |
| """ |
| if not host.startswith(("http://", "https://")): |
| host = f"http://{host}" |
| |
| sglang_args = Namespace(host=host, port=port, backend="srt-no-parallel") |
| set_default_backend(select_sglang_backend(sglang_args)) |
|
|
| |
| questions, labels = self.load_data() |
| if len(questions) == 0: |
| print("No valid questions found. Please check the dataset format.") |
| return |
|
|
| |
| sgl_function = self.create_sgl_function() |
|
|
| |
| metrics_list = [] |
| answer_keys = self.get_answer_keys() |
| max_new_tokens = max_new_tokens or self.get_max_new_tokens() |
|
|
| for _ in range(num_runs): |
| tic = time.perf_counter() |
| states = sgl_function.run_batch( |
| questions, |
| temperature=0, |
| max_new_tokens=max_new_tokens, |
| num_threads=batch_size, |
| progress_bar=True, |
| ) |
| latency = time.perf_counter() - tic |
|
|
| |
| predictions = [] |
| primary_answer_key = answer_keys[0] if answer_keys else "answer" |
| for i in range(len(states)): |
| |
| output = states[i][primary_answer_key] |
| if isinstance(output, str): |
| extracted = self.extract_answer( |
| output, |
| (labels[i] if labels and i < len(labels) else None), |
| ) |
| else: |
| extracted = output |
| predictions.append(extracted) |
|
|
| |
| accuracy = None |
| |
| has_labels_list = labels and len(labels) > 0 |
|
|
| if has_labels_list: |
| |
| |
| accuracy = self.compute_accuracy(predictions, labels) |
| if accuracy is not None: |
| valid_count = sum(1 for p in predictions if p is not None) |
| if valid_count < len(predictions): |
| print( |
| f"Warning: {len(predictions) - valid_count} predictions could not be extracted." |
| ) |
|
|
| |
| metrics = compute_metrics( |
| states, |
| latency, |
| answer_key=primary_answer_key, |
| additional_answer_keys=( |
| answer_keys[1:] if answer_keys and len(answer_keys) > 1 else None |
| ), |
| ) |
| |
| |
| if has_labels_list: |
| metrics.accuracy = ( |
| accuracy |
| ) |
| if accuracy is not None: |
| metrics.num_valid_predictions = sum( |
| 1 for p in predictions if p is not None |
| ) |
|
|
| metrics_list.append(metrics) |
| return metrics_list |
|
|