| # from __future__ import annotations | |
| # import hashlib | |
| # import os | |
| # import queue | |
| # import re | |
| # import sqlite3 | |
| # import threading | |
| # import time | |
| # from concurrent.futures import ThreadPoolExecutor, as_completed | |
| # from dataclasses import dataclass | |
| # from typing import Dict, List, Optional, Sequence, Set, Tuple, Union | |
| # # --- CACHE CONTROL --- | |
| # USE_CACHE = True | |
| # _REWARD_CACHE: Dict[str, float] = {} | |
| # def set_use_cache(enabled: bool): | |
| # """Dynamically toggle the reward cache for benchmarks.""" | |
| # global USE_CACHE | |
| # USE_CACHE = enabled | |
| # def _normalize_sql(sql: str) -> str: | |
| # if not isinstance(sql, str): | |
| # return "" | |
| # s = sql.strip() | |
| # if s.startswith("```"): | |
| # s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip() | |
| # s = re.sub(r"\n?```$", "", s).strip() | |
| # if s.lower().startswith("sql:"): | |
| # s = s[4:].strip() | |
| # if ";" in s: | |
| # s = s.split(";", 1)[0].strip() | |
| # return s | |
| # def _connect_readonly(db_path: str) -> sqlite3.Connection: | |
| # uri = f"file:{os.path.abspath(db_path)}?mode=ro" | |
| # conn = sqlite3.connect(uri, uri=True, check_same_thread=False) | |
| # conn.execute("PRAGMA query_only = ON;") | |
| # conn.execute("PRAGMA foreign_keys = ON;") | |
| # return conn | |
| # DEFAULT_QUERY_TIMEOUT_S = 2.0 | |
| # def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S) -> None: | |
| # start = time.monotonic() | |
| # def _handler() -> int: | |
| # return 1 if (time.monotonic() - start) > timeout_s else 0 | |
| # conn.set_progress_handler(_handler, 10_000) | |
| # def _list_tables(conn: sqlite3.Connection) -> List[str]: | |
| # try: | |
| # cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';") | |
| # return [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)] | |
| # except sqlite3.Error: | |
| # return [] | |
| # def _contains_table_name(sql: str, table_names: Sequence[str]) -> bool: | |
| # s = sql.lower() | |
| # for t in table_names: | |
| # tl = t.lower() | |
| # if not tl: | |
| # continue | |
| # if re.search(rf"\b{re.escape(tl)}\b", s): | |
| # return True | |
| # return False | |
| # def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool: | |
| # try: | |
| # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S) | |
| # conn.execute(f"EXPLAIN QUERY PLAN {sql}") | |
| # return True | |
| # except sqlite3.Error: | |
| # return False | |
| # def _execute(conn: sqlite3.Connection, sql: str, max_rows: int = 1000) -> Tuple[bool, List[Tuple], Optional[str]]: | |
| # try: | |
| # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S) | |
| # cur = conn.execute(sql) | |
| # rows = cur.fetchmany(max_rows) | |
| # norm_rows = [tuple(r) for r in rows] | |
| # return True, norm_rows, None | |
| # except sqlite3.Error as e: | |
| # return False, [], str(e) | |
| # _SQL_KEYWORDS_TO_IGNORE = { | |
| # "select", "from", "where", "join", "inner", "left", "right", "full", "outer", | |
| # "on", "group", "by", "order", "limit", "having", "distinct", "union", "intersect", | |
| # "except", "as", "and", "or", "not", "in", "is", "null", "like", "between", "case", | |
| # "when", "then", "else", "end", "asc", "desc" | |
| # } | |
| # _SQL_FUNCTIONS_TO_IGNORE = { | |
| # "count", "avg", "min", "max", "sum", "lower", "upper", "substr", "coalesce", | |
| # "round", "date", "datetime", "strftime" | |
| # } | |
| # # --- LIGHTWEIGHT PARSING --- | |
| # def is_valid_select(sql: str): | |
| # sql = sql.strip().lower() | |
| # return sql.startswith("select") or sql.startswith("with") | |
| # def extract_tables(sql: str) -> List[str]: | |
| # sql = sql.lower() | |
| # if "join" not in sql: | |
| # tables = re.findall(r'from\s+(\w+)', sql) | |
| # return list(set(tables)) | |
| # tables = re.findall(r'from\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql) | |
| # joins = re.findall(r'join\s+([a-zA-Z_][a-zA-Z0-9_]*)', sql) | |
| # return list(set(tables + joins)) | |
| # def extract_columns(sql: str) -> List[str]: | |
| # sql = sql.lower() | |
| # match = re.search(r'select\s+(.*?)\s+from', sql) | |
| # if not match: | |
| # return [] | |
| # cols = match.group(1) | |
| # if cols.strip() == "*": | |
| # return ["*"] | |
| # return [c.strip() for c in cols.split(",")] | |
| # def _get_db_tables_and_columns(conn: sqlite3.Connection) -> Tuple[Set[str], Set[str]]: | |
| # tables = set() | |
| # columns = set() | |
| # for t in _list_tables(conn): | |
| # tl = t.lower() | |
| # if not tl: | |
| # continue | |
| # tables.add(tl) | |
| # try: | |
| # cur = conn.execute(f'PRAGMA table_info("{t}")') | |
| # for row in cur.fetchall(): | |
| # if row and isinstance(row[1], str): | |
| # columns.add(row[1].lower()) | |
| # except sqlite3.Error: | |
| # continue | |
| # return tables, columns | |
| # def _safe_results_equal(a: List[Tuple], b: List[Tuple]) -> bool: | |
| # return a == b | |
| # @dataclass | |
| # class RewardDebugStats: | |
| # total: int = 0 | |
| # parsed_ok: int = 0 | |
| # table_match: int = 0 | |
| # column_match: int = 0 | |
| # executed_ok: int = 0 | |
| # exact_match: int = 0 | |
| # _DEBUG = RewardDebugStats() | |
| # def reset_debug_metrics() -> None: | |
| # global _DEBUG | |
| # _DEBUG = RewardDebugStats() | |
| # def get_debug_metrics() -> dict: | |
| # denom = max(_DEBUG.total, 1) | |
| # return { | |
| # "valid_sql_rate": _DEBUG.parsed_ok / denom, | |
| # "table_match_rate": _DEBUG.table_match / denom, | |
| # "column_match_rate": _DEBUG.column_match / denom, | |
| # "execution_accuracy": _DEBUG.exact_match / denom, | |
| # } | |
| # EXECUTION_ERROR = "EXECUTION_ERROR" | |
| # _RESULT_CACHE_LOCK = threading.Lock() | |
| # _RESULT_CACHE: "Dict[str, Union[List[Tuple], str]]" = {} | |
| # _RESULT_CACHE_MAX = 100_000 | |
| # def clear_result_cache() -> None: | |
| # """Clear both DB query cache and reward cache.""" | |
| # with _RESULT_CACHE_LOCK: | |
| # _RESULT_CACHE.clear() | |
| # _REWARD_CACHE.clear() | |
| # def _db_state_fingerprint(db_path: str) -> str: | |
| # try: | |
| # st = os.stat(db_path) | |
| # return f"{st.st_mtime_ns}:{st.st_size}" | |
| # except OSError: | |
| # return "missing" | |
| # def _result_cache_key(db_path: str, sql: str) -> str: | |
| # fp = _db_state_fingerprint(db_path) | |
| # payload = f"{fp}\0{sql}".encode("utf-8", errors="ignore") | |
| # return hashlib.sha256(payload).hexdigest() | |
| # class _ConnectionPool: | |
| # def __init__(self, db_path: str, maxsize: int = 1) -> None: | |
| # self.db_path = db_path | |
| # self.pool = queue.LifoQueue(maxsize=maxsize) | |
| # self.lock = threading.Lock() | |
| # def acquire(self) -> sqlite3.Connection: | |
| # try: | |
| # return self.pool.get_nowait() | |
| # except queue.Empty: | |
| # with self.lock: | |
| # try: | |
| # return self.pool.get_nowait() | |
| # except queue.Empty: | |
| # return _connect_readonly(self.db_path) | |
| # def release(self, conn: sqlite3.Connection) -> None: | |
| # try: | |
| # self.pool.put_nowait(conn) | |
| # except queue.Full: | |
| # try: | |
| # conn.close() | |
| # except Exception: | |
| # pass | |
| # _POOL_LOCK = threading.Lock() | |
| # _POOLS: Dict[str, _ConnectionPool] = {} | |
| # def _get_pool(db_path: str) -> _ConnectionPool: | |
| # with _POOL_LOCK: | |
| # pool = _POOLS.get(db_path) | |
| # if pool is None: | |
| # pool = _ConnectionPool(db_path=db_path, maxsize=1) | |
| # _POOLS[db_path] = pool | |
| # return pool | |
| # class _PooledConnection: | |
| # def __init__(self, db_path: str) -> None: | |
| # self.db_path = db_path | |
| # self.pool = _get_pool(db_path) | |
| # self.conn: Optional[sqlite3.Connection] = None | |
| # def __enter__(self) -> sqlite3.Connection: | |
| # self.conn = self.pool.acquire() | |
| # return self.conn | |
| # def __exit__(self, exc_type, exc, tb) -> None: | |
| # if self.conn is not None: | |
| # self.pool.release(self.conn) | |
| # self.conn = None | |
| # def _cache_get(key: str) -> Optional[Union[List[Tuple], str]]: | |
| # with _RESULT_CACHE_LOCK: | |
| # return _RESULT_CACHE.get(key) | |
| # def _cache_put(key: str, value: Union[List[Tuple], str]) -> None: | |
| # with _RESULT_CACHE_LOCK: | |
| # if len(_RESULT_CACHE) >= _RESULT_CACHE_MAX: | |
| # _RESULT_CACHE.clear() | |
| # _RESULT_CACHE[key] = value | |
| # def execute_sql(conn: sqlite3.Connection, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]: | |
| # try: | |
| # _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S) | |
| # cur = conn.execute(sql) | |
| # rows = cur.fetchmany(max_rows) | |
| # return [tuple(r) for r in rows] | |
| # except Exception: | |
| # return EXECUTION_ERROR | |
| # def execute_sql_cached(db_path: str, sql: str, *, max_rows: int = 1000) -> Union[List[Tuple], str]: | |
| # if not USE_CACHE: | |
| # with _PooledConnection(db_path) as conn: | |
| # return execute_sql(conn, sql, max_rows=max_rows) | |
| # key = _result_cache_key(db_path, sql) | |
| # cached = _cache_get(key) | |
| # if cached is not None: | |
| # return cached | |
| # with _PooledConnection(db_path) as conn: | |
| # res = execute_sql(conn, sql, max_rows=max_rows) | |
| # _cache_put(key, res) | |
| # return res | |
| # def execution_reward_timed( | |
| # pred_sql: str, db_path: str, gold_sql: str, *, measure_plan: bool = False, | |
| # ) -> Tuple[float, Dict[str, float]]: | |
| # timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0} | |
| # t0 = time.perf_counter() | |
| # sql = _normalize_sql(pred_sql) | |
| # gold = _normalize_sql(gold_sql) | |
| # if not is_valid_select(sql): | |
| # timings["parse_s"] = time.perf_counter() - t0 | |
| # return 0.0, timings | |
| # t1 = time.perf_counter() | |
| # timings["parse_s"] = t1 - t0 | |
| # if measure_plan: | |
| # with _PooledConnection(db_path) as conn: | |
| # p0 = time.perf_counter() | |
| # _explain_query_plan(conn, sql) | |
| # _explain_query_plan(conn, gold) | |
| # timings["plan_s"] = time.perf_counter() - p0 | |
| # e0 = time.perf_counter() | |
| # pred_res = execute_sql_cached(db_path, sql) | |
| # if pred_res == EXECUTION_ERROR: | |
| # timings["exec_s"] = time.perf_counter() - e0 | |
| # return 0.0, timings | |
| # gold_res = execute_sql_cached(db_path, gold) | |
| # timings["exec_s"] = time.perf_counter() - e0 | |
| # if gold_res == EXECUTION_ERROR: | |
| # return 0.0, timings | |
| # reward = -0.2 | |
| # reward += 0.2 | |
| # if _safe_results_equal(pred_res, gold_res): | |
| # return 1.0, timings | |
| # return max(-1.0, min(1.0, reward)), timings | |
| # def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float: | |
| # try: | |
| # sql = _normalize_sql(pred_sql) | |
| # gold = _normalize_sql(gold_sql) | |
| # if not is_valid_select(sql): | |
| # return -1.0 | |
| # reward = -0.2 | |
| # pred_tables = set(extract_tables(sql)) | |
| # gold_tables = set(extract_tables(gold)) | |
| # if pred_tables == gold_tables and len(gold_tables) > 0: | |
| # reward += 0.3 | |
| # pred_cols = set(extract_columns(sql)) | |
| # gold_cols = set(extract_columns(gold)) | |
| # if gold_cols: | |
| # overlap = len(pred_cols & gold_cols) / len(gold_cols) | |
| # reward += 0.3 * overlap | |
| # pred_res = execute_sql_cached(db_path, sql) | |
| # if pred_res == EXECUTION_ERROR: | |
| # return 0.0 | |
| # reward += 0.2 | |
| # gold_res = execute_sql_cached(db_path, gold) | |
| # if gold_res == EXECUTION_ERROR: | |
| # return 0.0 | |
| # if _safe_results_equal(pred_res, gold_res): | |
| # return 1.0 | |
| # return max(-1.0, min(1.0, reward)) | |
| # except Exception: | |
| # return 0.0 | |
| # def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float: | |
| # if not USE_CACHE: | |
| # return execution_reward(pred_sql, db_path, gold_sql) | |
| # key = f"{db_path}|{pred_sql}|{gold_sql}" | |
| # if key not in _REWARD_CACHE: | |
| # _REWARD_CACHE[key] = execution_reward(pred_sql, db_path, gold_sql) | |
| # return _REWARD_CACHE[key] | |
| # def execution_reward_batch_sequential(rollouts: Sequence[Tuple[str, str, str]]) -> List[float]: | |
| # return [cached_execution_reward(pred_sql, db_path, gold_sql) for pred_sql, db_path, gold_sql in rollouts] | |
| # def execution_reward_batch_parallel(rollouts: Sequence[Tuple[str, str, str]], *, max_workers: int = 20) -> List[float]: | |
| # if not rollouts: | |
| # return [] | |
| # unique_dbs = {db_path for _, db_path, _ in rollouts} | |
| # worker_count = max(1, min(max_workers, len(unique_dbs))) | |
| # results: List[Optional[float]] = [None] * len(rollouts) | |
| # with ThreadPoolExecutor(max_workers=worker_count) as executor: | |
| # futures = { | |
| # executor.submit(cached_execution_reward, pred_sql, db_path, gold_sql): i | |
| # for i, (pred_sql, db_path, gold_sql) in enumerate(rollouts) | |
| # } | |
| # for fut in as_completed(futures): | |
| # idx = futures[fut] | |
| # try: | |
| # results[idx] = float(fut.result()) | |
| # except Exception: | |
| # results[idx] = 0.0 | |
| # return [r if r is not None else 0.0 for r in results] | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import sqlite3 | |
| import threading | |
| import time | |
| import json | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from dataclasses import dataclass | |
| from typing import Dict, List | |
| from src.sql_validator import validate_sql_schema | |
| # ========================================================= | |
| # 🔥 CONFIG FLAGS | |
| # ========================================================= | |
| USE_SCHEMA_VALIDATION = True | |
| USE_CACHE = True | |
| DEFAULT_QUERY_TIMEOUT_S = 2.0 | |
| EXECUTION_ERROR = "EXECUTION_ERROR" | |
| _REWARD_CACHE: Dict[str, float] = {} | |
| # ========================================================= | |
| # 🔥 TASK 2: ERROR ANALYSIS + LOGGING | |
| # ========================================================= | |
| ERROR_LOG_FILE = "results/error_logs.json" | |
| def classify_error(sql: str) -> str: | |
| sql = sql.lower() | |
| if "join" in sql and " on " not in sql: | |
| return "missing_join" | |
| if "where" in sql and "=" not in sql and ">" not in sql and "<" not in sql: | |
| return "wrong_where" | |
| if "null" in sql: | |
| return "null_handling" | |
| if "group by" in sql and "count" not in sql: | |
| return "wrong_groupby" | |
| return "other" | |
| def get_hint(error_type: str) -> str: | |
| hints = { | |
| "missing_join": "Add proper JOIN condition using ON.", | |
| "wrong_where": "Check WHERE clause conditions.", | |
| "null_handling": "Handle NULL values using IS NULL.", | |
| "wrong_groupby": "Use aggregation functions with GROUP BY.", | |
| "other": "Check SQL syntax and logic." | |
| } | |
| return hints.get(error_type, "Check query.") | |
| def log_error(question: str, sql: str, error: str, error_type: str): | |
| os.makedirs("results", exist_ok=True) | |
| entry = { | |
| "question": question, | |
| "sql": sql, | |
| "error": error, | |
| "error_type": error_type, | |
| "timestamp": time.time() | |
| } | |
| if os.path.exists(ERROR_LOG_FILE): | |
| with open(ERROR_LOG_FILE, "r") as f: | |
| logs = json.load(f) | |
| else: | |
| logs = [] | |
| logs.append(entry) | |
| with open(ERROR_LOG_FILE, "w") as f: | |
| json.dump(logs, f, indent=2) | |
| # ========================================================= | |
| # CACHE/VALIDATION TOGGLES (Task 1) | |
| # ========================================================= | |
| def set_use_cache(enabled: bool) -> None: | |
| global USE_CACHE | |
| USE_CACHE = bool(enabled) | |
| def set_use_schema_validation(enabled: bool) -> None: | |
| global USE_SCHEMA_VALIDATION | |
| USE_SCHEMA_VALIDATION = bool(enabled) | |
| # ========================================================= | |
| # SQL CLEANING | |
| # ========================================================= | |
| def _normalize_sql(sql: str) -> str: | |
| if not isinstance(sql, str): | |
| return "" | |
| s = sql.strip() | |
| if s.startswith("```"): | |
| s = re.sub(r"^```[a-zA-Z0-9_+-]*\n?", "", s).strip() | |
| s = re.sub(r"\n?```$", "", s).strip() | |
| if s.lower().startswith("sql:"): | |
| s = s[4:].strip() | |
| if ";" in s: | |
| s = s.split(";", 1)[0].strip() | |
| return s | |
| # ========================================================= | |
| # DB EXECUTION | |
| # ========================================================= | |
| def _connect_readonly(db_path: str): | |
| uri = f"file:{os.path.abspath(db_path)}?mode=ro" | |
| conn = sqlite3.connect(uri, uri=True, check_same_thread=False) | |
| conn.execute("PRAGMA query_only = ON;") | |
| conn.execute("PRAGMA foreign_keys = ON;") | |
| return conn | |
| def _with_timeout(conn: sqlite3.Connection, timeout_s: float = DEFAULT_QUERY_TIMEOUT_S): | |
| start = time.monotonic() | |
| def handler(): | |
| return 1 if (time.monotonic() - start) > timeout_s else 0 | |
| conn.set_progress_handler(handler, 10_000) | |
| def execute_sql(conn, sql): | |
| try: | |
| _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S) | |
| cur = conn.execute(sql) | |
| return cur.fetchall() | |
| except Exception: | |
| return EXECUTION_ERROR | |
| _RESULT_CACHE = {} | |
| _RESULT_LOCK = threading.Lock() | |
| def execute_sql_cached(db_path, sql): | |
| key = f"{db_path}|{sql}" | |
| if USE_CACHE: | |
| with _RESULT_LOCK: | |
| if key in _RESULT_CACHE: | |
| return _RESULT_CACHE[key] | |
| conn = _connect_readonly(db_path) | |
| result = execute_sql(conn, sql) | |
| conn.close() | |
| if USE_CACHE: | |
| with _RESULT_LOCK: | |
| _RESULT_CACHE[key] = result | |
| return result | |
| def execute_sql_cached_conn(conn: sqlite3.Connection, db_path: str, sql: str): | |
| """ | |
| Like execute_sql_cached(), but reuses an existing connection. | |
| Intended for 1-thread-per-DB workloads (Task 1). | |
| """ | |
| key = f"{db_path}|{sql}" | |
| if USE_CACHE: | |
| with _RESULT_LOCK: | |
| if key in _RESULT_CACHE: | |
| return _RESULT_CACHE[key] | |
| result = execute_sql(conn, sql) | |
| if USE_CACHE: | |
| with _RESULT_LOCK: | |
| _RESULT_CACHE[key] = result | |
| return result | |
| def clear_result_cache() -> None: | |
| global _RESULT_CACHE, _REWARD_CACHE | |
| with _RESULT_LOCK: | |
| _RESULT_CACHE.clear() | |
| _REWARD_CACHE.clear() | |
| # ========================================================= | |
| # SQL PARSING | |
| # ========================================================= | |
| def is_valid_select(sql): | |
| return sql.lower().startswith("select") or sql.lower().startswith("with") | |
| def extract_tables(sql): | |
| return re.findall(r'from\s+(\w+)', sql.lower()) | |
| def extract_columns(sql): | |
| match = re.search(r'select\s+(.*?)\s+from', sql.lower()) | |
| if not match: | |
| return [] | |
| cols = match.group(1) | |
| return ["*"] if cols.strip() == "*" else [c.strip() for c in cols.split(",")] | |
| def get_sql_operations(sql: str): | |
| sql = sql.lower() | |
| ops = [] | |
| if "select" in sql: ops.append("SELECT") | |
| if "where" in sql: ops.append("WHERE") | |
| if "join" in sql: ops.append("JOIN") | |
| if "group by" in sql: ops.append("GROUP_BY") | |
| if "order by" in sql: ops.append("ORDER_BY") | |
| return ops | |
| def _explain_query_plan(conn: sqlite3.Connection, sql: str) -> bool: | |
| try: | |
| _with_timeout(conn, timeout_s=DEFAULT_QUERY_TIMEOUT_S) | |
| conn.execute(f"EXPLAIN QUERY PLAN {sql}") | |
| return True | |
| except Exception: | |
| return False | |
| def execution_reward_timed(pred_sql: str, db_path: str, gold_sql: str, measure_plan: bool = False): | |
| """ | |
| Returns (reward, timings) where timings keys: parse_s, plan_s, exec_s. | |
| Used by Task-1 benchmark to profile bottlenecks. | |
| """ | |
| timings = {"parse_s": 0.0, "plan_s": 0.0, "exec_s": 0.0} | |
| t0 = time.perf_counter() | |
| sql = _normalize_sql(pred_sql) | |
| gold = _normalize_sql(gold_sql) | |
| if not is_valid_select(sql): | |
| timings["parse_s"] = time.perf_counter() - t0 | |
| return 0.0, timings | |
| t1 = time.perf_counter() | |
| timings["parse_s"] = t1 - t0 | |
| conn = _connect_readonly(db_path) | |
| try: | |
| if measure_plan: | |
| p0 = time.perf_counter() | |
| _explain_query_plan(conn, sql) | |
| _explain_query_plan(conn, gold) | |
| timings["plan_s"] = time.perf_counter() - p0 | |
| e0 = time.perf_counter() | |
| pred_res = execute_sql_cached_conn(conn, db_path, sql) | |
| if pred_res == EXECUTION_ERROR: | |
| timings["exec_s"] = time.perf_counter() - e0 | |
| return 0.0, timings | |
| gold_res = execute_sql_cached_conn(conn, db_path, gold) | |
| timings["exec_s"] = time.perf_counter() - e0 | |
| if gold_res == EXECUTION_ERROR: | |
| return 0.0, timings | |
| reward = -0.2 + 0.2 | |
| if pred_res == gold_res: | |
| return 1.0, timings | |
| return max(-1.0, min(1.0, reward)), timings | |
| finally: | |
| try: | |
| conn.close() | |
| except Exception: | |
| pass | |
| # ========================================================= | |
| # 🔥 FINAL REWARD FUNCTION (TASK 2 INTEGRATED) | |
| # ========================================================= | |
| def execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float: | |
| try: | |
| sql = _normalize_sql(pred_sql) | |
| gold = _normalize_sql(gold_sql) | |
| if not is_valid_select(sql): | |
| return -1.0 | |
| reward = -0.2 | |
| # ========================= | |
| # SCHEMA VALIDATION (Task 3) | |
| # ========================= | |
| if USE_SCHEMA_VALIDATION: | |
| valid, _ = validate_sql_schema(sql, db_path) | |
| if not valid: | |
| error_type = classify_error(sql) | |
| log_error("UNKNOWN", sql, "schema_invalid", error_type) | |
| return 0.1 | |
| # ========================= | |
| # EXECUTION | |
| # ========================= | |
| pred_res = execute_sql_cached(db_path, sql) | |
| if pred_res == "EXECUTION_ERROR": | |
| error_type = classify_error(sql) | |
| log_error( | |
| question="UNKNOWN", | |
| sql=sql, | |
| error="execution_error", | |
| error_type=error_type | |
| ) | |
| print(f"[ERROR] {error_type}") | |
| print(f"[HINT] {get_hint(error_type)}") | |
| return 0.1 | |
| reward += 0.2 | |
| gold_res = execute_sql_cached(db_path, gold) | |
| if gold_res == "EXECUTION_ERROR": | |
| return 0.1 | |
| if pred_res == gold_res: | |
| return 1.0 | |
| return max(-1.0, min(1.0, reward)) | |
| except Exception as e: | |
| log_error("UNKNOWN", pred_sql, str(e), "runtime_error") | |
| return 0.0 | |
| # ========================================================= | |
| # BATCH EXECUTION (Task 1) | |
| # ========================================================= | |
| def cached_execution_reward(pred_sql: str, db_path: str, gold_sql: str) -> float: | |
| if not USE_CACHE: | |
| return float(execution_reward(pred_sql, db_path, gold_sql)) | |
| key = f"{db_path}|{pred_sql}|{gold_sql}" | |
| if key in _REWARD_CACHE: | |
| return float(_REWARD_CACHE[key]) | |
| r = float(execution_reward(pred_sql, db_path, gold_sql)) | |
| _REWARD_CACHE[key] = r | |
| return r | |
| def execution_reward_batch_sequential(rollouts): | |
| return [cached_execution_reward(p, d, g) for (p, d, g) in rollouts] | |
| def execution_reward_batch_parallel(rollouts, max_workers=10): | |
| results = [0.0] * len(rollouts) | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| futures = { | |
| executor.submit(cached_execution_reward, p, d, g): i | |
| for i, (p, d, g) in enumerate(rollouts) | |
| } | |
| for fut in as_completed(futures): | |
| idx = futures[fut] | |
| try: | |
| results[idx] = fut.result() | |
| except Exception: | |
| results[idx] = 0.0 | |
| return results | |
| def execution_reward_batch_parallel_by_db(rollouts, max_workers: int = 20): | |
| """ | |
| 1 thread per DB path. Reuses a single readonly connection per DB worker. | |
| Preserves input order. | |
| """ | |
| if not rollouts: | |
| return [] | |
| by_db = {} | |
| for idx, (pred_sql, db_path, gold_sql) in enumerate(rollouts): | |
| by_db.setdefault(db_path, []).append((idx, pred_sql, gold_sql)) | |
| results = [0.0 for _ in range(len(rollouts))] | |
| def _reward_with_conn(conn: sqlite3.Connection, pred_sql: str, db_path: str, gold_sql: str) -> float: | |
| try: | |
| sql = _normalize_sql(pred_sql) | |
| gold = _normalize_sql(gold_sql) | |
| if not is_valid_select(sql): | |
| return -1.0 | |
| reward = -0.2 | |
| if USE_SCHEMA_VALIDATION: | |
| valid, _ = validate_sql_schema(sql, db_path) | |
| if not valid: | |
| error_type = classify_error(sql) | |
| log_error("UNKNOWN", sql, "schema_invalid", error_type) | |
| return 0.1 | |
| pred_res = execute_sql_cached_conn(conn, db_path, sql) | |
| if pred_res == EXECUTION_ERROR: | |
| error_type = classify_error(sql) | |
| log_error("UNKNOWN", sql, "execution_error", error_type) | |
| return 0.1 | |
| reward += 0.2 | |
| gold_res = execute_sql_cached_conn(conn, db_path, gold) | |
| if gold_res == EXECUTION_ERROR: | |
| return 0.1 | |
| if pred_res == gold_res: | |
| return 1.0 | |
| return max(-1.0, min(1.0, reward)) | |
| except Exception: | |
| return 0.0 | |
| def _worker(db_path: str, items): | |
| conn = _connect_readonly(db_path) | |
| try: | |
| for idx, pred, gold in items: | |
| results[idx] = _reward_with_conn(conn, pred, db_path, gold) | |
| finally: | |
| try: | |
| conn.close() | |
| except Exception: | |
| pass | |
| with ThreadPoolExecutor(max_workers=int(max_workers)) as ex: | |
| futures = [ex.submit(_worker, db_path, items) for db_path, items in by_db.items()] | |
| for fut in as_completed(futures): | |
| fut.result() | |
| return results | |