Spaces:
Sleeping
Sleeping
File size: 5,533 Bytes
f0e5200 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | from __future__ import annotations
import os
import sqlite3
from contextlib import closing
from typing import Dict, List
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from trl import AutoModelForSeq2SeqLMWithValueHead
import sys
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PROJECT_ROOT)
from src.execution_reward import execution_reward # noqa: E402
BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")
# Prefer RL best model if present; otherwise fall back.
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql", "best_model")
if not os.path.isdir(RL_DIR):
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql")
SPLIT = "train[:100]" # quick sanity check
MAX_NEW_TOKENS = 128
PREFIX = "translate English to SQL:"
MAX_SCHEMA_CHARS = 1500
MAX_INPUT_TOKENS = 512
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
def get_db_path(db_id: str) -> str:
return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")
_SCHEMA_CACHE: Dict[str, str] = {}
def get_db_schema_text(db_path: str) -> str:
if db_path in _SCHEMA_CACHE:
return _SCHEMA_CACHE[db_path]
schema_text = ""
try:
with closing(sqlite3.connect(db_path)) as conn:
cur = conn.cursor()
tables = cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
).fetchall()
for (tname,) in tables:
cols = cur.execute(f'PRAGMA table_info(\"{tname}\")').fetchall()
col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
schema_text += f"{tname}({', '.join(col_names)}) "
except Exception:
schema_text = ""
if len(schema_text) > MAX_SCHEMA_CHARS:
schema_text = schema_text[:MAX_SCHEMA_CHARS]
_SCHEMA_CACHE[db_path] = schema_text
return schema_text
def encode_prompt(tokenizer, question: str, schema: str) -> torch.Tensor:
schema = (schema or "")[:MAX_SCHEMA_CHARS]
prefix_schema = f"{PREFIX}\n\nSchema:\n"
mid = "\n\nQuestion:\n"
suffix = f"{question}\n\nSQL:"
prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
schema_ids = tokenizer.encode(schema, add_special_tokens=False)
mid_ids = tokenizer.encode(mid, add_special_tokens=False)
suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)
eos_id = tokenizer.eos_token_id
max_without_eos = MAX_INPUT_TOKENS - (1 if eos_id is not None else 0)
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
if fixed_len > max_without_eos:
keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
suffix_ids = suffix_ids[:keep]
fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
remaining_for_schema = max_without_eos - fixed_len
if remaining_for_schema < 0:
remaining_for_schema = 0
schema_ids = schema_ids[:remaining_for_schema]
ids = (prefix_ids + schema_ids + mid_ids + suffix_ids)[:max_without_eos]
if eos_id is not None:
ids = ids + [eos_id]
return torch.tensor(ids, dtype=torch.long).to(device)
def load_model_and_tokenizer():
# Try loading the PPO-saved value-head model directly.
try:
tok = AutoTokenizer.from_pretrained(RL_DIR)
mdl = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(RL_DIR).to(device)
return tok, mdl
except Exception:
pass
# Fallback: treat RL_DIR as a LoRA adapter directory.
tok = AutoTokenizer.from_pretrained(BASE_MODEL)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
try:
base = PeftModel.from_pretrained(base, RL_DIR)
except Exception:
# Final fallback: use SFT adapter (if RL adapter not found)
sft_dir = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter")
base = PeftModel.from_pretrained(base, sft_dir)
return tok, base
def main() -> None:
tokenizer, model = load_model_and_tokenizer()
model.eval()
ds = load_dataset("spider", split=SPLIT)
correct = 0
valid = 0
for i, ex in enumerate(ds, start=1):
question = ex["question"]
gold_sql = ex["query"]
db_id = ex["db_id"]
db_path = get_db_path(db_id)
schema = get_db_schema_text(db_path)
inp = encode_prompt(tokenizer, question, schema)
with torch.no_grad():
out = model.generate(
input_ids=inp.unsqueeze(0),
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
num_beams=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
pred_sql = tokenizer.decode(out[0], skip_special_tokens=True)
r = execution_reward(pred_sql, db_path, gold_sql)
if r > -1.0:
valid += 1
if r >= 1.0:
correct += 1
if i % 25 == 0:
print(f"Evaluated {i}/{len(ds)}")
n = len(ds)
print("\nRESULTS")
print(f"examples: {n}")
print(f"execution_accuracy: {correct/n:.3f}")
print(f"valid_sql_rate: {valid/n:.3f}")
if __name__ == "__main__":
main()
|