| |
|
| | """ |
| | This script is adapted from Qwen2.5-Math |
| | https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/grader.py |
| | """ |
| |
|
| | import re |
| | import regex |
| | import multiprocessing |
| | from math import isclose |
| | from typing import Union |
| | from collections import defaultdict |
| |
|
| | from sympy import simplify, N |
| | from sympy.parsing.sympy_parser import parse_expr |
| | from sympy.parsing.latex import parse_latex |
| |
|
| |
|
| | def latex2sympy(sympy: str, variable_values={}): |
| | |
| | global frac_type |
| | if sympy.find(r'\frac') != -1: |
| | frac_type = r'\frac' |
| | if sympy.find(r'\dfrac') != -1: |
| | frac_type = r'\dfrac' |
| | if sympy.find(r'\tfrac') != -1: |
| | frac_type = r'\tfrac' |
| | sympy = sympy.replace(r'\dfrac', r'\frac') |
| | sympy = sympy.replace(r'\tfrac', r'\frac') |
| | |
| | sympy = sympy.replace(r'\mathrm{T}', 'T', -1) |
| | |
| | sympy = sympy.replace(r'\mathrm{d}', 'd', -1).replace(r'{\rm d}', 'd', -1) |
| | |
| | sympy = sympy.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1) |
| | |
| | sympy = re.sub(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}", r"\\frac{(\1)!}{((\1)-(\2))!}", sympy) |
| | |
| | sympy = sympy.replace(r'\displaystyle', ' ', -1) |
| | |
| | sympy = sympy.replace(r'\quad', ' ', -1).replace(r'\qquad', ' ', -1).replace(r'~', ' ', -1).replace(r'\,', ' ', -1) |
| | |
| | sympy = sympy.replace(r'$', ' ', -1) |
| |
|
| | |
| | global VARIABLE_VALUES |
| | if len(variable_values) > 0: |
| | VARIABLE_VALUES = variable_values |
| | else: |
| | VARIABLE_VALUES = {} |
| |
|
| | |
| | matherror = MathErrorListener(sympy) |
| |
|
| | |
| | stream = InputStream(sympy) |
| | lex = PSLexer(stream) |
| | lex.removeErrorListeners() |
| | lex.addErrorListener(matherror) |
| |
|
| | tokens = CommonTokenStream(lex) |
| | parser = PSParser(tokens) |
| |
|
| | |
| | parser.removeErrorListeners() |
| | parser.addErrorListener(matherror) |
| |
|
| | |
| | return_data = None |
| | math = parser.math() |
| |
|
| | |
| | if math.relation_list(): |
| | return_data = [] |
| |
|
| | |
| | relation_list = math.relation_list().relation_list_content() |
| | for list_item in relation_list.relation(): |
| | expr = convert_relation(list_item) |
| | return_data.append(expr) |
| |
|
| | |
| | else: |
| | relation = math.relation() |
| | return_data = convert_relation(relation) |
| |
|
| | return return_data |
| |
|
| |
|
| | def math_answer_cleaning(answer, dataset_name): |
| | """ |
| | remove irrelevant strings and unify the answer format before checking whether the answers are equal |
| | """ |
| | def _is_completely_wrapped_by_text(input_string): |
| | pattern = r'^\\text{(.*)}$' |
| | match = re.match(pattern, input_string) |
| | if match: |
| | |
| | extracted_content = match.group(1) |
| | extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "") |
| | return extracted_content |
| | else: |
| | return None |
| |
|
| | |
| | extracted_content = _is_completely_wrapped_by_text(answer) |
| | answer = extracted_content if extracted_content else answer |
| | |
| | |
| | answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "") |
| | |
| | answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{") |
| | |
| | answer = answer.replace("^\circ", "") |
| | answer = answer.replace("^{\circ}", "") |
| | |
| | answer = answer.replace("\quad", "") |
| | |
| | answer = answer.replace(" ", "") |
| | |
| | answer = answer.replace("\n", "").replace("\\n", "") |
| | |
| | answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer) |
| | |
| | answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer) |
| | |
| | answer = re.sub(r'\\,\\text\{.*?\}', '', answer) |
| | |
| | answer = re.sub(r'\\text\{.*?\}', '', answer) |
| | |
| | answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer) |
| | |
| | answer = answer.lower() |
| |
|
| | if dataset_name == "collegemath": |
| | |
| | answer = re.sub(r'\\mathrm\{.*?\}', '', answer) |
| | |
| | answer = re.sub(r'\$\([^)]*\)', '', answer) |
| | if answer.endswith("-"): |
| | answer = answer[:-1] |
| | if answer.endswith("."): |
| | answer = answer[:-1] |
| | if answer.endswith("hours"): |
| | answer = answer[:-len("hours")] |
| | |
| | if "=" in answer: |
| | answer = answer.split("=", 1)[1] |
| | if ":" in answer: |
| | answer = answer.split(":", 1)[1] |
| | |
| | answer = answer.replace("\\emptyset", "\\oslash") |
| | if dataset_name == "gsm8k": |
| | |
| | answer = answer.replace(',', '') |
| | if dataset_name == "gaokao2023en": |
| | unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes'] |
| | for unit in unit_strings: |
| | answer = answer.replace(unit, "") |
| |
|
| | return answer |
| |
|
| |
|
| | def extract_final_answer(output): |
| | pattern_re = re.compile(r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}", re.DOTALL) |
| | all_matches = pattern_re.findall(output) |
| |
|
| | if len(all_matches) >= 1: |
| | extracted_answer = all_matches[-1] |
| | else: |
| | extracted_answer = None |
| | |
| | return extracted_answer, all_matches |
| |
|
| |
|
| | def round_number(answer): |
| | def _is_float(string): |
| | try: |
| | float(string) |
| | return True |
| | except: |
| | return False |
| |
|
| | if _is_float(answer) and float(answer) < 1: |
| | |
| | |
| | return f"{float(answer):.2g}" |
| | |
| | return answer |
| |
|
| |
|
| | def choice_answer_clean(pred: str): |
| | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") |
| | |
| | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) |
| | if tmp: |
| | pred = tmp |
| | else: |
| | pred = [pred.strip().strip(".")] |
| | pred = pred[-1] |
| | |
| | pred = pred.rstrip(".").rstrip("/") |
| | return pred |
| |
|
| |
|
| | def parse_digits(num): |
| | num = regex.sub(",", "", str(num)) |
| | try: |
| | return float(num) |
| | except: |
| | if num.endswith("%"): |
| | num = num[:-1] |
| | if num.endswith("\\"): |
| | num = num[:-1] |
| | try: |
| | return float(num) / 100 |
| | except: |
| | pass |
| | return None |
| |
|
| |
|
| | def is_digit(num): |
| | |
| | return parse_digits(num) is not None |
| |
|
| |
|
| | def str_to_pmatrix(input_str): |
| | input_str = input_str.strip() |
| | matrix_str = re.findall(r"\{.*,.*\}", input_str) |
| | pmatrix_list = [] |
| |
|
| | for m in matrix_str: |
| | m = m.strip("{}") |
| | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" |
| | pmatrix_list.append(pmatrix) |
| |
|
| | return ", ".join(pmatrix_list) |
| |
|
| |
|
| | def math_equal( |
| | prediction: Union[bool, float, str], |
| | reference: Union[float, str], |
| | include_percentage: bool = True, |
| | is_close: bool = True, |
| | timeout: bool = False, |
| | ) -> bool: |
| | """ |
| | Exact match of math if and only if: |
| | 1. numerical equal: both can convert to float and are equal |
| | 2. symbolic equal: both can convert to sympy expression and are equal |
| | """ |
| | if prediction is None or reference is None: |
| | return False |
| | if str(prediction.strip().lower()) == str(reference.strip().lower()): |
| | return True |
| | if ( |
| | reference in ["A", "B", "C", "D", "E"] |
| | and choice_answer_clean(prediction) == reference |
| | ): |
| | return True |
| |
|
| | |
| | if fraction_equal(prediction, reference): |
| | return True |
| |
|
| | try: |
| | if round_number(prediction) == round_number(reference): |
| | return True |
| | if is_digit(prediction) and is_digit(reference): |
| | prediction = parse_digits(prediction) |
| | reference = parse_digits(reference) |
| | |
| | if include_percentage: |
| | gt_result = [reference / 100, reference, reference * 100] |
| | else: |
| | gt_result = [reference] |
| | for item in gt_result: |
| | try: |
| | if is_close: |
| | if numeric_equal(prediction, item): |
| | return True |
| | else: |
| | if item == prediction: |
| | return True |
| | except Exception: |
| | continue |
| | return False |
| | except: |
| | pass |
| |
|
| | if not prediction and prediction not in [0, False]: |
| | return False |
| |
|
| | |
| | reference = str(reference).strip() |
| | prediction = str(prediction).strip() |
| |
|
| | |
| | if "pmatrix" in prediction and not "pmatrix" in reference: |
| | reference = str_to_pmatrix(reference) |
| |
|
| | |
| | pred_str, ref_str = prediction, reference |
| | if ( |
| | prediction.startswith("[") |
| | and prediction.endswith("]") |
| | and not reference.startswith("(") |
| | ) or ( |
| | prediction.startswith("(") |
| | and prediction.endswith(")") |
| | and not reference.startswith("[") |
| | ): |
| | pred_str = pred_str.strip("[]()") |
| | ref_str = ref_str.strip("[]()") |
| | for s in ["{", "}", "(", ")"]: |
| | ref_str = ref_str.replace(s, "") |
| | pred_str = pred_str.replace(s, "") |
| | if pred_str.lower() == ref_str.lower(): |
| | return True |
| |
|
| | |
| | if ( |
| | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None |
| | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None |
| | ): |
| | pred_parts = prediction[1:-1].split(",") |
| | ref_parts = reference[1:-1].split(",") |
| | if len(pred_parts) == len(ref_parts): |
| | if all( |
| | [ |
| | math_equal( |
| | pred_parts[i], ref_parts[i], include_percentage, is_close |
| | ) |
| | for i in range(len(pred_parts)) |
| | ] |
| | ): |
| | return True |
| | if ( |
| | ( |
| | prediction.startswith("\\begin{pmatrix}") |
| | or prediction.startswith("\\begin{bmatrix}") |
| | ) |
| | and ( |
| | prediction.endswith("\\end{pmatrix}") |
| | or prediction.endswith("\\end{bmatrix}") |
| | ) |
| | and ( |
| | reference.startswith("\\begin{pmatrix}") |
| | or reference.startswith("\\begin{bmatrix}") |
| | ) |
| | and ( |
| | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") |
| | ) |
| | ): |
| | pred_lines = [ |
| | line.strip() |
| | for line in prediction[ |
| | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") |
| | ].split("\\\\") |
| | if line.strip() |
| | ] |
| | ref_lines = [ |
| | line.strip() |
| | for line in reference[ |
| | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") |
| | ].split("\\\\") |
| | if line.strip() |
| | ] |
| | matched = True |
| | if len(pred_lines) == len(ref_lines): |
| | for pred_line, ref_line in zip(pred_lines, ref_lines): |
| | pred_parts = pred_line.split("&") |
| | ref_parts = ref_line.split("&") |
| | if len(pred_parts) == len(ref_parts): |
| | if not all( |
| | [ |
| | math_equal( |
| | pred_parts[i], |
| | ref_parts[i], |
| | include_percentage, |
| | is_close, |
| | ) |
| | for i in range(len(pred_parts)) |
| | ] |
| | ): |
| | matched = False |
| | break |
| | else: |
| | matched = False |
| | if not matched: |
| | break |
| | else: |
| | matched = False |
| | if matched: |
| | return True |
| |
|
| | if prediction.count("=") == 1 and reference.count("=") == 1: |
| | pred = prediction.split("=") |
| | pred = f"{pred[0].strip()} - ({pred[1].strip()})" |
| | ref = reference.split("=") |
| | ref = f"{ref[0].strip()} - ({ref[1].strip()})" |
| | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): |
| | return True |
| | elif ( |
| | prediction.count("=") == 1 |
| | and len(prediction.split("=")[0].strip()) <= 2 |
| | and "=" not in reference |
| | ): |
| | if math_equal( |
| | prediction.split("=")[1], reference, include_percentage, is_close |
| | ): |
| | return True |
| | elif ( |
| | reference.count("=") == 1 |
| | and len(reference.split("=")[0].strip()) <= 2 |
| | and "=" not in prediction |
| | ): |
| | if math_equal( |
| | prediction, reference.split("=")[1], include_percentage, is_close |
| | ): |
| | return True |
| |
|
| | |
| | if timeout: |
| | if call_with_timeout(symbolic_equal_process, prediction, reference): |
| | return True |
| | else: |
| | if symbolic_equal(prediction, reference): |
| | return True |
| |
|
| | return False |
| |
|
| |
|
| | def numeric_equal(prediction: float, reference: float): |
| | |
| | |
| | |
| | |
| | |
| | |
| | return isclose(reference, prediction, rel_tol=1e-4) |
| |
|
| |
|
| | def fraction_equal(prediction, reference): |
| | def _calculate_numbers(input_string): |
| | try: |
| | result = eval(input_string) |
| | return result |
| | except: |
| | return None |
| | |
| | reference = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', reference) |
| | prediction = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', prediction) |
| |
|
| | if reference == prediction: |
| | return True |
| |
|
| | reference = _calculate_numbers(reference) |
| | prediction = _calculate_numbers(prediction) |
| |
|
| | if reference and reference == prediction: |
| | return True |
| | |
| | return False |
| |
|
| | def symbolic_equal(a, b): |
| | def _parse(s): |
| | for f in [parse_latex, parse_expr, latex2sympy]: |
| | try: |
| | return f(s.replace("\\\\", "\\")) |
| | except: |
| | try: |
| | return f(s) |
| | except: |
| | pass |
| | return s |
| |
|
| | a = _parse(a) |
| | b = _parse(b) |
| |
|
| | |
| | try: |
| | if str(a) == str(b) or a == b: |
| | return True |
| | except: |
| | pass |
| |
|
| | |
| | try: |
| | if a.equals(b) or simplify(a - b) == 0: |
| | return True |
| | except: |
| | pass |
| |
|
| | |
| | try: |
| | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): |
| | return True |
| | except: |
| | pass |
| |
|
| | try: |
| | if numeric_equal(float(N(a)), float(N(b))): |
| | return True |
| | except: |
| | pass |
| |
|
| | |
| | try: |
| | |
| | if a.shape == b.shape: |
| | _a = a.applyfunc(lambda x: round(x, 3)) |
| | _b = b.applyfunc(lambda x: round(x, 3)) |
| | if _a.equals(_b): |
| | return True |
| | except: |
| | pass |
| |
|
| | return False |
| |
|
| |
|
| | def symbolic_equal_process(a, b, output_queue): |
| | result = symbolic_equal(a, b) |
| | output_queue.put(result) |
| |
|
| |
|
| | def math_equal_process(prediction, reference, output_queue): |
| | result = math_equal(prediction, reference, timeout=True) |
| | output_queue.put(result) |
| |
|
| |
|
| | def call_with_timeout(func, *args, timeout=1, **kwargs): |
| | output_queue = multiprocessing.Queue() |
| | process_args = args + (output_queue,) |
| | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) |
| | process.start() |
| | process.join(timeout) |
| |
|
| | if process.is_alive(): |
| | process.terminate() |
| | process.join() |
| | return False |
| |
|
| | return output_queue.get() |
| |
|
| |
|
| | def check_correctness_of_multiple_answer_cases(prediction, reference, all_matches): |
| |
|
| | if prediction.replace(",", "").replace("$", "") == reference.replace(",", "").replace("$", ""): |
| | return True |
| | |
| | if not prediction.split("=")[-1] == reference.split("=")[-1].replace("$", ""): |
| | return False |
| |
|
| | if "," in reference or "or" in reference or "and" in reference: |
| | |
| | if len(all_matches) <= 1: |
| | return False |
| |
|
| | prediction1 = prediction.split("=")[-1] |
| | prediction2 = all_matches[-2].split("=")[-1] |
| | reference = reference.replace("$", "") |
| | if "or" in reference: |
| | gold_list = reference.split("or", 1) |
| | elif "and" in reference: |
| | gold_list = reference.split("and", 1) |
| | else: |
| | gold_list = reference.split(",", 1) |
| | |
| | reference1 = gold_list[-1].split("=")[-1] |
| | reference2 = gold_list[-2].split("=")[-1] |
| | |
| | if math_equal(prediction1, reference1) and math_equal(prediction2, reference2): |
| | return True |
| | elif math_equal(prediction2, reference1) and math_equal(prediction1, reference2): |
| | return True |
| |
|
| | return False |
| | |
| | else: |
| | return True |
| |
|
| |
|
| | def is_equal(model_output, reference, dataset_name): |
| | |
| | extracted_model_answer, all_matches = extract_final_answer(model_output) |
| | if extracted_model_answer is None or reference is None: |
| | return False |
| |
|
| | extracted_model_answer = math_answer_cleaning(extracted_model_answer, dataset_name) |
| | reference = math_answer_cleaning(reference, dataset_name) |
| |
|
| | |
| | if call_with_timeout(math_equal_process, extracted_model_answer, reference): |
| | return True |
| | |
| | if dataset_name == "collegemath": |
| | return check_correctness_of_multiple_answer_cases(extracted_model_answer, reference, all_matches) |
| |
|
| | return False |
| |
|