Spaces:
Sleeping
Sleeping
| ################################ | |
| # val: number(float)/string(str)/sql(dict) | |
| # col_unit: (agg_id, col_id, isDistinct(bool)) | |
| # val_unit: (unit_op, col_unit1, col_unit2) | |
| # table_unit: (table_type, col_unit/sql) | |
| # cond_unit: (not_op, op_id, val_unit, val1, val2) | |
| # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] | |
| # sql { | |
| # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) | |
| # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} | |
| # 'where': condition | |
| # 'groupBy': [col_unit1, col_unit2, ...] | |
| # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) | |
| # 'having': condition | |
| # 'limit': None/limit value | |
| # 'intersect': None/sql | |
| # 'except': None/sql | |
| # 'union': None/sql | |
| # } | |
| ################################ | |
| import os | |
| import json | |
| import sqlite3 | |
| import argparse | |
| from .process_sql import get_schema, Schema, get_sql | |
| from .exec_eval import eval_exec_match | |
| # Flag to disable value evaluation | |
| LEVELS = ["easy", "medium", "hard", "duckdb", "ddl", "all"] | |
| TURNS = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] | |
| PARTIAL_TYPES = [ | |
| "select", | |
| "select(no AGG)", | |
| "where", | |
| "where(no OP)", | |
| "group(no Having)", | |
| "group", | |
| "order", | |
| "and/or", | |
| "IUEN", | |
| "keywords", | |
| ] | |
| DISABLE_VALUE = True | |
| # Flag to disable distinct in select evaluation | |
| DISABLE_DISTINCT = True | |
| CLAUSE_KEYWORDS = ( | |
| "select", | |
| "from", | |
| "where", | |
| "group", | |
| "order", | |
| "limit", | |
| "intersect", | |
| "union", | |
| "except", | |
| ) | |
| JOIN_KEYWORDS = ("join", "on", "as") | |
| WHERE_OPS = ( | |
| "not", | |
| "between", | |
| "=", | |
| ">", | |
| "<", | |
| ">=", | |
| "<=", | |
| "!=", | |
| "in", | |
| "like", | |
| "is", | |
| "exists", | |
| ) | |
| UNIT_OPS = ("none", "-", "+", "*", "/") | |
| AGG_OPS = ("none", "max", "min", "count", "sum", "avg") | |
| TABLE_TYPE = { | |
| "sql": "sql", | |
| "table_unit": "table_unit", | |
| } | |
| COND_OPS = ("and", "or") | |
| SQL_OPS = ("intersect", "union", "except") | |
| ORDER_OPS = ("desc", "asc") | |
| HARDNESS = { | |
| "component1": ("where", "group", "order", "limit", "join", "or", "like"), | |
| "component2": ("except", "union", "intersect"), | |
| } | |
| def condition_has_or(conds): | |
| return "or" in conds[1::2] | |
| def condition_has_like(conds): | |
| return WHERE_OPS.index("like") in [cond_unit[1] for cond_unit in conds[::2]] | |
| def condition_has_sql(conds): | |
| for cond_unit in conds[::2]: | |
| val1, val2 = cond_unit[3], cond_unit[4] | |
| if val1 is not None and type(val1) is dict: | |
| return True | |
| if val2 is not None and type(val2) is dict: | |
| return True | |
| return False | |
| def val_has_op(val_unit): | |
| return val_unit[0] != UNIT_OPS.index("none") | |
| def has_agg(unit): | |
| return unit[0] != AGG_OPS.index("none") | |
| def accuracy(count, total): | |
| if count == total: | |
| return 1 | |
| return 0 | |
| def recall(count, total): | |
| if count == total: | |
| return 1 | |
| return 0 | |
| def F1(acc, rec): | |
| if (acc + rec) == 0: | |
| return 0 | |
| return (2.0 * acc * rec) / (acc + rec) | |
| def get_scores(count, pred_total, label_total): | |
| if pred_total != label_total: | |
| return 0, 0, 0 | |
| elif count == pred_total: | |
| return 1, 1, 1 | |
| return 0, 0, 0 | |
| def eval_sel(pred, label): | |
| pred_sel = pred["select"][1] | |
| label_sel = label["select"][1] | |
| label_wo_agg = [unit[1] for unit in label_sel] | |
| pred_total = len(pred_sel) | |
| label_total = len(label_sel) | |
| cnt = 0 | |
| cnt_wo_agg = 0 | |
| for unit in pred_sel: | |
| if unit in label_sel: | |
| cnt += 1 | |
| label_sel.remove(unit) | |
| if unit[1] in label_wo_agg: | |
| cnt_wo_agg += 1 | |
| label_wo_agg.remove(unit[1]) | |
| return label_total, pred_total, cnt, cnt_wo_agg | |
| def eval_where(pred, label): | |
| pred_conds = [unit for unit in pred["where"][::2]] | |
| label_conds = [unit for unit in label["where"][::2]] | |
| label_wo_agg = [unit[2] for unit in label_conds] | |
| pred_total = len(pred_conds) | |
| label_total = len(label_conds) | |
| cnt = 0 | |
| cnt_wo_agg = 0 | |
| for unit in pred_conds: | |
| if unit in label_conds: | |
| cnt += 1 | |
| label_conds.remove(unit) | |
| if unit[2] in label_wo_agg: | |
| cnt_wo_agg += 1 | |
| label_wo_agg.remove(unit[2]) | |
| return label_total, pred_total, cnt, cnt_wo_agg | |
| def eval_group(pred, label): | |
| pred_cols = [unit[1] for unit in pred["groupBy"]] | |
| label_cols = [unit[1] for unit in label["groupBy"]] | |
| pred_total = len(pred_cols) | |
| label_total = len(label_cols) | |
| cnt = 0 | |
| pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] | |
| label_cols = [ | |
| label.split(".")[1] if "." in label else label for label in label_cols | |
| ] | |
| for col in pred_cols: | |
| if col in label_cols: | |
| cnt += 1 | |
| label_cols.remove(col) | |
| return label_total, pred_total, cnt | |
| def eval_having(pred, label): | |
| pred_total = label_total = cnt = 0 | |
| if len(pred["groupBy"]) > 0: | |
| pred_total = 1 | |
| if len(label["groupBy"]) > 0: | |
| label_total = 1 | |
| pred_cols = [unit[1] for unit in pred["groupBy"]] | |
| label_cols = [unit[1] for unit in label["groupBy"]] | |
| if ( | |
| pred_total == label_total == 1 | |
| and pred_cols == label_cols | |
| and pred["having"] == label["having"] | |
| ): | |
| cnt = 1 | |
| return label_total, pred_total, cnt | |
| def eval_order(pred, label): | |
| pred_total = label_total = cnt = 0 | |
| if len(pred["orderBy"]) > 0: | |
| pred_total = 1 | |
| if len(label["orderBy"]) > 0: | |
| label_total = 1 | |
| if ( | |
| len(label["orderBy"]) > 0 | |
| and pred["orderBy"] == label["orderBy"] | |
| and ( | |
| (pred["limit"] is None and label["limit"] is None) | |
| or (pred["limit"] is not None and label["limit"] is not None) | |
| ) | |
| ): | |
| cnt = 1 | |
| return label_total, pred_total, cnt | |
| def eval_and_or(pred, label): | |
| pred_ao = pred["where"][1::2] | |
| label_ao = label["where"][1::2] | |
| pred_ao = set(pred_ao) | |
| label_ao = set(label_ao) | |
| if pred_ao == label_ao: | |
| return 1, 1, 1 | |
| return len(pred_ao), len(label_ao), 0 | |
| def get_nestedSQL(sql): | |
| nested = [] | |
| for cond_unit in sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2]: | |
| if type(cond_unit[3]) is dict: | |
| nested.append(cond_unit[3]) | |
| if type(cond_unit[4]) is dict: | |
| nested.append(cond_unit[4]) | |
| if sql["intersect"] is not None: | |
| nested.append(sql["intersect"]) | |
| if sql["except"] is not None: | |
| nested.append(sql["except"]) | |
| if sql["union"] is not None: | |
| nested.append(sql["union"]) | |
| return nested | |
| def eval_nested(pred, label): | |
| label_total = 0 | |
| pred_total = 0 | |
| cnt = 0 | |
| if pred is not None: | |
| pred_total += 1 | |
| if label is not None: | |
| label_total += 1 | |
| if pred is not None and label is not None: | |
| partial_scores = Evaluator.eval_partial_match(pred, label) | |
| cnt += Evaluator.eval_exact_match(pred, label, partial_scores) | |
| return label_total, pred_total, cnt | |
| def eval_IUEN(pred, label): | |
| lt1, pt1, cnt1 = eval_nested(pred["intersect"], label["intersect"]) | |
| lt2, pt2, cnt2 = eval_nested(pred["except"], label["except"]) | |
| lt3, pt3, cnt3 = eval_nested(pred["union"], label["union"]) | |
| label_total = lt1 + lt2 + lt3 | |
| pred_total = pt1 + pt2 + pt3 | |
| cnt = cnt1 + cnt2 + cnt3 | |
| return label_total, pred_total, cnt | |
| def get_keywords(sql): | |
| res = set() | |
| if len(sql["where"]) > 0: | |
| res.add("where") | |
| if len(sql["groupBy"]) > 0: | |
| res.add("group") | |
| if len(sql["having"]) > 0: | |
| res.add("having") | |
| if len(sql["orderBy"]) > 0: | |
| res.add(sql["orderBy"][0]) | |
| res.add("order") | |
| if sql["limit"] is not None: | |
| res.add("limit") | |
| if sql["except"] is not None: | |
| res.add("except") | |
| if sql["union"] is not None: | |
| res.add("union") | |
| if sql["intersect"] is not None: | |
| res.add("intersect") | |
| # or keyword | |
| ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2] | |
| if len([token for token in ao if token == "or"]) > 0: | |
| res.add("or") | |
| cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2] | |
| # not keyword | |
| if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: | |
| res.add("not") | |
| # in keyword | |
| if ( | |
| len( | |
| [ | |
| cond_unit | |
| for cond_unit in cond_units | |
| if cond_unit[1] == WHERE_OPS.index("in") | |
| ] | |
| ) | |
| > 0 | |
| ): | |
| res.add("in") | |
| # like keyword | |
| if ( | |
| len( | |
| [ | |
| cond_unit | |
| for cond_unit in cond_units | |
| if cond_unit[1] == WHERE_OPS.index("like") | |
| ] | |
| ) | |
| > 0 | |
| ): | |
| res.add("like") | |
| return res | |
| def eval_keywords(pred, label): | |
| pred_keywords = get_keywords(pred) | |
| label_keywords = get_keywords(label) | |
| pred_total = len(pred_keywords) | |
| label_total = len(label_keywords) | |
| cnt = 0 | |
| for k in pred_keywords: | |
| if k in label_keywords: | |
| cnt += 1 | |
| return label_total, pred_total, cnt | |
| def count_agg(units): | |
| return len([unit for unit in units if has_agg(unit)]) | |
| def count_component1(sql): | |
| count = 0 | |
| if len(sql["where"]) > 0: | |
| count += 1 | |
| if len(sql["groupBy"]) > 0: | |
| count += 1 | |
| if len(sql["orderBy"]) > 0: | |
| count += 1 | |
| if sql["limit"] is not None: | |
| count += 1 | |
| if len(sql["from"]["table_units"]) > 0: # JOIN | |
| count += len(sql["from"]["table_units"]) - 1 | |
| ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2] | |
| count += len([token for token in ao if token == "or"]) | |
| cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2] | |
| count += len( | |
| [ | |
| cond_unit | |
| for cond_unit in cond_units | |
| if cond_unit[1] == WHERE_OPS.index("like") | |
| ] | |
| ) | |
| return count | |
| def count_component2(sql): | |
| nested = get_nestedSQL(sql) | |
| return len(nested) | |
| def count_others(sql): | |
| count = 0 | |
| # number of aggregation | |
| agg_count = count_agg(sql["select"][1]) | |
| agg_count += count_agg(sql["where"][::2]) | |
| agg_count += count_agg(sql["groupBy"]) | |
| if len(sql["orderBy"]) > 0: | |
| agg_count += count_agg( | |
| [unit[1] for unit in sql["orderBy"][1] if unit[1]] | |
| + [unit[2] for unit in sql["orderBy"][1] if unit[2]] | |
| ) | |
| agg_count += count_agg(sql["having"]) | |
| if agg_count > 1: | |
| count += 1 | |
| # number of select columns | |
| if len(sql["select"][1]) > 1: | |
| count += 1 | |
| # number of where conditions | |
| if len(sql["where"]) > 1: | |
| count += 1 | |
| # number of group by clauses | |
| if len(sql["groupBy"]) > 1: | |
| count += 1 | |
| return count | |
| class Evaluator: | |
| """A simple evaluator""" | |
| def __init__( | |
| self, | |
| db_dir, | |
| kmaps, | |
| etype, | |
| plug_value, | |
| keep_distinct, | |
| progress_bar_for_each_datapoint | |
| ): | |
| self.db_dir = db_dir | |
| self.kmaps = kmaps | |
| self.etype = etype | |
| self.plug_value = plug_value | |
| self.keep_distinct = keep_distinct | |
| self.progress_bar_for_each_datapoint = progress_bar_for_each_datapoint | |
| self.db_paths = {} | |
| self.schemas = {} | |
| self.scores = {} | |
| for turn in TURNS: | |
| self.scores[turn] = {"count": 0, "exact": 0.0} | |
| self.scores[turn]["exec"] = 0 | |
| for level in LEVELS: | |
| self.scores[level] = {"count": 0, "partial": {}, "exact": 0.0} | |
| self.scores[level]["exec"] = 0 | |
| for type_ in PARTIAL_TYPES: | |
| self.scores[level]["partial"][type_] = { | |
| "acc": 0.0, | |
| "rec": 0.0, | |
| "f1": 0.0, | |
| "acc_count": 0, | |
| "rec_count": 0, | |
| } | |
| def eval_hardness(self, sql): | |
| count_comp1_ = count_component1(sql) | |
| count_comp2_ = count_component2(sql) | |
| count_others_ = count_others(sql) | |
| if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: | |
| return "easy" | |
| elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or ( | |
| count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0 | |
| ): | |
| return "medium" | |
| elif ( | |
| (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) | |
| or (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) | |
| or (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1) | |
| ): | |
| return "hard" | |
| else: | |
| return "extra" | |
| def eval_exact_match(cls, pred, label, partial_scores): | |
| for key, score in partial_scores.items(): | |
| if score["f1"] != 1: | |
| return 0 | |
| if len(label["from"]["table_units"]) > 0: | |
| label_tables = sorted(label["from"]["table_units"]) | |
| pred_tables = sorted(pred["from"]["table_units"]) | |
| return label_tables == pred_tables | |
| return 1 | |
| def eval_partial_match(cls, pred, label): | |
| res = {} | |
| label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) | |
| acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
| res["select"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) | |
| res["select(no AGG)"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) | |
| acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
| res["where"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) | |
| res["where(no OP)"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| label_total, pred_total, cnt = eval_group(pred, label) | |
| acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
| res["group(no Having)"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| label_total, pred_total, cnt = eval_having(pred, label) | |
| acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
| res["group"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| label_total, pred_total, cnt = eval_order(pred, label) | |
| acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
| res["order"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| label_total, pred_total, cnt = eval_and_or(pred, label) | |
| acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
| res["and/or"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| label_total, pred_total, cnt = eval_IUEN(pred, label) | |
| acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
| res["IUEN"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| label_total, pred_total, cnt = eval_keywords(pred, label) | |
| acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
| res["keywords"] = { | |
| "acc": acc, | |
| "rec": rec, | |
| "f1": f1, | |
| "label_total": label_total, | |
| "pred_total": pred_total, | |
| } | |
| return res | |
| def evaluate_one(self, db_name, gold, predicted, setup_sql, | |
| validate_sql, turn_scores, idx, category): | |
| if db_name not in self.db_paths: | |
| db_path = os.path.join(self.db_dir, db_name, db_name + ".duckdb") | |
| self.db_paths[db_name] = db_path | |
| self.schemas[db_name] = Schema(get_schema(db_path)) | |
| if idx > 3: | |
| idx = "> 4" | |
| else: | |
| idx += 1 | |
| turn_id = "turn " + str(idx) | |
| hardness = category | |
| self.scores[turn_id]["count"] += 1 | |
| self.scores[hardness]["count"] += 1 | |
| self.scores["all"]["count"] += 1 | |
| if self.etype in ['all', 'match']: | |
| schema = self.schemas[db_name] | |
| g_sql = get_sql(schema, gold) | |
| self.scores[hardness]["count"] += 1 | |
| try: | |
| p_sql = get_sql(schema, predicted) | |
| except: | |
| # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql | |
| p_sql = { | |
| "except": None, | |
| "from": {"conds": [], "table_units": []}, | |
| "groupBy": [], | |
| "having": [], | |
| "intersect": None, | |
| "limit": None, | |
| "orderBy": [], | |
| "select": [False, []], | |
| "union": None, | |
| "where": [], | |
| } | |
| if self.etype in ["all", "exec"]: | |
| exec_score = eval_exec_match( | |
| db=self.db_paths[db_name], | |
| p_str=predicted, | |
| g_str=gold, | |
| setup_sql=setup_sql, | |
| validate_sql=validate_sql, | |
| plug_value=self.plug_value, | |
| keep_distinct=self.keep_distinct, | |
| progress_bar_for_each_datapoint=self.progress_bar_for_each_datapoint, | |
| ) | |
| if exec_score: | |
| self.scores[hardness]["exec"] += 1 | |
| self.scores[turn_id]["exec"] += 1 | |
| self.scores["all"]["exec"] += 1 | |
| turn_scores["exec"].append(1) | |
| else: | |
| turn_scores["exec"].append(0) | |
| if self.etype in ["all", "match"]: | |
| # rebuild sql for value evaluation | |
| kmap = self.kmaps[db_name] | |
| g_valid_col_units = build_valid_col_units( | |
| g_sql["from"]["table_units"], schema | |
| ) | |
| g_sql = rebuild_sql_val(g_sql) | |
| g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) | |
| p_valid_col_units = build_valid_col_units( | |
| p_sql["from"]["table_units"], schema | |
| ) | |
| p_sql = rebuild_sql_val(p_sql) | |
| p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) | |
| partial_scores = self.eval_partial_match(p_sql, g_sql) | |
| exact_score = self.eval_exact_match(p_sql, g_sql, partial_scores) | |
| if exact_score == 0: | |
| turn_scores["exact"].append(0) | |
| print("{} pred: {}".format(hardness, predicted)) | |
| print("{} gold: {}".format(hardness, gold)) | |
| print("") | |
| else: | |
| turn_scores["exact"].append(1) | |
| self.scores[turn_id]["exact"] += exact_score | |
| self.scores[hardness]["exact"] += exact_score | |
| self.scores["all"]["exact"] += exact_score | |
| for type_ in PARTIAL_TYPES: | |
| if partial_scores[type_]["pred_total"] > 0: | |
| self.scores[hardness]["partial"][type_]["acc"] += partial_scores[ | |
| type_ | |
| ]["acc"] | |
| self.scores[hardness]["partial"][type_]["acc_count"] += 1 | |
| if partial_scores[type_]["label_total"] > 0: | |
| self.scores[hardness]["partial"][type_]["rec"] += partial_scores[ | |
| type_ | |
| ]["rec"] | |
| self.scores[hardness]["partial"][type_]["rec_count"] += 1 | |
| self.scores[hardness]["partial"][type_]["f1"] += partial_scores[type_][ | |
| "f1" | |
| ] | |
| if partial_scores[type_]["pred_total"] > 0: | |
| self.scores["all"]["partial"][type_]["acc"] += partial_scores[type_][ | |
| "acc" | |
| ] | |
| self.scores["all"]["partial"][type_]["acc_count"] += 1 | |
| if partial_scores[type_]["label_total"] > 0: | |
| self.scores["all"]["partial"][type_]["rec"] += partial_scores[type_][ | |
| "rec" | |
| ] | |
| self.scores["all"]["partial"][type_]["rec_count"] += 1 | |
| self.scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] | |
| result = { | |
| "predictSQL": predicted, | |
| "goldSQL": gold, | |
| } | |
| if self.etype in ['all', 'match']: | |
| result.update({ | |
| "hardness": hardness, | |
| "exact": exact_score, | |
| "partial": partial_scores, | |
| }) | |
| if self.etype in ['all', 'exec']: | |
| result['exec'] = exec_score | |
| return result | |
| def finalize(self): | |
| scores = self.scores | |
| for turn in TURNS: | |
| if scores[turn]["count"] == 0: | |
| continue | |
| if self.etype in ["all", "exec"]: | |
| scores[turn]["exec"] /= scores[turn]["count"] | |
| if self.etype in ["all", "match"]: | |
| scores[turn]["exact"] /= scores[turn]["count"] | |
| for level in LEVELS: | |
| if scores[level]["count"] == 0: | |
| continue | |
| if self.etype in ["all", "exec"]: | |
| scores[level]["exec"] /= scores[level]["count"] | |
| if self.etype in ["all", "match"]: | |
| scores[level]["exact"] /= scores[level]["count"] | |
| for type_ in PARTIAL_TYPES: | |
| if scores[level]["partial"][type_]["acc_count"] == 0: | |
| scores[level]["partial"][type_]["acc"] = 0 | |
| else: | |
| scores[level]["partial"][type_]["acc"] = ( | |
| scores[level]["partial"][type_]["acc"] | |
| / scores[level]["partial"][type_]["acc_count"] | |
| * 1.0 | |
| ) | |
| if scores[level]["partial"][type_]["rec_count"] == 0: | |
| scores[level]["partial"][type_]["rec"] = 0 | |
| else: | |
| scores[level]["partial"][type_]["rec"] = ( | |
| scores[level]["partial"][type_]["rec"] | |
| / scores[level]["partial"][type_]["rec_count"] | |
| * 1.0 | |
| ) | |
| if ( | |
| scores[level]["partial"][type_]["acc"] == 0 | |
| and scores[level]["partial"][type_]["rec"] == 0 | |
| ): | |
| scores[level]["partial"][type_]["f1"] = 1 | |
| else: | |
| scores[level]["partial"][type_]["f1"] = ( | |
| 2.0 | |
| * scores[level]["partial"][type_]["acc"] | |
| * scores[level]["partial"][type_]["rec"] | |
| / ( | |
| scores[level]["partial"][type_]["rec"] | |
| + scores[level]["partial"][type_]["acc"] | |
| ) | |
| ) | |
| def isValidSQL(sql, db): | |
| conn = sqlite3.connect(db) | |
| cursor = conn.cursor() | |
| try: | |
| cursor.execute(sql) | |
| except: | |
| return False | |
| return True | |
| def print_formated_s(row_name, l, element_format): | |
| template = "{:20} " + " ".join([element_format] * len(l)) | |
| print(template.format(row_name, *l)) | |
| def print_scores(scores, etype, include_turn_acc=True): | |
| turns = TURNS | |
| levels = ["easy", "medium", "hard", "duckdb", "ddl", "all"] | |
| if include_turn_acc: | |
| levels.append("joint_all") | |
| partial_types = PARTIAL_TYPES | |
| print_formated_s("", levels, "{:20}") | |
| counts = [scores[level]["count"] for level in levels] | |
| print_formated_s("count", counts, "{:<20d}") | |
| if etype in ["all", "exec"]: | |
| print("===================== EXECUTION ACCURACY =====================") | |
| exec_scores = [scores[level]["exec"] for level in levels] | |
| print_formated_s("execution", exec_scores, "{:<20.3f}") | |
| if etype in ["all", "match"]: | |
| print("\n====================== EXACT MATCHING ACCURACY =====================") | |
| exact_scores = [scores[level]["exact"] for level in levels] | |
| print_formated_s("exact match", exact_scores, "{:<20.3f}") | |
| print("\n---------------------PARTIAL MATCHING ACCURACY----------------------") | |
| for type_ in partial_types: | |
| this_scores = [scores[level]["partial"][type_]["acc"] for level in levels] | |
| print_formated_s(type_, this_scores, "{:<20.3f}") | |
| print("---------------------- PARTIAL MATCHING RECALL ----------------------") | |
| for type_ in partial_types: | |
| this_scores = [scores[level]["partial"][type_]["rec"] for level in levels] | |
| print_formated_s(type_, this_scores, "{:<20.3f}") | |
| print("---------------------- PARTIAL MATCHING F1 --------------------------") | |
| for type_ in partial_types: | |
| this_scores = [scores[level]["partial"][type_]["f1"] for level in levels] | |
| print_formated_s(type_, this_scores, "{:<20.3f}") | |
| if include_turn_acc: | |
| print() | |
| print() | |
| print_formated_s("", turns, "{:20}") | |
| counts = [scores[turn]["count"] for turn in turns] | |
| print_formated_s("count", counts, "{:<20d}") | |
| if etype in ["all", "exec"]: | |
| print( | |
| "===================== TURN EXECUTION ACCURACY =====================" | |
| ) | |
| exec_scores = [scores[turn]["exec"] for turn in turns] | |
| print_formated_s("execution", exec_scores, "{:<20.3f}") | |
| if etype in ["all", "match"]: | |
| print( | |
| "\n====================== TURN EXACT MATCHING ACCURACY =====================" | |
| ) | |
| exact_scores = [scores[turn]["exact"] for turn in turns] | |
| print_formated_s("exact match", exact_scores, "{:<20.3f}") | |
| def evaluate( | |
| gold, | |
| predict, | |
| db_dir, | |
| etype, | |
| kmaps, | |
| plug_value, | |
| keep_distinct, | |
| progress_bar_for_each_datapoint, | |
| ): | |
| with open(gold) as f: | |
| glist = [] | |
| gseq_one = [] | |
| for l in f.readlines(): | |
| if len(l.strip()) == 0: | |
| glist.append(gseq_one) | |
| gseq_one = [] | |
| else: | |
| lstrip = l.strip().split("\t") | |
| gseq_one.append(lstrip) | |
| # include the last session | |
| # this was previously ignored in the SParC evaluation script | |
| # which might lead to slight differences in scores | |
| if len(gseq_one) != 0: | |
| glist.append(gseq_one) | |
| # spider formatting indicates that there is only one "single turn" | |
| # do not report "turn accuracy" for SPIDER | |
| include_turn_acc = len(glist) > 1 | |
| with open(predict) as f: | |
| plist = [] | |
| pseq_one = [] | |
| for l in f.readlines(): | |
| if len(l.strip()) == 0: | |
| plist.append(pseq_one) | |
| pseq_one = [] | |
| else: | |
| pseq_one.append(l.strip().split("\t")) | |
| if len(pseq_one) != 0: | |
| plist.append(pseq_one) | |
| assert len(plist) == len(glist), "number of sessions must equal" | |
| evaluator = Evaluator(db_dir, kmaps, etype, plug_value, keep_distinct, progress_bar_for_each_datapoint) | |
| results = [] | |
| for i, (p, g) in enumerate(zip(plist, glist)): | |
| if (i + 1) % 10 == 0: | |
| print("Evaluating %dth prediction" % (i + 1)) | |
| evaluator.scores["joint_all"]["count"] += 1 | |
| turn_scores = {"exec": [], "exact": []} | |
| for idx, pg in enumerate(zip(p, g)): | |
| p, g = pg | |
| p_str = p[0] | |
| p_str = p_str.replace("value", "1") | |
| g_str, db_name = g | |
| results.append(evaluator.evaluate_one(db_name, g_str, p_str, "", "", turn_scores, idx, "")) | |
| if all(v == 1 for v in turn_scores["exec"]): | |
| evaluator.scores["joint_all"]["exec"] += 1 | |
| if all(v == 1 for v in turn_scores["exact"]): | |
| evaluator.scores["joint_all"]["exact"] += 1 | |
| evaluator.finalize() | |
| print_scores(evaluator.scores, etype, include_turn_acc=include_turn_acc) | |
| return { | |
| "per_item": results, | |
| "total_scores": evaluator.scores | |
| } | |
| # Rebuild SQL functions for value evaluation | |
| def rebuild_cond_unit_val(cond_unit): | |
| if cond_unit is None or not DISABLE_VALUE: | |
| return cond_unit | |
| not_op, op_id, val_unit, val1, val2 = cond_unit | |
| if type(val1) is not dict: | |
| val1 = None | |
| else: | |
| val1 = rebuild_sql_val(val1) | |
| if type(val2) is not dict: | |
| val2 = None | |
| else: | |
| val2 = rebuild_sql_val(val2) | |
| return not_op, op_id, val_unit, val1, val2 | |
| def rebuild_condition_val(condition): | |
| if condition is None or not DISABLE_VALUE: | |
| return condition | |
| res = [] | |
| for idx, it in enumerate(condition): | |
| if idx % 2 == 0: | |
| res.append(rebuild_cond_unit_val(it)) | |
| else: | |
| res.append(it) | |
| return res | |
| def rebuild_sql_val(sql): | |
| if sql is None or not DISABLE_VALUE: | |
| return sql | |
| sql["from"]["conds"] = rebuild_condition_val(sql["from"]["conds"]) | |
| sql["having"] = rebuild_condition_val(sql["having"]) | |
| sql["where"] = rebuild_condition_val(sql["where"]) | |
| sql["intersect"] = rebuild_sql_val(sql["intersect"]) | |
| sql["except"] = rebuild_sql_val(sql["except"]) | |
| sql["union"] = rebuild_sql_val(sql["union"]) | |
| return sql | |
| # Rebuild SQL functions for foreign key evaluation | |
| def build_valid_col_units(table_units, schema): | |
| col_ids = [ | |
| table_unit[1] | |
| for table_unit in table_units | |
| if table_unit[0] == TABLE_TYPE["table_unit"] | |
| ] | |
| prefixs = [col_id[:-2] for col_id in col_ids] | |
| valid_col_units = [] | |
| for value in schema.idMap.values(): | |
| if "." in value and value[: value.index(".")] in prefixs: | |
| valid_col_units.append(value) | |
| return valid_col_units | |
| def rebuild_col_unit_col(valid_col_units, col_unit, kmap): | |
| if col_unit is None: | |
| return col_unit | |
| agg_id, col_id, distinct = col_unit | |
| if col_id in kmap and col_id in valid_col_units: | |
| col_id = kmap[col_id] | |
| if DISABLE_DISTINCT: | |
| distinct = None | |
| return agg_id, col_id, distinct | |
| def rebuild_val_unit_col(valid_col_units, val_unit, kmap): | |
| if val_unit is None: | |
| return val_unit | |
| unit_op, col_unit1, col_unit2 = val_unit | |
| col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) | |
| col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) | |
| return unit_op, col_unit1, col_unit2 | |
| def rebuild_table_unit_col(valid_col_units, table_unit, kmap): | |
| if table_unit is None: | |
| return table_unit | |
| table_type, col_unit_or_sql = table_unit | |
| if isinstance(col_unit_or_sql, tuple): | |
| col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) | |
| return table_type, col_unit_or_sql | |
| def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): | |
| if cond_unit is None: | |
| return cond_unit | |
| not_op, op_id, val_unit, val1, val2 = cond_unit | |
| val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) | |
| return not_op, op_id, val_unit, val1, val2 | |
| def rebuild_condition_col(valid_col_units, condition, kmap): | |
| for idx in range(len(condition)): | |
| if idx % 2 == 0: | |
| condition[idx] = rebuild_cond_unit_col( | |
| valid_col_units, condition[idx], kmap | |
| ) | |
| return condition | |
| def rebuild_select_col(valid_col_units, sel, kmap): | |
| if sel is None: | |
| return sel | |
| distinct, _list = sel | |
| new_list = [] | |
| for it in _list: | |
| agg_id, val_unit = it | |
| new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) | |
| if DISABLE_DISTINCT: | |
| distinct = None | |
| return distinct, new_list | |
| def rebuild_from_col(valid_col_units, from_, kmap): | |
| if from_ is None: | |
| return from_ | |
| from_["table_units"] = [ | |
| rebuild_table_unit_col(valid_col_units, table_unit, kmap) | |
| for table_unit in from_["table_units"] | |
| ] | |
| from_["conds"] = rebuild_condition_col(valid_col_units, from_["conds"], kmap) | |
| return from_ | |
| def rebuild_group_by_col(valid_col_units, group_by, kmap): | |
| if group_by is None: | |
| return group_by | |
| return [ | |
| rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by | |
| ] | |
| def rebuild_order_by_col(valid_col_units, order_by, kmap): | |
| if order_by is None or len(order_by) == 0: | |
| return order_by | |
| direction, val_units = order_by | |
| new_val_units = [ | |
| rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units | |
| ] | |
| return direction, new_val_units | |
| def rebuild_sql_col(valid_col_units, sql, kmap): | |
| if sql is None: | |
| return sql | |
| sql["select"] = rebuild_select_col(valid_col_units, sql["select"], kmap) | |
| sql["from"] = rebuild_from_col(valid_col_units, sql["from"], kmap) | |
| sql["where"] = rebuild_condition_col(valid_col_units, sql["where"], kmap) | |
| sql["groupBy"] = rebuild_group_by_col(valid_col_units, sql["groupBy"], kmap) | |
| sql["orderBy"] = rebuild_order_by_col(valid_col_units, sql["orderBy"], kmap) | |
| sql["having"] = rebuild_condition_col(valid_col_units, sql["having"], kmap) | |
| sql["intersect"] = rebuild_sql_col(valid_col_units, sql["intersect"], kmap) | |
| sql["except"] = rebuild_sql_col(valid_col_units, sql["except"], kmap) | |
| sql["union"] = rebuild_sql_col(valid_col_units, sql["union"], kmap) | |
| return sql | |
| def build_foreign_key_map(entry): | |
| cols_orig = entry["column_names_original"] | |
| tables_orig = entry["table_names_original"] | |
| # rebuild cols corresponding to idmap in Schema | |
| cols = [] | |
| for col_orig in cols_orig: | |
| if col_orig[0] >= 0: | |
| t = tables_orig[col_orig[0]] | |
| c = col_orig[1] | |
| cols.append("__" + t.lower() + "." + c.lower() + "__") | |
| else: | |
| cols.append("__all__") | |
| def keyset_in_list(k1, k2, k_list): | |
| for k_set in k_list: | |
| if k1 in k_set or k2 in k_set: | |
| return k_set | |
| new_k_set = set() | |
| k_list.append(new_k_set) | |
| return new_k_set | |
| foreign_key_list = [] | |
| foreign_keys = entry["foreign_keys"] | |
| for fkey in foreign_keys: | |
| key1, key2 = fkey | |
| key_set = keyset_in_list(key1, key2, foreign_key_list) | |
| key_set.add(key1) | |
| key_set.add(key2) | |
| foreign_key_map = {} | |
| for key_set in foreign_key_list: | |
| sorted_list = sorted(list(key_set)) | |
| midx = sorted_list[0] | |
| for idx in sorted_list: | |
| foreign_key_map[cols[idx]] = cols[midx] | |
| return foreign_key_map | |
| def build_foreign_key_map_from_json(table): | |
| with open(table) as f: | |
| data = json.load(f) | |
| tables = {} | |
| for entry in data: | |
| tables[entry["db_id"]] = build_foreign_key_map(entry) | |
| return tables | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--gold", dest="gold", type=str, help="the path to the gold queries" | |
| ) | |
| parser.add_argument( | |
| "--pred", dest="pred", type=str, help="the path to the predicted queries" | |
| ) | |
| parser.add_argument( | |
| "--db", | |
| dest="db", | |
| type=str, | |
| help="the directory that contains all the databases and test suites", | |
| ) | |
| parser.add_argument( | |
| "--table", dest="table", type=str, help="the tables.json schema file" | |
| ) | |
| parser.add_argument( | |
| "--etype", | |
| dest="etype", | |
| type=str, | |
| default="exec", | |
| help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", | |
| choices=("all", "exec", "match"), | |
| ) | |
| parser.add_argument( | |
| "--plug_value", | |
| default=False, | |
| action="store_true", | |
| help="whether to plug in the gold value into the predicted query; suitable if your model does not predict values.", | |
| ) | |
| parser.add_argument( | |
| "--keep_distinct", | |
| default=False, | |
| action="store_true", | |
| help="whether to keep distinct keyword during evaluation. default is false.", | |
| ) | |
| parser.add_argument( | |
| "--progress_bar_for_each_datapoint", | |
| default=False, | |
| action="store_true", | |
| help="whether to print progress bar of running test inputs for each datapoint", | |
| ) | |
| args = parser.parse_args() | |
| # only evaluting exact match needs this argument | |
| kmaps = None | |
| if args.etype in ["all", "match"]: | |
| assert ( | |
| args.table is not None | |
| ), "table argument must be non-None if exact set match is evaluated" | |
| kmaps = build_foreign_key_map_from_json(args.table) | |
| evaluate( | |
| args.gold, | |
| args.pred, | |
| args.db, | |
| args.etype, | |
| kmaps, | |
| args.plug_value, | |
| args.keep_distinct, | |
| args.progress_bar_for_each_datapoint, | |
| ) | |