| import torch |
| import sqlite3 |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
| |
| |
| |
| MODEL_PATH = "outputs/model" |
|
|
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
| print("Loading fine-tuned model...") |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH) |
| model.eval() |
|
|
| |
| |
| |
| print("Connecting to database...") |
| |
| conn = sqlite3.connect("data/database/department_management/department_management.sqlite") |
| cursor = conn.cursor() |
| print("Database connected ✔") |
|
|
| |
| |
| |
| def build_prompt(question): |
| schema = """ |
| Table department columns = Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. |
| Table head columns = head_ID, name, born_state, age. |
| Table management columns = department_ID, head_ID, temporary_acting. |
| """ |
| return f"translate English to SQL: {schema} question: {question}" |
|
|
| |
| |
| |
| def generate_sql(question): |
|
|
| prompt = build_prompt(question) |
|
|
| encoding = tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=256 |
| ) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| input_ids=encoding["input_ids"], |
| attention_mask=encoding["attention_mask"], |
| max_length=256, |
| num_beams=5, |
| early_stopping=True |
| ) |
|
|
| sql = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return sql.strip() |
|
|
| |
| |
| |
| def evaluate_sql(sql): |
| try: |
| cursor.execute(sql) |
| rows = cursor.fetchall() |
|
|
| |
| if len(rows) == 0: |
| return -0.2, rows |
|
|
| |
| else: |
| return 1.0, rows |
|
|
| except Exception as e: |
| |
| return -1.0, str(e) |
|
|
| |
| |
| |
| while True: |
| q = input("\nAsk question (type exit to quit): ") |
|
|
| if q.lower() == "exit": |
| break |
|
|
| sql = generate_sql(q) |
|
|
| print("\nPredicted SQL:") |
| print(sql) |
|
|
| |
| reward, output = evaluate_sql(sql) |
|
|
| print("\nReward:", reward) |
|
|
| if reward == -1.0: |
| print("SQL Error:", output) |
|
|
| elif reward == -0.2: |
| print("No results found") |
|
|
| else: |
| print("\nAnswer:") |
| for r in output: |
| print(r) |
|
|