File size: 5,533 Bytes
f0e5200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from __future__ import annotations

import os
import sqlite3
from contextlib import closing
from typing import Dict, List

import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from trl import AutoModelForSeq2SeqLMWithValueHead

import sys

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PROJECT_ROOT)
from src.execution_reward import execution_reward  # noqa: E402


BASE_MODEL = os.environ.get("BASE_MODEL", "t5-small")
DB_ROOT = os.path.join(PROJECT_ROOT, "data", "database")

# Prefer RL best model if present; otherwise fall back.
RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql", "best_model")
if not os.path.isdir(RL_DIR):
    RL_DIR = os.path.join(PROJECT_ROOT, "outputs", "rlhf_text2sql")

SPLIT = "train[:100]"  # quick sanity check
MAX_NEW_TOKENS = 128

PREFIX = "translate English to SQL:"
MAX_SCHEMA_CHARS = 1500
MAX_INPUT_TOKENS = 512


os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)


def get_db_path(db_id: str) -> str:
    return os.path.join(DB_ROOT, db_id, f"{db_id}.sqlite")


_SCHEMA_CACHE: Dict[str, str] = {}


def get_db_schema_text(db_path: str) -> str:
    if db_path in _SCHEMA_CACHE:
        return _SCHEMA_CACHE[db_path]
    schema_text = ""
    try:
        with closing(sqlite3.connect(db_path)) as conn:
            cur = conn.cursor()
            tables = cur.execute(
                "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
            ).fetchall()
            for (tname,) in tables:
                cols = cur.execute(f'PRAGMA table_info(\"{tname}\")').fetchall()
                col_names = [c[1] for c in cols if c and isinstance(c[1], str)]
                schema_text += f"{tname}({', '.join(col_names)}) "
    except Exception:
        schema_text = ""
    if len(schema_text) > MAX_SCHEMA_CHARS:
        schema_text = schema_text[:MAX_SCHEMA_CHARS]
    _SCHEMA_CACHE[db_path] = schema_text
    return schema_text


def encode_prompt(tokenizer, question: str, schema: str) -> torch.Tensor:
    schema = (schema or "")[:MAX_SCHEMA_CHARS]
    prefix_schema = f"{PREFIX}\n\nSchema:\n"
    mid = "\n\nQuestion:\n"
    suffix = f"{question}\n\nSQL:"

    prefix_ids = tokenizer.encode(prefix_schema, add_special_tokens=False)
    schema_ids = tokenizer.encode(schema, add_special_tokens=False)
    mid_ids = tokenizer.encode(mid, add_special_tokens=False)
    suffix_ids = tokenizer.encode(suffix, add_special_tokens=False)

    eos_id = tokenizer.eos_token_id
    max_without_eos = MAX_INPUT_TOKENS - (1 if eos_id is not None else 0)

    fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)
    if fixed_len > max_without_eos:
        keep = max(0, max_without_eos - (len(prefix_ids) + len(mid_ids)))
        suffix_ids = suffix_ids[:keep]
        fixed_len = len(prefix_ids) + len(mid_ids) + len(suffix_ids)

    remaining_for_schema = max_without_eos - fixed_len
    if remaining_for_schema < 0:
        remaining_for_schema = 0
    schema_ids = schema_ids[:remaining_for_schema]

    ids = (prefix_ids + schema_ids + mid_ids + suffix_ids)[:max_without_eos]
    if eos_id is not None:
        ids = ids + [eos_id]

    return torch.tensor(ids, dtype=torch.long).to(device)


def load_model_and_tokenizer():
    # Try loading the PPO-saved value-head model directly.
    try:
        tok = AutoTokenizer.from_pretrained(RL_DIR)
        mdl = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(RL_DIR).to(device)
        return tok, mdl
    except Exception:
        pass

    # Fallback: treat RL_DIR as a LoRA adapter directory.
    tok = AutoTokenizer.from_pretrained(BASE_MODEL)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    base = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL).to(device)
    try:
        base = PeftModel.from_pretrained(base, RL_DIR)
    except Exception:
        # Final fallback: use SFT adapter (if RL adapter not found)
        sft_dir = os.path.join(PROJECT_ROOT, "checkpoints", "sft_adapter")
        base = PeftModel.from_pretrained(base, sft_dir)
    return tok, base


def main() -> None:
    tokenizer, model = load_model_and_tokenizer()
    model.eval()

    ds = load_dataset("spider", split=SPLIT)

    correct = 0
    valid = 0

    for i, ex in enumerate(ds, start=1):
        question = ex["question"]
        gold_sql = ex["query"]
        db_id = ex["db_id"]
        db_path = get_db_path(db_id)
        schema = get_db_schema_text(db_path)

        inp = encode_prompt(tokenizer, question, schema)
        with torch.no_grad():
            out = model.generate(
                input_ids=inp.unsqueeze(0),
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=False,
                num_beams=1,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        pred_sql = tokenizer.decode(out[0], skip_special_tokens=True)
        r = execution_reward(pred_sql, db_path, gold_sql)
        if r > -1.0:
            valid += 1
        if r >= 1.0:
            correct += 1

        if i % 25 == 0:
            print(f"Evaluated {i}/{len(ds)}")

    n = len(ds)
    print("\nRESULTS")
    print(f"examples: {n}")
    print(f"execution_accuracy: {correct/n:.3f}")
    print(f"valid_sql_rate: {valid/n:.3f}")


if __name__ == "__main__":
    main()