text2sql_tani / src /execution_reward.py
tjhalanigrid's picture
Added full project
cf17729
# from __future__ import annotations
# import hashlib
# import os
# import queue
# import re
# import sqlite3
# import threading
# import time
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from dataclasses import dataclass
# from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
# # --- CACHE CONTROL ---
# USE_CACHE = True
# _REWARD_CACHE: Dict[str, float] = {}
# def set_use_cache(enabled: bool):
# """Dynamically toggle the reward cache for benchmarks."""
# global USE_CACHE
# USE_CACHE = enabled
# def _normalize_sql(sql: str) -> str:
# if not isinstance(sql, str):
# return ""
# s = sql.strip()
# if s.startswith("```"):
# s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
# s = re.sub(r"\n?```$", "", s).strip()
# if s.lower().startswith("sql:"):
# s = s[4:].strip()
# if ";" in s:
# s = s.split(";", 1)[0].strip()
# return s
# def _connect_readonly(db_path: str) -> sqlite3.Connection:
# uri = f"file:{os.path.abspath(db_path)}?mode=ro"
# conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
# conn.execute("PRAGMA query_only = ON;")
# conn.execute("PRAGMA foreign_keys = ON;")
# return conn
# DEFAULT_QUERY_TIMEOUT_S = 2.0
# def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S) -> None:
# start = time.monotonic()
# def _handler() -> int:
# return 1 if (time.monotonic() - start) > timeout_s else 0
# conn.set_progress_handler(_handler, 10_000)
# def _list_tables(conn: sqlite3.Connection) -> List[str]:
# try:
# cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';")
# return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
# except sqlite3.Error:
# return []
# def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool:
# s = sql.lower()
# for t in table_names:
# tl = t.lower()
# if not tl:
# continue
# if re.search(rf"\b{re.escape(tl)}\b", s):
# return True
# return False
# def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
# try:
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
# conn.execute(f"EXPLAIN QUERY PLAN {sql}")
# return True
# except sqlite3.Error:
# return False
# def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]:
# try:
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
# cur = conn.execute(sql)
# rows = cur.fetchmany(max_rows)
# norm_rows = [tuple(r) for r in rows]
# return True, norm_rows, None
# except sqlite3.Error as e:
# return False, [], str(e)
# _SQL_KEYWORDS_TO_IGNORE = {
# "select", "from", "where", "join", "inner", "left", "right", "full", "outer",
# "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect",
# "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case",
# "when", "then", "else", "end", "asc", "desc"
# }
# _SQL_FUNCTIONS_TO_IGNORE = {
# "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce",
# "round", "date", "datetime", "strftime"
# }
# # --- LIGHTWEIGHT PARSING ---
# def is_valid_select(sql: str):
# sql = sql.strip().lower()
# return sql.startswith("select") or sql.startswith("with")
# def extract_tables(sql: str) -> List[str]:
# sql = sql.lower()
# if "join" not in sql:
# tables = re.findall(r'from\s+(\w+)', sql)
# return list(set(tables))
# tables = re.findall(r'from\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
# joins = re.findall(r'join\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql)
# return list(set(tables + joins))
# def extract_columns(sql: str) -> List[str]:
# sql = sql.lower()
# match = re.search(r'select\s+(.*?)\s+from', sql)
# if not match:
# return []
# cols = match.group(1)
# if cols.strip() == "*":
# return ["*"]
# return [c.strip() for c in cols.split(",")]
# def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]:
# tables = set()
# columns = set()
# for t in _list_tables(conn):
# tl = t.lower()
# if not tl:
# continue
# tables.add(tl)
# try:
# cur = conn.execute(f'PRAGMA table_info("{t}")')
# for row in cur.fetchall():
# if row and isinstance(row[1], str):
# columns.add(row[1].lower())
# except sqlite3.Error:
# continue
# return tables, columns
# def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool:
# return a == b
# @dataclass
# class RewardDebugStats:
# total: int = 0
# parsed_ok: int = 0
# table_match: int = 0
# column_match: int = 0
# executed_ok: int = 0
# exact_match: int = 0
# _DEBUG = RewardDebugStats()
# def reset_debug_metrics() -> None:
# global _DEBUG
# _DEBUG = RewardDebugStats()
# def get_debug_metrics() -> dict:
# denom = max(_DEBUG.total, 1)
# return {
# "valid_sql_rate": _DEBUG.parsed_ok / denom,
# "table_match_rate": _DEBUG.table_match / denom,
# "column_match_rate": _DEBUG.column_match / denom,
# "execution_accuracy": _DEBUG.exact_match / denom,
# }
# EXECUTION_ERROR = "EXECUTION_ERROR"
# _RESULT_CACHE_LOCK = threading.Lock()
# _RESULT_CACHE: "Dict[str, Union[List[Tuple], str]]" = {}
# _RESULT_CACHE_MAX = 100_000
# def clear_result_cache() -> None:
# """Clear both DB query cache and reward cache."""
# with _RESULT_CACHE_LOCK:
# _RESULT_CACHE.clear()
# _REWARD_CACHE.clear()
# def _db_state_fingerprint(db_path: str) -> str:
# try:
# st = os.stat(db_path)
# return f"{st.st_mtime_ns}:{st.st_size}"
# except OSError:
# return "missing"
# def _result_cache_key(db_path: str, sql: str) -> str:
# fp = _db_state_fingerprint(db_path)
# payload = f"{fp}\0{sql}".encode("utf-8", errors="ignore")
# return hashlib.sha256(payload).hexdigest()
# class _ConnectionPool:
# def __init__(self, db_path: str, maxsize: int = 1) -> None:
# self.db_path = db_path
# self.pool = queue.LifoQueue(maxsize=maxsize)
# self.lock = threading.Lock()
# def acquire(self) -> sqlite3.Connection:
# try:
# return self.pool.get_nowait()
# except queue.Empty:
# with self.lock:
# try:
# return self.pool.get_nowait()
# except queue.Empty:
# return _connect_readonly(self.db_path)
# def release(self, conn: sqlite3.Connection) -> None:
# try:
# self.pool.put_nowait(conn)
# except queue.Full:
# try:
# conn.close()
# except Exception:
# pass
# _POOL_LOCK = threading.Lock()
# _POOLS: Dict[str, _ConnectionPool] = {}
# def _get_pool(db_path: str) -> _ConnectionPool:
# with _POOL_LOCK:
# pool = _POOLS.get(db_path)
# if pool is None:
# pool = _ConnectionPool(db_path=db_path, maxsize=1)
# _POOLS[db_path] = pool
# return pool
# class _PooledConnection:
# def __init__(self, db_path: str) -> None:
# self.db_path = db_path
# self.pool = _get_pool(db_path)
# self.conn: Optional[sqlite3.Connection] = None
# def __enter__(self) -> sqlite3.Connection:
# self.conn = self.pool.acquire()
# return self.conn
# def __exit__(self, exc_type, exc, tb) -> None:
# if self.conn is not None:
# self.pool.release(self.conn)
# self.conn = None
# def _cache_get(key: str) -> Optional[Union[List[Tuple], str]]:
# with _RESULT_CACHE_LOCK:
# return _RESULT_CACHE.get(key)
# def _cache_put(key: str, value: Union[List[Tuple], str]) -> None:
# with _RESULT_CACHE_LOCK:
# if len(_RESULT_CACHE) >= _RESULT_CACHE_MAX:
# _RESULT_CACHE.clear()
# _RESULT_CACHE[key] = value
# def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
# try:
# _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
# cur = conn.execute(sql)
# rows = cur.fetchmany(max_rows)
# return [tuple(r) for r in rows]
# except Exception:
# return EXECUTION_ERROR
# def execute_sql_cached(db_path: str, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]:
# if not USE_CACHE:
# with _PooledConnection(db_path) as conn:
# return execute_sql(conn, sql, max_rows=max_rows)
# key = _result_cache_key(db_path, sql)
# cached = _cache_get(key)
# if cached is not None:
# return cached
# with _PooledConnection(db_path) as conn:
# res = execute_sql(conn, sql, max_rows=max_rows)
# _cache_put(key, res)
# return res
# def execution_reward_timed(
# pred_sql: str, db_path: str, gold_sql: str, *, measure_plan: bool = False,
# ) -> Tuple[float, Dict[str, float]]:
# timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
# t0 = time.perf_counter()
# sql = _normalize_sql(pred_sql)
# gold = _normalize_sql(gold_sql)
# if not is_valid_select(sql):
# timings["parse_s"] = time.perf_counter() - t0
# return 0.0, timings
# t1 = time.perf_counter()
# timings["parse_s"] = t1 - t0
# if measure_plan:
# with _PooledConnection(db_path) as conn:
# p0 = time.perf_counter()
# _explain_query_plan(conn, sql)
# _explain_query_plan(conn, gold)
# timings["plan_s"] = time.perf_counter() - p0
# e0 = time.perf_counter()
# pred_res = execute_sql_cached(db_path, sql)
# if pred_res == EXECUTION_ERROR:
# timings["exec_s"] = time.perf_counter() - e0
# return 0.0, timings
# gold_res = execute_sql_cached(db_path, gold)
# timings["exec_s"] = time.perf_counter() - e0
# if gold_res == EXECUTION_ERROR:
# return 0.0, timings
# reward = -0.2
# reward += 0.2
# if _safe_results_equal(pred_res, gold_res):
# return 1.0, timings
# return max(-1.0, min(1.0, reward)), timings
# def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
# try:
# sql = _normalize_sql(pred_sql)
# gold = _normalize_sql(gold_sql)
# if not is_valid_select(sql):
# return -1.0
# reward = -0.2
# pred_tables = set(extract_tables(sql))
# gold_tables = set(extract_tables(gold))
# if pred_tables == gold_tables and len(gold_tables) > 0:
# reward += 0.3
# pred_cols = set(extract_columns(sql))
# gold_cols = set(extract_columns(gold))
# if gold_cols:
# overlap = len(pred_cols & gold_cols) / len(gold_cols)
# reward += 0.3 * overlap
# pred_res = execute_sql_cached(db_path, sql)
# if pred_res == EXECUTION_ERROR:
# return 0.0
# reward += 0.2
# gold_res = execute_sql_cached(db_path, gold)
# if gold_res == EXECUTION_ERROR:
# return 0.0
# if _safe_results_equal(pred_res, gold_res):
# return 1.0
# return max(-1.0, min(1.0, reward))
# except Exception:
# return 0.0
# def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
# if not USE_CACHE:
# return execution_reward(pred_sql, db_path, gold_sql)
# key = f"{db_path}|{pred_sql}|{gold_sql}"
# if key not in _REWARD_CACHE:
# _REWARD_CACHE[key] = execution_reward(pred_sql, db_path, gold_sql)
# return _REWARD_CACHE[key]
# def execution_reward_batch_sequential(rollouts: Sequence[Tuple[str, str, str]]) -> List[float]:
# return [cached_execution_reward(pred_sql, db_path, gold_sql) for pred_sql, db_path, gold_sql in rollouts]
# def execution_reward_batch_parallel(rollouts: Sequence[Tuple[str, str, str]], *, max_workers: int = 20) -> List[float]:
# if not rollouts:
# return []
# unique_dbs = {db_path for _, db_path, _ in rollouts}
# worker_count = max(1, min(max_workers, len(unique_dbs)))
# results: List[Optional[float]] = [None] * len(rollouts)
# with ThreadPoolExecutor(max_workers=worker_count) as executor:
# futures = {
# executor.submit(cached_execution_reward, pred_sql, db_path, gold_sql): i
# for i, (pred_sql, db_path, gold_sql) in enumerate(rollouts)
# }
# for fut in as_completed(futures):
# idx = futures[fut]
# try:
# results[idx] = float(fut.result())
# except Exception:
# results[idx] = 0.0
# return [r if r is not None else 0.0 for r in results]
from __future__ import annotations
import os
import re
import sqlite3
import threading
import time
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Dict, List
from src.sql_validator import validate_sql_schema
# =========================================================
# 🔥 CONFIG FLAGS
# =========================================================
USE_SCHEMA_VALIDATION = True
USE_CACHE = True
DEFAULT_QUERY_TIMEOUT_S = 2.0
EXECUTION_ERROR = "EXECUTION_ERROR"
_REWARD_CACHE: Dict[str, float] = {}
# =========================================================
# 🔥 TASK 2: ERROR ANALYSIS + LOGGING
# =========================================================
ERROR_LOG_FILE = "results/error_logs.json"
def classify_error(sql: str) -> str:
sql = sql.lower()
if "join" in sql and " on " not in sql:
return "missing_join"
if "where" in sql and "=" not in sql and ">" not in sql and "<" not in sql:
return "wrong_where"
if "null" in sql:
return "null_handling"
if "group by" in sql and "count" not in sql:
return "wrong_groupby"
return "other"
def get_hint(error_type: str) -> str:
hints = {
"missing_join": "Add proper JOIN condition using ON.",
"wrong_where": "Check WHERE clause conditions.",
"null_handling": "Handle NULL values using IS NULL.",
"wrong_groupby": "Use aggregation functions with GROUP BY.",
"other": "Check SQL syntax and logic."
}
return hints.get(error_type, "Check query.")
def log_error(question: str, sql: str, error: str, error_type: str):
os.makedirs("results", exist_ok=True)
entry = {
"question": question,
"sql": sql,
"error": error,
"error_type": error_type,
"timestamp": time.time()
}
if os.path.exists(ERROR_LOG_FILE):
with open(ERROR_LOG_FILE, "r") as f:
logs = json.load(f)
else:
logs = []
logs.append(entry)
with open(ERROR_LOG_FILE, "w") as f:
json.dump(logs, f, indent=2)
# =========================================================
# CACHE/VALIDATION TOGGLES (Task 1)
# =========================================================
def set_use_cache(enabled: bool) -> None:
global USE_CACHE
USE_CACHE = bool(enabled)
def set_use_schema_validation(enabled: bool) -> None:
global USE_SCHEMA_VALIDATION
USE_SCHEMA_VALIDATION = bool(enabled)
# =========================================================
# SQL CLEANING
# =========================================================
def _normalize_sql(sql: str) -> str:
if not isinstance(sql, str):
return ""
s = sql.strip()
if s.startswith("```"):
s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip()
s = re.sub(r"\n?```$", "", s).strip()
if s.lower().startswith("sql:"):
s = s[4:].strip()
if ";" in s:
s = s.split(";", 1)[0].strip()
return s
# =========================================================
# DB EXECUTION
# =========================================================
def _connect_readonly(db_path: str):
uri = f"file:{os.path.abspath(db_path)}?mode=ro"
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
conn.execute("PRAGMA query_only = ON;")
conn.execute("PRAGMA foreign_keys = ON;")
return conn
def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S):
start = time.monotonic()
def handler():
return 1 if (time.monotonic() - start) > timeout_s else 0
conn.set_progress_handler(handler, 10_000)
def execute_sql(conn, sql):
try:
_with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
cur = conn.execute(sql)
return cur.fetchall()
except Exception:
return EXECUTION_ERROR
_RESULT_CACHE = {}
_RESULT_LOCK = threading.Lock()
def execute_sql_cached(db_path, sql):
key = f"{db_path}|{sql}"
if USE_CACHE:
with _RESULT_LOCK:
if key in _RESULT_CACHE:
return _RESULT_CACHE[key]
conn = _connect_readonly(db_path)
result = execute_sql(conn, sql)
conn.close()
if USE_CACHE:
with _RESULT_LOCK:
_RESULT_CACHE[key] = result
return result
def execute_sql_cached_conn(conn: sqlite3.Connection, db_path: str, sql: str):
"""
Like execute_sql_cached(), but reuses an existing connection.
Intended for 1-thread-per-DB workloads (Task 1).
"""
key = f"{db_path}|{sql}"
if USE_CACHE:
with _RESULT_LOCK:
if key in _RESULT_CACHE:
return _RESULT_CACHE[key]
result = execute_sql(conn, sql)
if USE_CACHE:
with _RESULT_LOCK:
_RESULT_CACHE[key] = result
return result
def clear_result_cache() -> None:
global _RESULT_CACHE, _REWARD_CACHE
with _RESULT_LOCK:
_RESULT_CACHE.clear()
_REWARD_CACHE.clear()
# =========================================================
# SQL PARSING
# =========================================================
def is_valid_select(sql):
return sql.lower().startswith("select") or sql.lower().startswith("with")
def extract_tables(sql):
return re.findall(r'from\s+(\w+)', sql.lower())
def extract_columns(sql):
match = re.search(r'select\s+(.*?)\s+from', sql.lower())
if not match:
return []
cols = match.group(1)
return ["*"] if cols.strip() == "*" else [c.strip() for c in cols.split(",")]
def get_sql_operations(sql: str):
sql = sql.lower()
ops = []
if "select" in sql: ops.append("SELECT")
if "where" in sql: ops.append("WHERE")
if "join" in sql: ops.append("JOIN")
if "group by" in sql: ops.append("GROUP_BY")
if "order by" in sql: ops.append("ORDER_BY")
return ops
def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool:
try:
_with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S)
conn.execute(f"EXPLAIN QUERY PLAN {sql}")
return True
except Exception:
return False
def execution_reward_timed(pred_sql: str, db_path: str, gold_sql: str, measure_plan: bool = False):
"""
Returns (reward, timings) where timings keys: parse_s, plan_s, exec_s.
Used by Task-1 benchmark to profile bottlenecks.
"""
timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0}
t0 = time.perf_counter()
sql = _normalize_sql(pred_sql)
gold = _normalize_sql(gold_sql)
if not is_valid_select(sql):
timings["parse_s"] = time.perf_counter() - t0
return 0.0, timings
t1 = time.perf_counter()
timings["parse_s"] = t1 - t0
conn = _connect_readonly(db_path)
try:
if measure_plan:
p0 = time.perf_counter()
_explain_query_plan(conn, sql)
_explain_query_plan(conn, gold)
timings["plan_s"] = time.perf_counter() - p0
e0 = time.perf_counter()
pred_res = execute_sql_cached_conn(conn, db_path, sql)
if pred_res == EXECUTION_ERROR:
timings["exec_s"] = time.perf_counter() - e0
return 0.0, timings
gold_res = execute_sql_cached_conn(conn, db_path, gold)
timings["exec_s"] = time.perf_counter() - e0
if gold_res == EXECUTION_ERROR:
return 0.0, timings
reward = -0.2 + 0.2
if pred_res == gold_res:
return 1.0, timings
return max(-1.0, min(1.0, reward)), timings
finally:
try:
conn.close()
except Exception:
pass
# =========================================================
# 🔥 FINAL REWARD FUNCTION (TASK 2 INTEGRATED)
# =========================================================
def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
try:
sql = _normalize_sql(pred_sql)
gold = _normalize_sql(gold_sql)
if not is_valid_select(sql):
return -1.0
reward = -0.2
# =========================
# SCHEMA VALIDATION (Task 3)
# =========================
if USE_SCHEMA_VALIDATION:
valid, _ = validate_sql_schema(sql, db_path)
if not valid:
error_type = classify_error(sql)
log_error("UNKNOWN", sql, "schema_invalid", error_type)
return 0.1
# =========================
# EXECUTION
# =========================
pred_res = execute_sql_cached(db_path, sql)
if pred_res == "EXECUTION_ERROR":
error_type = classify_error(sql)
log_error(
question="UNKNOWN",
sql=sql,
error="execution_error",
error_type=error_type
)
print(f"[ERROR] {error_type}")
print(f"[HINT] {get_hint(error_type)}")
return 0.1
reward += 0.2
gold_res = execute_sql_cached(db_path, gold)
if gold_res == "EXECUTION_ERROR":
return 0.1
if pred_res == gold_res:
return 1.0
return max(-1.0, min(1.0, reward))
except Exception as e:
log_error("UNKNOWN", pred_sql, str(e), "runtime_error")
return 0.0
# =========================================================
# BATCH EXECUTION (Task 1)
# =========================================================
def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float:
if not USE_CACHE:
return float(execution_reward(pred_sql, db_path, gold_sql))
key = f"{db_path}|{pred_sql}|{gold_sql}"
if key in _REWARD_CACHE:
return float(_REWARD_CACHE[key])
r = float(execution_reward(pred_sql, db_path, gold_sql))
_REWARD_CACHE[key] = r
return r
def execution_reward_batch_sequential(rollouts):
return [cached_execution_reward(p, d, g) for (p, d, g) in rollouts]
def execution_reward_batch_parallel(rollouts, max_workers=10):
results = [0.0] * len(rollouts)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(cached_execution_reward, p, d, g): i
for i, (p, d, g) in enumerate(rollouts)
}
for fut in as_completed(futures):
idx = futures[fut]
try:
results[idx] = fut.result()
except Exception:
results[idx] = 0.0
return results
def execution_reward_batch_parallel_by_db(rollouts, max_workers: int = 20):
"""
1 thread per DB path. Reuses a single readonly connection per DB worker.
Preserves input order.
"""
if not rollouts:
return []
by_db = {}
for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts):
by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql))
results = [0.0 for _ in range(len(rollouts))]
def _reward_with_conn(conn: sqlite3.Connection, pred_sql: str, db_path: str, gold_sql: str) -> float:
try:
sql = _normalize_sql(pred_sql)
gold = _normalize_sql(gold_sql)
if not is_valid_select(sql):
return -1.0
reward = -0.2
if USE_SCHEMA_VALIDATION:
valid, _ = validate_sql_schema(sql, db_path)
if not valid:
error_type = classify_error(sql)
log_error("UNKNOWN", sql, "schema_invalid", error_type)
return 0.1
pred_res = execute_sql_cached_conn(conn, db_path, sql)
if pred_res == EXECUTION_ERROR:
error_type = classify_error(sql)
log_error("UNKNOWN", sql, "execution_error", error_type)
return 0.1
reward += 0.2
gold_res = execute_sql_cached_conn(conn, db_path, gold)
if gold_res == EXECUTION_ERROR:
return 0.1
if pred_res == gold_res:
return 1.0
return max(-1.0, min(1.0, reward))
except Exception:
return 0.0
def _worker(db_path: str, items):
conn = _connect_readonly(db_path)
try:
for idx, pred, gold in items:
results[idx] = _reward_with_conn(conn, pred, db_path, gold)
finally:
try:
conn.close()
except Exception:
pass
with ThreadPoolExecutor(max_workers=int(max_workers)) as ex:
futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()]
for fut in as_completed(futures):
fut.result()
return results