Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import sqlite3 | |
| from contextlib import closing | |
| from typing import Dict, List | |
| import torch | |
| from datasets import load_dataset | |
| from peft import PeftModel | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from trl import AutoModelForSeq2SeqLMWithValueHead | |
| import sys | |
| PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| sys.path.append(PROJECT_ROOT) | |
| from src.execution_reward import execution_reward # noqa: E402 | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small") | |
| DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database") | |
| # Prefer RL best model if present; otherwise fall back. | |
| RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql", "best_model") | |
| if not os.path.isdir(RL_DIR): | |
| RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql") | |
| SPLIT = "train[:100]" # quick sanity check | |
| MAX_NEW_TOKENS = 128 | |
| PREFIX = "translate English to SQL:" | |
| MAX_SCHEMA_CHARS = 1500 | |
| MAX_INPUT_TOKENS = 512 | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| print("Using device:", device) | |
| def get_db_path(db_id: str) -> str: | |
| return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite") | |
| _SCHEMA_CACHE: Dict[str, str] = {} | |
| def get_db_schema_text(db_path: str) -> str: | |
| if db_path in _SCHEMA_CACHE: | |
| return _SCHEMA_CACHE[db_path] | |
| schema_text = "" | |
| try: | |
| with closing(sqlite3.connect(db_path)) as conn: | |
| cur = conn.cursor() | |
| tables = cur.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" | |
| ).fetchall() | |
| for (tname,) in tables: | |
| cols = cur.execute(f'PRAGMA table_info(\"{tname}\")').fetchall() | |
| col_names = [c[1] for c in cols if c and isinstance(c[1], str)] | |
| schema_text += f"{tname}({', '.join(col_names)}) " | |
| except Exception: | |
| schema_text = "" | |
| if len(schema_text) > MAX_SCHEMA_CHARS: | |
| schema_text = schema_text[:MAX_SCHEMA_CHARS] | |
| _SCHEMA_CACHE[db_path] = schema_text | |
| return schema_text | |
| def encode_prompt(tokenizer, question: str, schema: str) -> torch.Tensor: | |
| schema = (schema or "")[:MAX_SCHEMA_CHARS] | |
| prefix_schema = f"{PREFIX}\n\nSchema:\n" | |
| mid = "\n\nQuestion:\n" | |
| suffix = f"{question}\n\nSQL:" | |
| prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False) | |
| schema_ids = tokenizer.encode(schema, add_special_tokens=False) | |
| mid_ids = tokenizer.encode(mid, add_special_tokens=False) | |
| suffix_ids = tokenizer.encode(suffix, add_special_tokens=False) | |
| eos_id = tokenizer.eos_token_id | |
| max_without_eos = MAX_INPUT_TOKENS - (1 if eos_id is not None else 0) | |
| fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids) | |
| if fixed_len > max_without_eos: | |
| keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids))) | |
| suffix_ids = suffix_ids[:keep] | |
| fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids) | |
| remaining_for_schema = max_without_eos - fixed_len | |
| if remaining_for_schema < 0: | |
| remaining_for_schema = 0 | |
| schema_ids = schema_ids[:remaining_for_schema] | |
| ids = (prefix_ids + schema_ids + mid_ids + suffix_ids)[:max_without_eos] | |
| if eos_id is not None: | |
| ids = ids + [eos_id] | |
| return torch.tensor(ids, dtype=torch.long).to(device) | |
| def load_model_and_tokenizer(): | |
| # Try loading the PPO-saved value-head model directly. | |
| try: | |
| tok = AutoTokenizer.from_pretrained(RL_DIR) | |
| mdl = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(RL_DIR).to(device) | |
| return tok, mdl | |
| except Exception: | |
| pass | |
| # Fallback: treat RL_DIR as a LoRA adapter directory. | |
| tok = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tok.pad_token_id is None: | |
| tok.pad_token = tok.eos_token | |
| base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device) | |
| try: | |
| base = PeftModel.from_pretrained(base, RL_DIR) | |
| except Exception: | |
| # Final fallback: use SFT adapter (if RL adapter not found) | |
| sft_dir = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter") | |
| base = PeftModel.from_pretrained(base, sft_dir) | |
| return tok, base | |
| def main() -> None: | |
| tokenizer, model = load_model_and_tokenizer() | |
| model.eval() | |
| ds = load_dataset("spider", split=SPLIT) | |
| correct = 0 | |
| valid = 0 | |
| for i, ex in enumerate(ds, start=1): | |
| question = ex["question"] | |
| gold_sql = ex["query"] | |
| db_id = ex["db_id"] | |
| db_path = get_db_path(db_id) | |
| schema = get_db_schema_text(db_path) | |
| inp = encode_prompt(tokenizer, question, schema) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| input_ids=inp.unsqueeze(0), | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| num_beams=1, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| pred_sql = tokenizer.decode(out[0], skip_special_tokens=True) | |
| r = execution_reward(pred_sql, db_path, gold_sql) | |
| if r > -1.0: | |
| valid += 1 | |
| if r >= 1.0: | |
| correct += 1 | |
| if i % 25 == 0: | |
| print(f"Evaluated {i}/{len(ds)}") | |
| n = len(ds) | |
| print("\nRESULTS") | |
| print(f"examples: {n}") | |
| print(f"execution_accuracy: {correct/n:.3f}") | |
| print(f"valid_sql_rate: {valid/n:.3f}") | |
| if __name__ == "__main__": | |
| main() | |