Spaces:
Sleeping
Sleeping
| """Evaluate text2sql spider model predictions.""" | |
| import json | |
| import os | |
| import re | |
| import signal | |
| import sys | |
| import traceback | |
| from pathlib import Path | |
| from typing import Any | |
| import click | |
| import pandas as pd | |
| from rich.console import Console | |
| from tqdm.auto import tqdm | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError | |
| sys.path.append(os.path.join(os.path.dirname(__file__), ".")) | |
| # from metrics.spider import evaluation as spider_evaluation # type: ignore # noqa: E402 | |
| from metrics.test_suite_sql_eval import ( # type: ignore # noqa: E402 | |
| evaluation as test_suite_evaluation, | |
| ) | |
| from data_utils import read_tables_json # type: ignore # noqa: E402 | |
| from metric_utils import ( # type: ignore # noqa: E402 | |
| correct_casing, | |
| edit_distance, | |
| ) | |
| console = Console(soft_wrap=True) | |
| LEVELS = ["easy", "medium", "hard", "duckdb", "ddl", "all"] | |
| PARTIAL_TYPES = [ | |
| "select", | |
| "select(no AGG)", | |
| "where", | |
| "where(no OP)", | |
| "group(no Having)", | |
| "group", | |
| "order", | |
| "and/or", | |
| "IUEN", | |
| "keywords", | |
| ] | |
| TIMEOUT_SECONDS = 30 | |
| def timeout_handler(signum: int, frame: Any) -> None: | |
| raise TimeoutError("Function execution timed out.") | |
| def print_scores(scores: dict, model_name: str, metric_type: str = "exec") -> None: | |
| """Print scores.""" | |
| def print_formated_s( | |
| row_name: str, l: list[str], element_format: str = "{}", sep: str = "\t" | |
| ) -> None: | |
| template = "{}" + sep + sep.join([element_format] * len(l)) | |
| console.print(template.format(row_name, *l)) | |
| # Add empty scores for each level if not present | |
| for level in LEVELS: | |
| if level not in scores: | |
| scores[level] = {} | |
| scores[level]["count"] = 0 | |
| scores[level]["exec"] = 0 | |
| scores[level]["exact"] = 0 | |
| print_formated_s("", LEVELS) | |
| counts = [scores[level]["count"] for level in LEVELS] | |
| print_formated_s("count", counts) | |
| console.print(f">====================== {model_name} =====================") | |
| if metric_type == "exec": | |
| console.print( | |
| ">===================== EXECUTION ACCURACY =====================" | |
| ) | |
| exec_scores = [scores[level]["exec"] for level in LEVELS] | |
| print_formated_s("execution", exec_scores, element_format="{:.3f}") | |
| elif metric_type == "exact": | |
| console.print( | |
| "\n>====================== EXACT MATCHING ACCURACY =====================" | |
| ) | |
| exact_scores = [scores[level]["exact"] for level in LEVELS] | |
| print_formated_s("exact match", exact_scores, element_format="{:.3f}") | |
| def compute_exact_match_metric( | |
| predictions: list, | |
| references: list, | |
| gold_dbs: list, | |
| kmaps: dict, | |
| db_dir: str, | |
| categories, | |
| ) -> dict: | |
| """Compute exact match metric.""" | |
| exact_match = {} | |
| exact_match["all"] = {} | |
| exact_match["all"]["count"] = 0 | |
| exact_match["all"]["exact"] = 0 | |
| for prediction, reference, gold_db, category in tqdm( | |
| zip(predictions, references, gold_dbs, categories), total=len(predictions) | |
| ): | |
| if category not in exact_match: | |
| exact_match[category] = {} | |
| exact_match[category]["count"] = 0 | |
| exact_match[category]["exact"] = 0 | |
| exact_match["all"]["count"] += 1 | |
| exact_match[category]["count"] += 1 | |
| try: | |
| match = int(prediction.trim() == reference.trim()) | |
| exact_match[category]["exact"] += match | |
| exact_match["all"]["exact"] += match | |
| except Exception: | |
| pass | |
| return exact_match | |
| def evaluate_with_timeout(evaluator, gold_db, reference, prediction, | |
| setup_sql, validate_sql, turn_scores, idx, category, timeout): | |
| with ThreadPoolExecutor(max_workers=1) as executor: | |
| future = executor.submit(evaluator.evaluate_one, gold_db, reference, prediction, | |
| setup_sql, validate_sql, turn_scores, idx=idx, category=category) | |
| try: | |
| result = future.result(timeout=timeout) | |
| except TimeoutError: | |
| result = None | |
| return result | |
| def compute_test_suite_metric( | |
| predictions: list, | |
| references: list, | |
| gold_dbs: list, | |
| setup_sqls: list, | |
| validate_sqls: list, | |
| kmaps: dict, | |
| db_dir: str, | |
| categories: list[str] = None, | |
| ) -> tuple[Any, list[int | None]]: | |
| """Compute test suite execution metric.""" | |
| evaluator = test_suite_evaluation.Evaluator( | |
| db_dir=db_dir, | |
| kmaps=kmaps, | |
| etype="exec", | |
| plug_value=False, | |
| keep_distinct=False, | |
| progress_bar_for_each_datapoint=False, | |
| ) | |
| # Only used for Sparc/CoSQL | |
| turn_scores: dict[str, list] = {"exec": [], "exact": []} | |
| by_row_metrics: list[int | None] = [] | |
| for prediction, reference, gold_db, setup_sql, validate_sql, category in tqdm( | |
| zip(predictions, references, gold_dbs, setup_sqls, validate_sqls, categories), | |
| total=len(predictions), | |
| ): | |
| turn_idx = 0 # or any value that represents the current index if this is incorrect | |
| # Use the new function to evaluate with timeout | |
| ex_metrics = evaluate_with_timeout( | |
| evaluator, gold_db, reference, prediction, setup_sql, validate_sql, | |
| turn_scores, idx=turn_idx, category=category, timeout=TIMEOUT_SECONDS | |
| ) | |
| if ex_metrics: | |
| by_row_metrics.append(int(ex_metrics["exec"])) | |
| else: | |
| by_row_metrics.append(None) | |
| evaluator.finalize() | |
| return evaluator.scores, by_row_metrics | |
| def compute_metrics( | |
| gold_sqls: list[str], | |
| pred_sqls: list[str], | |
| gold_dbs: list[str], | |
| setup_sqls: list[str], | |
| validate_sqls: list[str], | |
| kmaps: dict, | |
| db_schemas: dict, | |
| database_dir: str, | |
| lowercase_schema_match: bool, | |
| model_name: str, | |
| categories: list[str] = None, | |
| ) -> dict[str, str]: | |
| """Compute all metrics for data slice.""" | |
| if len(gold_sqls) != len(pred_sqls): | |
| raise ValueError( | |
| f"Gold {len(gold_sqls)} and pred {len(pred_sqls)} have different number of lines!" | |
| ) | |
| all_metrics: dict[str, Any] = {} | |
| # Execution Accuracy | |
| metrics, by_row_metrics = compute_test_suite_metric( | |
| pred_sqls, | |
| gold_sqls, | |
| gold_dbs, | |
| setup_sqls, | |
| validate_sqls, | |
| kmaps, | |
| database_dir, | |
| categories, | |
| ) | |
| all_metrics["exec"] = metrics | |
| all_metrics["by_row_exec"] = by_row_metrics | |
| print_scores(metrics, model_name, "exec") | |
| # Exact Match Accuracy | |
| metrics = compute_exact_match_metric( | |
| pred_sqls, gold_sqls, gold_dbs, kmaps, database_dir, categories | |
| ) | |
| all_metrics["exact"] = metrics | |
| print_scores(metrics, model_name, "exact") | |
| # Equality Accuracy | |
| per_row_match = [ | |
| int(gold.lower() == pred.lower()) for gold, pred in zip(gold_sqls, pred_sqls) | |
| ] | |
| all_metrics["equality"] = {"equality": sum(per_row_match) / len(gold_sqls)} | |
| all_metrics["by_row_equality"] = per_row_match | |
| # Edit Distance | |
| per_row_edit_dist = [ | |
| edit_distance(gold, pred) for gold, pred in zip(gold_sqls, pred_sqls) | |
| ] | |
| edit_dist = sum(per_row_edit_dist) / len(gold_sqls) | |
| all_metrics["edit_distance"] = {"edit_distance": edit_dist} | |
| all_metrics["by_row_edit_distance"] = per_row_edit_dist | |
| return all_metrics | |
| def get_to_print(metrics: dict, key: str, model_name: str, num_rows: int) -> dict: | |
| """Get pretty print dictionary of metrics.""" | |
| return { | |
| "slice": key, | |
| "model": model_name, | |
| "support": num_rows, | |
| "exec": f"{metrics[key]['exec']['all']['exec']:.3f}", | |
| "exact": f"{metrics[key]['exact']['all']['exact']:.3f}", | |
| "equality": f"{metrics[key]['equality']['equality']:.3f}", | |
| "edit_distance": f"{metrics[key]['edit_distance']['edit_distance']:.3f}", | |
| } | |
| def cli() -> None: | |
| """Entrypoint.""" | |
| pass | |
| def evaluate( | |
| gold: str, | |
| pred: str, | |
| tables: str, | |
| db: str, | |
| slice_attribute: str, | |
| output_dir: str, | |
| output_filename: str, | |
| correct_sql_casing: bool, | |
| lowercase_schema_match: bool, | |
| ) -> None: | |
| """Evaluate SQL. | |
| Args: | |
| gold: path to gold sql file. | |
| pred: path to predicted json lines file. | |
| tables: the json path of the table metadata. | |
| db: path to database dir. | |
| slice_attribute: json attribute in gold data to slice on. | |
| output_dir: the prediction output directory | |
| output_filename: the prediction output filename | |
| correct_sql_casing: whether to correct casing of SQL keywords | |
| lowercase_schema_match: whether to lowercase schema match | |
| """ | |
| gold_path = Path(gold) | |
| pred_path = Path(pred) | |
| model_name = pred_path.stem | |
| if not output_filename: | |
| output_filename = pred_path.stem + "_eval.json" | |
| console.print(f"Saving to {Path(output_dir) / output_filename}") | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| kmaps = test_suite_evaluation.build_foreign_key_map_from_json(tables) | |
| db_schemas = read_tables_json(tables) | |
| gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) | |
| pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()] | |
| # Data validation | |
| assert len(gold_sqls_dict) == len( | |
| pred_sqls_dict | |
| ), "Sample size doesn't match between pred and gold file" | |
| # Keep track of everything | |
| full_results = [] | |
| for gold_sql, pred_sql in zip(gold_sqls_dict, pred_sqls_dict): | |
| merged_res = {**pred_sql, **gold_sql} | |
| full_results.append(merged_res) | |
| gold_sqls = [ | |
| re.sub(r"[\s\t\n]+", " ", p.get("gold", p.get("query", p.get("sql", "")))) | |
| for p in gold_sqls_dict | |
| ] | |
| setup_sqls = [re.sub(r"[\s\t\n]+", " ", p["setup_sql"]) for p in gold_sqls_dict] | |
| validate_sqls = [ | |
| re.sub(r"[\s\t\n]+", " ", p["validation_sql"]) for p in gold_sqls_dict | |
| ] | |
| gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] | |
| pred_sqls = [re.sub(r"[\s\t\n]+", " ", p["pred"]) for p in pred_sqls_dict] | |
| categories = [p.get("category", "") for p in gold_sqls_dict] | |
| if correct_sql_casing: | |
| # One line to correct casing of SQL keywords using correct_casing(sql) | |
| gold_sqls = [correct_casing(sql) for sql in gold_sqls] | |
| pred_sqls = [correct_casing(sql) for sql in pred_sqls] | |
| final_metrics: dict[str, dict[str, Any]] = {} | |
| to_print = [] | |
| final_metrics["all"] = compute_metrics( | |
| gold_sqls=gold_sqls, | |
| pred_sqls=pred_sqls, | |
| gold_dbs=gold_dbs, | |
| setup_sqls=setup_sqls, | |
| validate_sqls=validate_sqls, | |
| kmaps=kmaps, | |
| db_schemas=db_schemas, | |
| database_dir=db, | |
| lowercase_schema_match=lowercase_schema_match, | |
| model_name=model_name + "(all)", | |
| categories=categories, | |
| ) | |
| for k, v in final_metrics["all"].items(): | |
| if k.startswith("by_row"): | |
| assert len(v) == len(gold_sqls) | |
| for dct, val in zip(full_results, v): | |
| dct[k[len("by_row_") :]] = val | |
| to_print.append(get_to_print(final_metrics, "all", model_name, len(gold_sqls))) | |
| # TODO: could be way more efficient if we subsliced the results but...whatever | |
| if slice_attribute: | |
| for unq_value in sorted(set([g[slice_attribute] for g in gold_sqls_dict])): | |
| idx_set = [ | |
| i | |
| for i, g in enumerate(gold_sqls_dict) | |
| if g[slice_attribute] == unq_value | |
| ] | |
| print(f"Processing {unq_value} with {len(idx_set)} samples") | |
| final_metrics[unq_value] = compute_metrics( | |
| gold_sqls=[gold_sqls[i] for i in idx_set], | |
| pred_sqls=[pred_sqls[i] for i in idx_set], | |
| gold_dbs=[gold_dbs[i] for i in idx_set], | |
| setup_sqls=[setup_sqls[i] for i in idx_set], | |
| validate_sqls=[validate_sqls[i] for i in idx_set], | |
| kmaps=kmaps, | |
| db_schemas=db_schemas, | |
| database_dir=db, | |
| lowercase_schema_match=lowercase_schema_match, | |
| model_name=model_name + f"({unq_value})", | |
| categories=[categories[i] for i in idx_set], | |
| ) | |
| to_print.append( | |
| get_to_print(final_metrics, unq_value, model_name, len(idx_set)) | |
| ) | |
| df = pd.DataFrame(to_print) | |
| console.print(df.to_csv(sep=",", index=False)) | |
| console.print("******") | |
| console.print(f"Saved metrics to {Path(output_dir) / output_filename}") | |
| json.dump(final_metrics, open(Path(output_dir) / output_filename, "w"), indent=4) | |
| output_filename = str(output_filename).replace("_eval.json", "_fd.jsonl") | |
| console.print(f"Saved dump to {Path(output_dir) / output_filename}") | |
| with open(Path(output_dir) / output_filename, "w") as f: | |
| for dct in full_results: | |
| f.write(json.dumps(dct) + "\n") | |
| if __name__ == "__main__": | |
| cli() | |