Spaces:
Sleeping
Sleeping
| import io | |
| import logging | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import uuid | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import duckdb | |
| import gradio as gr | |
| import pandas as pd | |
| import pytest | |
| import requests | |
| from dotenv import load_dotenv | |
| from src.client import LLMChain | |
| from src.pipelines import Query2Schema | |
| load_dotenv() | |
| LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING" | |
| logging.basicConfig( | |
| level=getattr(logging, LEVEL, logging.INFO), | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| if not Path("/tmp").exists(): | |
| os.mkdir("/tmp") | |
| def download_file(url: str, save_path: str): | |
| if Path(save_path).exists(): | |
| print(f"File already exists at {save_path}. Skipping download.") | |
| return duckdb.connect(database=save_path) | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(save_path, "wb") as out_file: | |
| shutil.copyfileobj(response.raw, out_file) | |
| return duckdb.connect(database=save_path) | |
| except Exception as e: | |
| logger.info(f"Error Downloding Chinook DB: {e}") | |
| raise | |
| conn = download_file( | |
| url="https://raw.githubusercontent.com/RandomFractals/duckdb-sql-tools/main/data/chinook/duckdb/chinook.duckdb", | |
| save_path="database/chinook.duckdb", | |
| ) | |
| pipe = Query2Schema(duckdb=conn, chain=LLMChain()) | |
| def get_test_databases() -> List[str]: | |
| """Scans the 'tests' directory for subdirectories (representing databases).""" | |
| return ["All", "chinook", "Northwind"] | |
| def get_tables_names(schema_name): | |
| tables = conn.execute("SELECT table_name FROM information_schema.tables").fetchall() | |
| return [table[0] for table in tables] | |
| def update_table_names(schema_name): | |
| tables = get_tables_names(schema_name) | |
| return gr.update(choices=tables, value=tables[0] if tables else None) | |
| def update_column_names(table_name): | |
| columns = conn.execute( | |
| f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}' " | |
| ).fetchall() | |
| columns = [column[0] for column in columns] | |
| df = pd.DataFrame(columns, columns=["Column Names"]) | |
| # return gr.update( | |
| # choices=columns, | |
| # value=columns[0] if columns else None | |
| # ) | |
| return df | |
| def get_ddl(table: str) -> str: | |
| result = conn.sql( | |
| f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';" | |
| ).df() | |
| ddl_create = result.iloc[0, 0] | |
| parent_database = result.iloc[0, 1] | |
| schema_name = result.iloc[0, 2] | |
| full_path = f"{parent_database}.{schema_name}.{table}" | |
| if schema_name != "main": | |
| old_path = f"{schema_name}.{table}" | |
| else: | |
| old_path = table | |
| ddl_create = ddl_create.replace(old_path, full_path) | |
| return ddl_create | |
| def run_pipeline(table: str, query_input: str) -> Tuple[str, pd.DataFrame]: | |
| try: | |
| schema = get_ddl(table=table) | |
| except Exception as e: | |
| logger.error(f"Failed to fetch DDL for table {table}: {e}") | |
| raise | |
| try: | |
| sql, df = pipe.try_sql_with_retries( | |
| user_question=query_input, | |
| context=schema, | |
| ) | |
| sql = sql.get("sql_query") if isinstance(sql, dict) else sql | |
| if not sql: | |
| raise ValueError("SQL generation returned None") | |
| return sql, df | |
| except Exception as e: | |
| logger.error(f"Error generating SQL for table {table}: {e}") | |
| raise | |
| def create_mesh_model(sql: str, db_name: str = "chinook") -> Tuple[str, str, str]: | |
| model_name = f"model_{uuid.uuid4().hex[:8]}" | |
| # Use catalog.schema.model_name format | |
| full_model_name = f"{db_name}.{model_name}" | |
| MODEL_HEADER = f"""MODEL ( | |
| name {full_model_name}, | |
| kind FULL | |
| ); | |
| """ | |
| try: | |
| model_dir = Path("models/") | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| model_path = model_dir / f"{model_name}.sql" | |
| model_text = MODEL_HEADER + "\n" + sql.replace("chinook.main.", "") | |
| model_path.write_text(model_text) | |
| return model_text, str(model_path), full_model_name | |
| except Exception as e: | |
| logger.error(f"Error creating SQL Mesh model: {e}") | |
| raise | |
| def create_pandera_schema( | |
| sql: str, user_instruction: str, model_name: str | |
| ) -> Tuple[str, str]: | |
| SCRIPT_HEADER = """ | |
| import pandas as pd | |
| import pandera.pandas as pa | |
| from pandera.typing import * | |
| import pytest | |
| from sqlmesh import Context | |
| from datetime import date | |
| from pathlib import Path | |
| import shutil | |
| import duckdb | |
| """ | |
| MESH_STR = f""" | |
| @pytest.fixture(scope="session") | |
| def mesh_context(): | |
| context = Context(paths=".", gateway="duckdb", load=True) | |
| yield context | |
| @pytest.fixture | |
| def today_str(): | |
| return date.today().isoformat() | |
| def test_back_fill(mesh_context, today_str): | |
| mesh_context.plan(skip_backfill=False, auto_apply=True) | |
| mesh_context.run(start=today_str, end=today_str) | |
| # df = mesh_context.fetchdf("SELECT * FROM {model_name} LIMIT 10") | |
| # assert not df.empty | |
| """ | |
| try: | |
| schema = pipe.generate_pandera_schema( | |
| sql_query=sql, user_instruction=user_instruction | |
| ) | |
| test_schema = f""" | |
| def test_schema(mesh_context, today_str): | |
| df = mesh_context.evaluate( | |
| "{model_name}", | |
| start=today_str, | |
| end=today_str, | |
| execution_time=today_str, | |
| ) | |
| {schema.split()[1].split("(")[0].strip()}.validate(df) | |
| """ | |
| print(schema) | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", | |
| prefix="test_", | |
| suffix=".py", | |
| delete=False, | |
| encoding="utf-8", | |
| ) as f: | |
| f.write(SCRIPT_HEADER) | |
| f.write("\n\n") | |
| f.write(schema) | |
| f.write("\n\n") | |
| f.write(MESH_STR) | |
| f.write("\n\n") | |
| f.write(test_schema) | |
| file_path = Path(f.name) | |
| return schema, str(file_path) | |
| except Exception as e: | |
| logger.error(f"Error creating Pandera schema: {e}") | |
| raise | |
| def create_test_file( | |
| table_name: str, db_name: str, sql_instruction: str, user_instruction: str | |
| ) -> Tuple[str, str, pd.DataFrame, str, str]: | |
| try: | |
| sql, df = run_pipeline(table=table_name, query_input=sql_instruction) | |
| model_text, model_file, model_name = create_mesh_model(sql=sql, db_name=db_name) | |
| schema, test_file = create_pandera_schema( | |
| sql=sql, | |
| user_instruction=user_instruction, | |
| model_name=model_name, | |
| ) | |
| return test_file, model_file, df, model_text, schema | |
| except Exception as e: | |
| logger.error(f"Error creating test file for table {table_name}: {e}") | |
| raise | |
| def run_tests( | |
| table_name: str, db_name: str, sql_instruction: str, user_instruction: str | |
| ): | |
| test_file, model_file, df, model_text, schema = create_test_file( | |
| table_name=table_name, | |
| db_name=db_name, | |
| sql_instruction=sql_instruction, | |
| user_instruction=user_instruction, | |
| ) | |
| capture_out = io.StringIO() | |
| capture_err = io.StringIO() | |
| old_out = sys.stdout | |
| old_err = sys.stderr | |
| sys.stdout = capture_out | |
| sys.stderr = capture_err | |
| try: | |
| retcode = pytest.main( | |
| [ | |
| test_file, | |
| "-s", | |
| "--tb=short", | |
| "--disable-warnings", | |
| "-o", | |
| "cache_dir=/tmp", | |
| ] | |
| ) | |
| except Exception as e: | |
| sys.stdout = old_out | |
| sys.stderr = old_err | |
| return f"Error running tests: {str(e)}", "" | |
| sys.stdout = old_out | |
| sys.stderr = old_err | |
| output = capture_out.getvalue() + "\n" + capture_err.getvalue() | |
| for f in [test_file, model_file]: | |
| try: | |
| os.remove(f) | |
| except FileNotFoundError: | |
| pass | |
| return output, df, model_text, schema | |
| custom_css = """ | |
| /* --- Overall container --- */ | |
| .gradio-container { | |
| background-color: #f0f4f8; /* light background */ | |
| font-family: 'Arial', sans-serif; | |
| } | |
| /* --- Logo --- */ | |
| .logo { | |
| max-width: 200px; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| /* --- Buttons --- */ | |
| .gr-button { | |
| background-color: #4a90e2 !important; /* primary color */ | |
| font-size: 14px; /* fixed font size */ | |
| padding: 6px 12px !important; /* fixed padding */ | |
| height: 36px !important; /* fixed height */ | |
| min-width: 120px !important; /* fixed width */ | |
| } | |
| .gr-button:hover { | |
| background-color: #3a7bc8 !important; | |
| } | |
| /* --- Logs Textbox --- */ | |
| #logs textarea { | |
| overflow-y: scroll; | |
| resize: none; | |
| height: 400px; | |
| width: 100%; | |
| font-family: monospace; | |
| font-size: 13px; | |
| line-height: 1.4; | |
| } | |
| /* Optional: small spacing between rows */ | |
| .gr-row { | |
| gap: 10px; | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css | |
| ) as demo: | |
| gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
| gr.Markdown( | |
| """ | |
| <div style='text-align: center;'> | |
| <strong style='font-size: 36px;'>SQL Test Suite</strong> | |
| <br> | |
| <span style='font-size: 20px;'>Automated testing and schema validation for SQL models with LLM.</span> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| schema_dropdown = gr.Dropdown( | |
| choices=["chinook", "northwind"], | |
| value="chinook", | |
| label="Select Schema", | |
| interactive=True, | |
| ) | |
| tables_dropdown = gr.Dropdown( | |
| choices=[], label="Available Tables", value=None, interactive=True | |
| ) | |
| # columns_dropdown = gr.Dropdown(choices=[], label="Available Columns", value=None, interactive=True) | |
| columns_df = gr.DataFrame(label="Columns", value=[], interactive=False) | |
| # with gr.Row(): | |
| # generate_result = gr.Button("Run Tests", variant="primary") | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| sql_instruction = gr.Textbox( | |
| lines=3, | |
| label="Business Metric Query (Plain English)", | |
| placeholder=( | |
| "Describe the business question you want to answer.\n" | |
| "Example: 'Show me the average sales per month.'\n" | |
| "Example: 'Total revenue by product category for last year.'" | |
| ), | |
| ) | |
| with gr.Row(): | |
| user_instruction = gr.Textbox( | |
| lines=5, | |
| label="Define Data Quality Level", | |
| placeholder=( | |
| "Describe the validation rule and how strict it should be.\n" | |
| "Example: Validate that the incident_zip column contains valid 5-digit ZIP codes.\n" | |
| ), | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| pass | |
| with gr.Column(scale=1): | |
| run_tests_btn = gr.Button("▶ Run Tests", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.Tab("Test Logs"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| test_logs = gr.Textbox( | |
| label="Test Logs", | |
| lines=20, | |
| max_lines=20, | |
| interactive=False, | |
| elem_id="logs", | |
| ) | |
| with gr.Tab("SQL Model"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| sql_model = gr.Textbox( | |
| label="SQL Model", | |
| lines=20, | |
| max_lines=20, | |
| interactive=False, | |
| elem_id="sql_model", | |
| ) | |
| with gr.Tab("Schema"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| result_schema = gr.Textbox( | |
| label="Validation Schema", | |
| lines=20, | |
| max_lines=20, | |
| interactive=False, | |
| ) | |
| with gr.Tab("Data"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| result_data = gr.DataFrame( | |
| label="Query Result", | |
| value=[], | |
| interactive=False, | |
| ) | |
| schema_dropdown.change( | |
| update_table_names, inputs=schema_dropdown, outputs=tables_dropdown | |
| ) | |
| tables_dropdown.change( | |
| update_column_names, inputs=tables_dropdown, outputs=columns_df | |
| ) | |
| demo.load( | |
| fn=update_table_names, inputs=schema_dropdown, outputs=tables_dropdown | |
| ) | |
| run_tests_btn.click( | |
| run_tests, | |
| inputs=[ | |
| tables_dropdown, | |
| schema_dropdown, | |
| sql_instruction, | |
| user_instruction, | |
| ], | |
| outputs=[test_logs, result_data, sql_model, result_schema], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |