tjhalanigrid's picture
Step 2: added code folders
f0e5200
from __future__ import annotations
import os
import sqlite3
from contextlib import closing
from typing import Dict, List
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from trl import AutoModelForSeq2SeqLMWithValueHead
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PROJECT_ROOT)
from src.execution_reward import execution_reward # noqa: E402
BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
# Prefer RL best model if present; otherwise fall back.
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql", "best_model")
if not os.path.isdir(RL_DIR):
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql")
SPLIT = "train[:100]" # quick sanity check
MAX_NEW_TOKENS = 128
PREFIX = "translate English to SQL:"
MAX_SCHEMA_CHARS = 1500
MAX_INPUT_TOKENS = 512
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
def get_db_path(db_id: str) -> str:
return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
_SCHEMA_CACHE: Dict[str, str] = {}
def get_db_schema_text(db_path: str) -> str:
if db_path in _SCHEMA_CACHE:
return _SCHEMA_CACHE[db_path]
schema_text = ""
try:
with closing(sqlite3.connect(db_path)) as conn:
cur = conn.cursor()
tables = cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
).fetchall()
for (tname,) in tables:
cols = cur.execute(f'PRAGMA table_info(\"{tname}\")').fetchall()
col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
schema_text += f"{tname}({', '.join(col_names)}) "
except Exception:
schema_text = ""
if len(schema_text) > MAX_SCHEMA_CHARS:
schema_text = schema_text[:MAX_SCHEMA_CHARS]
_SCHEMA_CACHE[db_path] = schema_text
return schema_text
def encode_prompt(tokenizer, question: str, schema: str) -> torch.Tensor:
schema = (schema or "")[:MAX_SCHEMA_CHARS]
prefix_schema = f"{PREFIX}\n\nSchema:\n"
mid = "\n\nQuestion:\n"
suffix = f"{question}\n\nSQL:"
prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
schema_ids = tokenizer.encode(schema, add_special_tokens=False)
mid_ids = tokenizer.encode(mid, add_special_tokens=False)
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
eos_id = tokenizer.eos_token_id
max_without_eos = MAX_INPUT_TOKENS - (1 if eos_id is not None else 0)
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
if fixed_len > max_without_eos:
keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
suffix_ids = suffix_ids[:keep]
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
remaining_for_schema = max_without_eos - fixed_len
if remaining_for_schema < 0:
remaining_for_schema = 0
schema_ids = schema_ids[:remaining_for_schema]
ids = (prefix_ids + schema_ids + mid_ids + suffix_ids)[:max_without_eos]
if eos_id is not None:
ids = ids + [eos_id]
return torch.tensor(ids, dtype=torch.long).to(device)
def load_model_and_tokenizer():
# Try loading the PPO-saved value-head model directly.
try:
tok = AutoTokenizer.from_pretrained(RL_DIR)
mdl = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(RL_DIR).to(device)
return tok, mdl
except Exception:
pass
# Fallback: treat RL_DIR as a LoRA adapter directory.
tok = AutoTokenizer.from_pretrained(BASE_MODEL)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
try:
base = PeftModel.from_pretrained(base, RL_DIR)
except Exception:
# Final fallback: use SFT adapter (if RL adapter not found)
sft_dir = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter")
base = PeftModel.from_pretrained(base, sft_dir)
return tok, base
def main() -> None:
tokenizer, model = load_model_and_tokenizer()
model.eval()
ds = load_dataset("spider", split=SPLIT)
correct = 0
valid = 0
for i, ex in enumerate(ds, start=1):
question = ex["question"]
gold_sql = ex["query"]
db_id = ex["db_id"]
db_path = get_db_path(db_id)
schema = get_db_schema_text(db_path)
inp = encode_prompt(tokenizer, question, schema)
with torch.no_grad():
out = model.generate(
input_ids=inp.unsqueeze(0),
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
num_beams=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
pred_sql = tokenizer.decode(out[0], skip_special_tokens=True)
r = execution_reward(pred_sql, db_path, gold_sql)
if r > -1.0:
valid += 1
if r >= 1.0:
correct += 1
if i % 25 == 0:
print(f"Evaluated {i}/{len(ds)}")
n = len(ds)
print("\nRESULTS")
print(f"examples: {n}")
print(f"execution_accuracy: {correct/n:.3f}")
print(f"valid_sql_rate: {valid/n:.3f}")
if __name__ == "__main__":
main()