Spaces:
Sleeping
Sleeping
File size: 2,794 Bytes
e62bece 7f3026b c7c0d53 7f424d1 5d261f7 818d4d0 5d261f7 e95c2d3 5d261f7 4bc3e8b e95c2d3 818d4d0 e95c2d3 cc1250f e95c2d3 f5903a4 e95c2d3 02976e0 f5903a4 818d4d0 f5903a4 818d4d0 ac4a697 818d4d0 ac4a697 f5903a4 818d4d0 f5903a4 5d261f7 a2f39c6 f5903a4 818d4d0 f5903a4 818d4d0 bb16527 22df2c5 e95c2d3 bb16527 5d261f7 e62bece 7f424d1 22df2c5 e95c2d3 f5903a4 818d4d0 ac4a697 818d4d0 5d261f7 818d4d0 5d261f7 818d4d0 0882b84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import warnings
warnings.filterwarnings("ignore")
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_num_threads(1)
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready")
# βββββββββββββββββββββββββ
# SQL FILTER
# βββββββββββββββββββββββββ
SQL_KEYWORDS = [
"sql", "database", "table", "select", "insert",
"update", "delete", "join", "group by",
"postgres", "mysql", "sqlite", "query"
]
def is_sql_related(text):
text = text.lower()
return any(k in text for k in SQL_KEYWORDS)
# βββββββββββββββββββββββββ
# GENERATION
# βββββββββββββββββββββββββ
SYSTEM_PROMPT = """
You are an expert SQL generator.
Rules:
- Only respond to SQL or database related questions.
- If the question is not about SQL or databases, refuse.
- Output ONLY SQL query.
- Do not explain.
"""
def generate_sql(user_input):
if not user_input.strip():
return "Enter SQL question."
if not is_sql_related(user_input):
return "I only respond to SQL and database related questions. If you want, I can craft helpful database queries for you."
prompt = f"""
{SYSTEM_PROMPT}
User request: {user_input}
SQL:
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.1,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
result = text.split("SQL:")[-1].strip()
result = result.split("\n\n")[0]
return result
# βββββββββββββββββββββββββ
# UI
# βββββββββββββββββββββββββ
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
lines=3,
label="SQL Question",
placeholder="Find duplicate emails in users table"
),
outputs=gr.Textbox(
lines=8,
label="Generated SQL"
),
title="AI SQL Generator (Portfolio Project)",
description="This model ONLY responds to SQL/database queries.",
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Count orders per customer last month"],
["Write a joke about cats"]
],
)
demo.launch() |