text2sql_final_space / src /quantized_text2sql_engine.py
tjhalanigrid's picture
Step 2: added code folders
f0e5200
from __future__ import annotations
import os
import sqlite3
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
import torch
from src.quantization_utils import load_quant_artifact
from src.schema_encoder import SchemaEncoder
from src.sql_validator import validate_sql_schema
# ==========================================
# RELATIVE PATH RESOLUTION (GLOBAL)
# ==========================================
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if (PROJECT_ROOT / "data" / "database").exists():
DB_ROOT = PROJECT_ROOT / "data" / "database"
else:
DB_ROOT = PROJECT_ROOT / "final_databases"
class QuantizedText2SQLEngine:
def __init__(
self,
artifact_dir: str,
*,
device: str = "cpu",
use_constrained: bool = False,
exec_workers: int | None = None,
default_timeout_s: float = 2.0,
use_cache: bool = True,
cache_max_entries: int = 50_000,
):
self.device = device
self.use_constrained = bool(use_constrained)
self.tokenizer, self.model, self.meta = load_quant_artifact(artifact_dir, device=device, local_only=True)
self.schema_encoder = SchemaEncoder(DB_ROOT)
if exec_workers is None:
exec_workers = int(os.environ.get("SQL_EXEC_WORKERS", "8"))
self.exec_pool = ThreadPoolExecutor(max_workers=int(exec_workers))
self.default_timeout_s = float(default_timeout_s)
self.use_cache = bool(use_cache)
self.cache_max_entries = int(cache_max_entries)
self._cache: "OrderedDict[tuple[str, str], tuple[list, list]]" = OrderedDict()
self._cache_lock = threading.Lock()
self._stats_lock = threading.Lock()
self._exec_cache_hits = 0
self._exec_cache_misses = 0
self._exec_calls = 0
self._tls = threading.local()
def _get_db_path(self, db_id: str) -> str:
"""Smart resolver for flat vs nested database folders"""
path1 = DB_ROOT / db_id / f"{db_id}.sqlite"
path2 = DB_ROOT / f"{db_id}.sqlite"
return str(path1) if path1.exists() else str(path2)
def build_prompt(self, question: str, db_id: str) -> str:
schema = self.schema_encoder.structured_schema(db_id)
return (
"You are a SQLite expert.\n\n"
f"Database: {db_id}\n\n"
"Schema:\n"
f"{schema}\n\n"
"Question:\n"
f"{question}\n\n"
"SQL:"
)
def generate_sql_batch(
self,
pairs: Sequence[Tuple[str, str]],
*,
max_new_tokens: int = 120,
num_beams: int = 8,
repetition_penalty: float = 1.2,
) -> List[str]:
prompts = [self.build_prompt(q, db_id) for q, db_id in pairs]
if self.use_constrained:
from transformers.generation.logits_process import LogitsProcessorList
from src.constrained_decoding import SchemaConstrainedLogitsProcessor
sqls: List[str] = []
for (q, db_id), prompt in zip(pairs, prompts):
db_path = self._get_db_path(db_id)
enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
proc = LogitsProcessorList([SchemaConstrainedLogitsProcessor(self.tokenizer, db_path)])
out = self.model.generate(
**enc,
max_new_tokens=int(max_new_tokens),
num_beams=int(num_beams),
repetition_penalty=float(repetition_penalty),
logits_processor=proc,
)
sqls.append(self.tokenizer.decode(out[0], skip_special_tokens=True).strip())
return sqls
enc = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
out = self.model.generate(
**enc,
max_new_tokens=int(max_new_tokens),
num_beams=int(num_beams),
repetition_penalty=float(repetition_penalty),
)
return [self.tokenizer.decode(x, skip_special_tokens=True).strip() for x in out]
def _get_thread_conn(self, db_path: str) -> sqlite3.Connection:
conns = getattr(self._tls, "conns", None)
if conns is None:
conns = {}
self._tls.conns = conns
conn = conns.get(db_path)
if conn is None:
conn = sqlite3.connect(db_path)
conn.text_factory = lambda b: b.decode(errors="ignore")
conns[db_path] = conn
return conn
def _cache_get(self, key: tuple[str, str]) -> tuple[list, list] | None:
if not self.use_cache: return None
with self._cache_lock:
hit = self._cache.get(key)
if hit is None: return None
self._cache.move_to_end(key)
return hit
def _cache_put(self, key: tuple[str, str], value: tuple[list, list]) -> None:
if not self.use_cache: return
with self._cache_lock:
self._cache[key] = value
self._cache.move_to_end(key)
while len(self._cache) > self.cache_max_entries:
self._cache.popitem(last=False)
def _execute_one(self, sql: str, db_path: str, timeout_s: float | None = None):
timeout_s = float(self.default_timeout_s if timeout_s is None else timeout_s)
key = (db_path, sql)
cached = self._cache_get(key)
with self._stats_lock: self._exec_calls += 1
if cached is not None:
with self._stats_lock: self._exec_cache_hits += 1
return cached
with self._stats_lock: self._exec_cache_misses += 1
conn = self._get_thread_conn(db_path)
start_t = time.monotonic()
def handler():
return 1 if (time.monotonic() - start_t) > timeout_s else 0
conn.set_progress_handler(handler, 10_000)
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
cols = [d[0] for d in cur.description] if cur.description else []
out = (rows, cols)
self._cache_put(key, out)
return out
def stats(self) -> Dict[str, Any]:
with self._stats_lock:
calls, hits, misses = int(self._exec_calls), int(self._exec_cache_hits), int(self._exec_cache_misses)
hit_rate = (hits / calls) if calls else 0.0
return {
"exec_calls": calls, "exec_cache_hits": hits, "exec_cache_misses": misses,
"exec_cache_hit_rate": float(hit_rate), "use_cache": bool(self.use_cache),
"exec_workers": int(getattr(self.exec_pool, "_max_workers", 0) or 0),
}
def reset_stats(self) -> None:
with self._stats_lock:
self._exec_calls = self._exec_cache_hits = self._exec_cache_misses = 0
def execute_sql(self, sql: str, db_id: str, *, timeout_s: float | None = None, validate_schema: bool = True):
db_path = self._get_db_path(db_id)
if validate_schema:
try: ok, _ = validate_sql_schema(sql, db_path)
except Exception: ok = False
if not ok: raise ValueError("Invalid schema")
return self._execute_one(sql, db_path, timeout_s=timeout_s)
def ask(
self,
question: str,
db_id: str,
*,
max_new_tokens: int = 120,
num_beams: int = 8,
repetition_penalty: float = 1.2,
timeout_s: float | None = None,
) -> Dict[str, Any]:
sql = self.generate_sql_batch(
[(question, db_id)],
max_new_tokens=max_new_tokens,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
)[0]
db_path = self._get_db_path(db_id)
try: ok, _ = validate_sql_schema(sql, db_path)
except Exception: ok = False
if not ok: return {"sql": sql, "rows": [], "columns": [], "error": "Invalid schema"}
try:
rows, cols = self._execute_one(sql, db_path, timeout_s=timeout_s)
return {"sql": sql, "rows": rows, "columns": cols, "error": None}
except Exception as e:
return {"sql": sql, "rows": [], "columns": [], "error": str(e)}
def ask_batch_execute(self, pairs: Sequence[Tuple[str, str]]) -> List[Dict[str, Any]]:
sqls = self.generate_sql_batch(pairs)
results: List[Dict[str, Any]] = []
futures = {}
for (q, db_id), sql in zip(pairs, sqls):
db_path = self._get_db_path(db_id)
futures[self.exec_pool.submit(self._execute_one, sql, db_path)] = (sql, db_id)
for fut in as_completed(futures):
sql, db_id = futures[fut]
try:
rows, cols = fut.result()
results.append({"db_id": db_id, "sql": sql, "rows": rows, "columns": cols, "error": None})
except Exception as e:
results.append({"db_id": db_id, "sql": sql, "rows": [], "columns": [], "error": str(e)})
return results