Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import sqlite3 | |
| import threading | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Sequence, Tuple | |
| import torch | |
| from src.quantization_utils import load_quant_artifact | |
| from src.schema_encoder import SchemaEncoder | |
| from src.sql_validator import validate_sql_schema | |
| # ========================================== | |
| # RELATIVE PATH RESOLUTION (GLOBAL) | |
| # ========================================== | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| if (PROJECT_ROOT / "data" / "database").exists(): | |
| DB_ROOT = PROJECT_ROOT / "data" / "database" | |
| else: | |
| DB_ROOT = PROJECT_ROOT / "final_databases" | |
| class QuantizedText2SQLEngine: | |
| def __init__( | |
| self, | |
| artifact_dir: str, | |
| *, | |
| device: str = "cpu", | |
| use_constrained: bool = False, | |
| exec_workers: int | None = None, | |
| default_timeout_s: float = 2.0, | |
| use_cache: bool = True, | |
| cache_max_entries: int = 50_000, | |
| ): | |
| self.device = device | |
| self.use_constrained = bool(use_constrained) | |
| self.tokenizer, self.model, self.meta = load_quant_artifact(artifact_dir, device=device, local_only=True) | |
| self.schema_encoder = SchemaEncoder(DB_ROOT) | |
| if exec_workers is None: | |
| exec_workers = int(os.environ.get("SQL_EXEC_WORKERS", "8")) | |
| self.exec_pool = ThreadPoolExecutor(max_workers=int(exec_workers)) | |
| self.default_timeout_s = float(default_timeout_s) | |
| self.use_cache = bool(use_cache) | |
| self.cache_max_entries = int(cache_max_entries) | |
| self._cache: "OrderedDict[tuple[str, str], tuple[list, list]]" = OrderedDict() | |
| self._cache_lock = threading.Lock() | |
| self._stats_lock = threading.Lock() | |
| self._exec_cache_hits = 0 | |
| self._exec_cache_misses = 0 | |
| self._exec_calls = 0 | |
| self._tls = threading.local() | |
| def _get_db_path(self, db_id: str) -> str: | |
| """Smart resolver for flat vs nested database folders""" | |
| path1 = DB_ROOT / db_id / f"{db_id}.sqlite" | |
| path2 = DB_ROOT / f"{db_id}.sqlite" | |
| return str(path1) if path1.exists() else str(path2) | |
| def build_prompt(self, question: str, db_id: str) -> str: | |
| schema = self.schema_encoder.structured_schema(db_id) | |
| return ( | |
| "You are a SQLite expert.\n\n" | |
| f"Database: {db_id}\n\n" | |
| "Schema:\n" | |
| f"{schema}\n\n" | |
| "Question:\n" | |
| f"{question}\n\n" | |
| "SQL:" | |
| ) | |
| def generate_sql_batch( | |
| self, | |
| pairs: Sequence[Tuple[str, str]], | |
| *, | |
| max_new_tokens: int = 120, | |
| num_beams: int = 8, | |
| repetition_penalty: float = 1.2, | |
| ) -> List[str]: | |
| prompts = [self.build_prompt(q, db_id) for q, db_id in pairs] | |
| if self.use_constrained: | |
| from transformers.generation.logits_process import LogitsProcessorList | |
| from src.constrained_decoding import SchemaConstrainedLogitsProcessor | |
| sqls: List[str] = [] | |
| for (q, db_id), prompt in zip(pairs, prompts): | |
| db_path = self._get_db_path(db_id) | |
| enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device) | |
| proc = LogitsProcessorList([SchemaConstrainedLogitsProcessor(self.tokenizer, db_path)]) | |
| out = self.model.generate( | |
| **enc, | |
| max_new_tokens=int(max_new_tokens), | |
| num_beams=int(num_beams), | |
| repetition_penalty=float(repetition_penalty), | |
| logits_processor=proc, | |
| ) | |
| sqls.append(self.tokenizer.decode(out[0], skip_special_tokens=True).strip()) | |
| return sqls | |
| enc = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device) | |
| out = self.model.generate( | |
| **enc, | |
| max_new_tokens=int(max_new_tokens), | |
| num_beams=int(num_beams), | |
| repetition_penalty=float(repetition_penalty), | |
| ) | |
| return [self.tokenizer.decode(x, skip_special_tokens=True).strip() for x in out] | |
| def _get_thread_conn(self, db_path: str) -> sqlite3.Connection: | |
| conns = getattr(self._tls, "conns", None) | |
| if conns is None: | |
| conns = {} | |
| self._tls.conns = conns | |
| conn = conns.get(db_path) | |
| if conn is None: | |
| conn = sqlite3.connect(db_path) | |
| conn.text_factory = lambda b: b.decode(errors="ignore") | |
| conns[db_path] = conn | |
| return conn | |
| def _cache_get(self, key: tuple[str, str]) -> tuple[list, list] | None: | |
| if not self.use_cache: return None | |
| with self._cache_lock: | |
| hit = self._cache.get(key) | |
| if hit is None: return None | |
| self._cache.move_to_end(key) | |
| return hit | |
| def _cache_put(self, key: tuple[str, str], value: tuple[list, list]) -> None: | |
| if not self.use_cache: return | |
| with self._cache_lock: | |
| self._cache[key] = value | |
| self._cache.move_to_end(key) | |
| while len(self._cache) > self.cache_max_entries: | |
| self._cache.popitem(last=False) | |
| def _execute_one(self, sql: str, db_path: str, timeout_s: float | None = None): | |
| timeout_s = float(self.default_timeout_s if timeout_s is None else timeout_s) | |
| key = (db_path, sql) | |
| cached = self._cache_get(key) | |
| with self._stats_lock: self._exec_calls += 1 | |
| if cached is not None: | |
| with self._stats_lock: self._exec_cache_hits += 1 | |
| return cached | |
| with self._stats_lock: self._exec_cache_misses += 1 | |
| conn = self._get_thread_conn(db_path) | |
| start_t = time.monotonic() | |
| def handler(): | |
| return 1 if (time.monotonic() - start_t) > timeout_s else 0 | |
| conn.set_progress_handler(handler, 10_000) | |
| cur = conn.cursor() | |
| cur.execute(sql) | |
| rows = cur.fetchall() | |
| cols = [d[0] for d in cur.description] if cur.description else [] | |
| out = (rows, cols) | |
| self._cache_put(key, out) | |
| return out | |
| def stats(self) -> Dict[str, Any]: | |
| with self._stats_lock: | |
| calls, hits, misses = int(self._exec_calls), int(self._exec_cache_hits), int(self._exec_cache_misses) | |
| hit_rate = (hits / calls) if calls else 0.0 | |
| return { | |
| "exec_calls": calls, "exec_cache_hits": hits, "exec_cache_misses": misses, | |
| "exec_cache_hit_rate": float(hit_rate), "use_cache": bool(self.use_cache), | |
| "exec_workers": int(getattr(self.exec_pool, "_max_workers", 0) or 0), | |
| } | |
| def reset_stats(self) -> None: | |
| with self._stats_lock: | |
| self._exec_calls = self._exec_cache_hits = self._exec_cache_misses = 0 | |
| def execute_sql(self, sql: str, db_id: str, *, timeout_s: float | None = None, validate_schema: bool = True): | |
| db_path = self._get_db_path(db_id) | |
| if validate_schema: | |
| try: ok, _ = validate_sql_schema(sql, db_path) | |
| except Exception: ok = False | |
| if not ok: raise ValueError("Invalid schema") | |
| return self._execute_one(sql, db_path, timeout_s=timeout_s) | |
| def ask( | |
| self, | |
| question: str, | |
| db_id: str, | |
| *, | |
| max_new_tokens: int = 120, | |
| num_beams: int = 8, | |
| repetition_penalty: float = 1.2, | |
| timeout_s: float | None = None, | |
| ) -> Dict[str, Any]: | |
| sql = self.generate_sql_batch( | |
| [(question, db_id)], | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| repetition_penalty=repetition_penalty, | |
| )[0] | |
| db_path = self._get_db_path(db_id) | |
| try: ok, _ = validate_sql_schema(sql, db_path) | |
| except Exception: ok = False | |
| if not ok: return {"sql": sql, "rows": [], "columns": [], "error": "Invalid schema"} | |
| try: | |
| rows, cols = self._execute_one(sql, db_path, timeout_s=timeout_s) | |
| return {"sql": sql, "rows": rows, "columns": cols, "error": None} | |
| except Exception as e: | |
| return {"sql": sql, "rows": [], "columns": [], "error": str(e)} | |
| def ask_batch_execute(self, pairs: Sequence[Tuple[str, str]]) -> List[Dict[str, Any]]: | |
| sqls = self.generate_sql_batch(pairs) | |
| results: List[Dict[str, Any]] = [] | |
| futures = {} | |
| for (q, db_id), sql in zip(pairs, sqls): | |
| db_path = self._get_db_path(db_id) | |
| futures[self.exec_pool.submit(self._execute_one, sql, db_path)] = (sql, db_id) | |
| for fut in as_completed(futures): | |
| sql, db_id = futures[fut] | |
| try: | |
| rows, cols = fut.result() | |
| results.append({"db_id": db_id, "sql": sql, "rows": rows, "columns": cols, "error": None}) | |
| except Exception as e: | |
| results.append({"db_id": db_id, "sql": sql, "rows": [], "columns": [], "error": str(e)}) | |
| return results |