import io import logging import os import shutil import sys import tempfile import uuid from pathlib import Path from typing import List, Tuple import duckdb import gradio as gr import pandas as pd import pytest import requests from dotenv import load_dotenv from src.client import LLMChain from src.pipelines import Query2Schema load_dotenv() LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING" logging.basicConfig( level=getattr(logging, LEVEL, logging.INFO), format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) logger = logging.getLogger(__name__) if not Path("/tmp").exists(): os.mkdir("/tmp") def download_file(url: str, save_path: str): if Path(save_path).exists(): print(f"File already exists at {save_path}. Skipping download.") return duckdb.connect(database=save_path) try: response = requests.get(url, stream=True) response.raise_for_status() with open(save_path, "wb") as out_file: shutil.copyfileobj(response.raw, out_file) return duckdb.connect(database=save_path) except Exception as e: logger.info(f"Error Downloding Chinook DB: {e}") raise conn = download_file( url="https://raw.githubusercontent.com/RandomFractals/duckdb-sql-tools/main/data/chinook/duckdb/chinook.duckdb", save_path="database/chinook.duckdb", ) pipe = Query2Schema(duckdb=conn, chain=LLMChain()) def get_test_databases() -> List[str]: """Scans the 'tests' directory for subdirectories (representing databases).""" return ["All", "chinook", "Northwind"] def get_tables_names(schema_name): tables = conn.execute("SELECT table_name FROM information_schema.tables").fetchall() return [table[0] for table in tables] def update_table_names(schema_name): tables = get_tables_names(schema_name) return gr.update(choices=tables, value=tables[0] if tables else None) def update_column_names(table_name): columns = conn.execute( f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}' " ).fetchall() columns = [column[0] for column in columns] df = pd.DataFrame(columns, columns=["Column Names"]) # return gr.update( # choices=columns, # value=columns[0] if columns else None # ) return df def get_ddl(table: str) -> str: result = conn.sql( f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';" ).df() ddl_create = result.iloc[0, 0] parent_database = result.iloc[0, 1] schema_name = result.iloc[0, 2] full_path = f"{parent_database}.{schema_name}.{table}" if schema_name != "main": old_path = f"{schema_name}.{table}" else: old_path = table ddl_create = ddl_create.replace(old_path, full_path) return ddl_create def run_pipeline(table: str, query_input: str) -> Tuple[str, pd.DataFrame]: try: schema = get_ddl(table=table) except Exception as e: logger.error(f"Failed to fetch DDL for table {table}: {e}") raise try: sql, df = pipe.try_sql_with_retries( user_question=query_input, context=schema, ) sql = sql.get("sql_query") if isinstance(sql, dict) else sql if not sql: raise ValueError("SQL generation returned None") return sql, df except Exception as e: logger.error(f"Error generating SQL for table {table}: {e}") raise def create_mesh_model(sql: str, db_name: str = "chinook") -> Tuple[str, str, str]: model_name = f"model_{uuid.uuid4().hex[:8]}" # Use catalog.schema.model_name format full_model_name = f"{db_name}.{model_name}" MODEL_HEADER = f"""MODEL ( name {full_model_name}, kind FULL ); """ try: model_dir = Path("models/") model_dir.mkdir(parents=True, exist_ok=True) model_path = model_dir / f"{model_name}.sql" model_text = MODEL_HEADER + "\n" + sql.replace("chinook.main.", "") model_path.write_text(model_text) return model_text, str(model_path), full_model_name except Exception as e: logger.error(f"Error creating SQL Mesh model: {e}") raise def create_pandera_schema( sql: str, user_instruction: str, model_name: str ) -> Tuple[str, str]: SCRIPT_HEADER = """ import pandas as pd import pandera.pandas as pa from pandera.typing import * import pytest from sqlmesh import Context from datetime import date from pathlib import Path import shutil import duckdb """ MESH_STR = f""" @pytest.fixture(scope="session") def mesh_context(): context = Context(paths=".", gateway="duckdb", load=True) yield context @pytest.fixture def today_str(): return date.today().isoformat() def test_back_fill(mesh_context, today_str): mesh_context.plan(skip_backfill=False, auto_apply=True) mesh_context.run(start=today_str, end=today_str) # df = mesh_context.fetchdf("SELECT * FROM {model_name} LIMIT 10") # assert not df.empty """ try: schema = pipe.generate_pandera_schema( sql_query=sql, user_instruction=user_instruction ) test_schema = f""" def test_schema(mesh_context, today_str): df = mesh_context.evaluate( "{model_name}", start=today_str, end=today_str, execution_time=today_str, ) {schema.split()[1].split("(")[0].strip()}.validate(df) """ print(schema) with tempfile.NamedTemporaryFile( mode="w", prefix="test_", suffix=".py", delete=False, encoding="utf-8", ) as f: f.write(SCRIPT_HEADER) f.write("\n\n") f.write(schema) f.write("\n\n") f.write(MESH_STR) f.write("\n\n") f.write(test_schema) file_path = Path(f.name) return schema, str(file_path) except Exception as e: logger.error(f"Error creating Pandera schema: {e}") raise def create_test_file( table_name: str, db_name: str, sql_instruction: str, user_instruction: str ) -> Tuple[str, str, pd.DataFrame, str, str]: try: sql, df = run_pipeline(table=table_name, query_input=sql_instruction) model_text, model_file, model_name = create_mesh_model(sql=sql, db_name=db_name) schema, test_file = create_pandera_schema( sql=sql, user_instruction=user_instruction, model_name=model_name, ) return test_file, model_file, df, model_text, schema except Exception as e: logger.error(f"Error creating test file for table {table_name}: {e}") raise def run_tests( table_name: str, db_name: str, sql_instruction: str, user_instruction: str ): test_file, model_file, df, model_text, schema = create_test_file( table_name=table_name, db_name=db_name, sql_instruction=sql_instruction, user_instruction=user_instruction, ) capture_out = io.StringIO() capture_err = io.StringIO() old_out = sys.stdout old_err = sys.stderr sys.stdout = capture_out sys.stderr = capture_err try: retcode = pytest.main( [ test_file, "-s", "--tb=short", "--disable-warnings", "-o", "cache_dir=/tmp", ] ) except Exception as e: sys.stdout = old_out sys.stderr = old_err return f"Error running tests: {str(e)}", "" sys.stdout = old_out sys.stderr = old_err output = capture_out.getvalue() + "\n" + capture_err.getvalue() for f in [test_file, model_file]: try: os.remove(f) except FileNotFoundError: pass return output, df, model_text, schema custom_css = """ /* --- Overall container --- */ .gradio-container { background-color: #f0f4f8; /* light background */ font-family: 'Arial', sans-serif; } /* --- Logo --- */ .logo { max-width: 200px; margin: 20px auto; display: block; } /* --- Buttons --- */ .gr-button { background-color: #4a90e2 !important; /* primary color */ font-size: 14px; /* fixed font size */ padding: 6px 12px !important; /* fixed padding */ height: 36px !important; /* fixed height */ min-width: 120px !important; /* fixed width */ } .gr-button:hover { background-color: #3a7bc8 !important; } /* --- Logs Textbox --- */ #logs textarea { overflow-y: scroll; resize: none; height: 400px; width: 100%; font-family: monospace; font-size: 13px; line-height: 1.4; } /* Optional: small spacing between rows */ .gr-row { gap: 10px; } """ with gr.Blocks( theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css ) as demo: gr.Image("logo.png", label=None, show_label=False, container=False, height=100) gr.Markdown( """