| from dotenv import load_dotenv |
| import os |
| from sentence_transformers import SentenceTransformer |
| import gradio as gr |
| from sklearn.metrics.pairwise import cosine_similarity |
| from groq import Groq |
| import sqlite3 |
| import pandas as pd |
|
|
| load_dotenv() |
| api = os.getenv("groq_api_key") |
|
|
| |
| def setup_database(): |
| conn = sqlite3.connect("college.db") |
| cursor = conn.cursor() |
|
|
| |
| cursor.execute("DROP TABLE IF EXISTS student;") |
| cursor.execute("DROP TABLE IF EXISTS employee;") |
| cursor.execute("DROP TABLE IF EXISTS course_info;") |
|
|
| |
| cursor.execute(""" |
| CREATE TABLE student ( |
| student_id INTEGER, |
| first_name TEXT, |
| last_name TEXT, |
| date_of_birth TEXT, |
| email TEXT, |
| phone_number TEXT, |
| major TEXT, |
| year_of_enrollment INTEGER |
| ); |
| """) |
|
|
| cursor.execute("INSERT INTO student VALUES (1, 'Alice', 'Smith', '2000-05-01', 'alice@example.com', '1234567890', 'Computer Science', 2019);") |
|
|
| |
| cursor.execute(""" |
| CREATE TABLE employee ( |
| employee_id INTEGER, |
| first_name TEXT, |
| last_name TEXT, |
| email TEXT, |
| department TEXT, |
| position TEXT, |
| salary REAL, |
| date_of_joining TEXT |
| ); |
| """) |
|
|
| cursor.execute("INSERT INTO employee VALUES (101, 'John', 'Doe', 'john@college.edu', 'CSE', 'Professor', 80000, '2015-08-20');") |
|
|
| |
| cursor.execute(""" |
| CREATE TABLE course_info ( |
| course_id INTEGER, |
| course_name TEXT, |
| course_code TEXT, |
| instructor_id INTEGER, |
| department TEXT, |
| credits INTEGER, |
| semester TEXT |
| ); |
| """) |
|
|
| cursor.execute("INSERT INTO course_info VALUES (501, 'AI Basics', 'CS501', 101, 'CSE', 4, 'Fall');") |
|
|
| conn.commit() |
| conn.close() |
|
|
| |
| setup_database() |
|
|
| |
| def create_metadata_embeddings(): |
| student = """Table: student...""" |
| employee = """Table: employee...""" |
| course = """Table: course_info...""" |
| metadata_list = [student, employee, course] |
| model = SentenceTransformer('all-MiniLM-L6-v2') |
| embeddings = model.encode(metadata_list) |
| return embeddings, model, student, employee, course |
|
|
| def find_best_fit(embeddings, model, user_query, student, employee, course): |
| query_embedding = model.encode([user_query]) |
| similarities = cosine_similarity(query_embedding, embeddings) |
| best_match_table = similarities.argmax() |
| return [student, employee, course][best_match_table] |
|
|
| def create_prompt(user_query, table_metadata): |
| system_prompt = """You are a SQL query generator specialized in generating SQL queries for a single table at a time. Your task is to accurately convert natural language queries into SQL statements based on the user's intent and the provided table metadata. |
| |
| Rules: |
| - Multi-Table Queries Allowed: You can generate queries involving multiple tables using appropriate SQL JOIN operations, based on the provided metadata. |
| - Join Logic: Use INNER JOIN, LEFT JOIN, or other appropriate joins based on logical relationships (e.g., foreign keys like `student_id`, `instructor_id`, etc.) inferred from the metadata. |
| - Metadata-Based Validation: Always ensure the generated query matches the table names, columns, and data types as described in the metadata. |
| - User Intent: Accurately capture the user's requirements such as filters, sorting, aggregations, and selections across one or more tables. |
| - SQL Syntax: Use standard SQL syntax that is compatible with most relational database systems. |
| - Output Format: Provide only the SQL query in a single line. Do not include explanations or any extra text. |
| |
| Input Format: |
| User Query: The user's natural language request. |
| Table Metadata: The structure of the relevant table, including the table name, column names, and data types. |
| |
| Output Format: |
| SQL Query: A valid SQL query formatted for readability. |
| Do not output anything else except the SQL query.Not even a single word extra.Ouput the whole query in a single line only. |
| You are ready to generate SQL queries based on the user input and table metadata.""" |
| user_prompt = f"User Query: {user_query}\nTable Metadata: {table_metadata}" |
| return system_prompt, user_prompt |
|
|
| def generate_sql(system_prompt, user_prompt): |
| client = Groq(api_key=api) |
| chat_completion = client.chat.completions.create( |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| model="llama3-70b-8192", |
| ) |
| res = chat_completion.choices[0].message.content.strip() |
| if res.lower().startswith("select"): |
| return res |
| else: |
| return None |
|
|
| |
| def execute_sql(sql_query): |
| try: |
| conn = sqlite3.connect("college.db") |
| df = pd.read_sql_query(sql_query, conn) |
| conn.close() |
| return df |
| except Exception as e: |
| return str(e) |
|
|
| |
|
|
| def response(user_query): |
| embeddings, model, student, employee, course = create_metadata_embeddings() |
| table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course) |
| system_prompt, user_prompt = create_prompt(user_query, table_metadata) |
| sql_query = generate_output(system_prompt, user_prompt) |
|
|
| |
| try: |
| conn = sqlite3.connect("college.db") |
| cursor = conn.cursor() |
| cursor.execute(sql_query) |
| result = cursor.fetchall() |
| conn.close() |
|
|
| return f"SQL Query:\n{sql_query}\n\nQuery Result:\n{result}" |
| except Exception as e: |
| return f"SQL Query:\n{sql_query}\n\nQuery Result:\nError: {str(e)}" |
|
|
|
|
| |
| desc = """Ask a natural language question about students, employees, or courses. I'll generate and run a SQL query for you.""" |
|
|
| demo = gr.Interface( |
| fn=response, |
| inputs=gr.Textbox(label="Your Question"), |
| outputs=gr.Textbox(label="SQL + Result"), |
| title="Natural Language to SQL + Result", |
| description="Ask a natural language question about students, employees, or courses. I'll generate and run a SQL query for you." |
| ) |
|
|
|
|
|
|
| demo.launch(share=True) |
|
|