Spaces:
Runtime error
Runtime error
| import Levenshtein | |
| import re | |
| from collections import Counter | |
| #Get columns in query (input_ string; output: list) | |
| def get_columns_name_in_query(query): | |
| cols_from_select = get_cols_name_for_select(query) | |
| cols_from_where = get_cols_name_for_where(query) | |
| return list(set(cols_from_select + cols_from_where)) | |
| # Sometime column name maybe ill-defined. This function replace weird chars with underscore (input:list; output:string) | |
| def replace_nonalphanumeric_chars_with_us(l): | |
| well_defined = [re.sub('[^0-9a-zA-Z]+', '_', s) for s in l] | |
| return well_defined | |
| # Adjust column name using columns name from original table (input: column name in SQL query (string), | |
| #list of columns names from table (string); output: corrected column name (if needed) (string)) | |
| def adjust_col_name(col_name, columns_available): | |
| columns_available = [x.upper() for x in columns_available] | |
| if col_name.upper() in set(columns_available): | |
| return col_name | |
| else: | |
| max = -100 | |
| most_similar_column = 'column123456789011' | |
| for col in columns_available: | |
| score = -Levenshtein.distance(col_name, col) | |
| if score > max: | |
| most_similar_column = col | |
| max = score | |
| return most_similar_column | |
| def min_positive(a,b): | |
| if (b < a) and (b > 0): return b | |
| else: return a | |
| #Return corrected syntax for aggregator operators (input: string; output: string) | |
| #USE only for wikisql dataset | |
| def aggregator_parser(query): | |
| query = query.upper() | |
| if query.find('SELECT MAX') > -1: | |
| end = min_positive(query.find('FROM'), query.find(',')) | |
| adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ') | |
| return adjusted_query | |
| elif query.find('SELECT COUNT') > -1: | |
| end = min_positive(query.find('FROM'), query.find(',')) | |
| adjusted_query = query.replace(query[12:end],'(' + query[13:end-1] + ') ') | |
| return adjusted_query | |
| elif query.find('SELECT MIN') > -1: | |
| end = min_positive(query.find('FROM'), query.find(',')) | |
| adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ') | |
| return adjusted_query | |
| #elif query.find('SELECT DISTINCT') > -1: | |
| #end = query.find('FROM') | |
| #adjusted_query = query.replace(query[15:end],'(' + query[16:end-1] + ') ') | |
| #return adjusted_query | |
| else: | |
| return query | |
| #Return columns name from SELECT operator (input: string; output: list) | |
| def get_cols_name_for_select(query): | |
| query = query.upper() | |
| if query.find('SELECT DISTINCT') > -1: | |
| end = query.find('FROM') | |
| cols = query[15:end-1].split(',') | |
| elif query.find('SELECT MAX') > -1: | |
| end = query.find('FROM') | |
| cols = query[10:end-1].split(',') | |
| elif query.find('SELECT MIN') > -1: | |
| end = query.find('FROM') | |
| cols = query[10:end-1].split(',') | |
| elif query.find('SELECT COUNT') > -1: | |
| end = query.find('FROM') | |
| cols = query[13:end-1].split(',') | |
| elif query.find('SELECT') > -1: | |
| end = query.find('FROM') | |
| cols = query[7:end-1].split(',') | |
| else: | |
| cols = [''] | |
| return [x.replace(' ','').replace(')','').replace('(','').upper() for x in cols] | |
| def get_indexes(l): | |
| ops = [] | |
| idx = [] | |
| for i in range(len(l)): | |
| if l[i] in ['=', '>', '<', '>=', '<=', '<>', 'LIKE', 'AND', 'OR']: | |
| idx.append(i) | |
| return idx | |
| def add_spaces_cmp_operators(string): | |
| ops = ['=', '>', '<', '>=', '<=', '<>'] | |
| for op in ops: | |
| string = string.replace(op, ' ' + op + ' ') | |
| return ' '.join(string.split()) | |
| #Check if string and add quotes (input: string; output: string) | |
| #USE only for wikisql dataset | |
| def add_quotes_to_string(query): | |
| query = query.upper() | |
| if query.find('WHERE') > 0: | |
| query_list = query.split(' ') | |
| query_list = [x.replace(' ','') for x in query_list] | |
| query_list[:] = filter(None, query_list) | |
| idx_list = get_indexes(query_list) | |
| idx_list.append(len(query_list)) | |
| subs = [] | |
| for i in range(len(idx_list)): | |
| if i % 2 == 0: | |
| b = idx_list[i] + 1 | |
| e = idx_list[i+1] - 1 | |
| if b != e: | |
| s = '' | |
| for ix in range(b,e + 1): | |
| s = s + query_list[ix] + ' ' | |
| s = s[:-1] | |
| else: | |
| s = query_list[b] | |
| if not(s.isnumeric()): | |
| s = "'" + s + "'" | |
| subs.append((idx_list[i] + 1, idx_list[i+1] - 1, s)) | |
| subs = subs[::-1] | |
| for i in range(len(subs)): | |
| e = subs[i] | |
| if e[0] == e[1]: | |
| query_list[e[0]] = e[2] | |
| else: | |
| query_list[e[0]] = e[2] | |
| idx = e[1] | |
| while idx > e[0]: | |
| query_list.pop(idx) | |
| idx = idx - 1 | |
| final_query = '' | |
| for word in query_list: | |
| final_query = final_query + word + ' ' | |
| return final_query[:-1] | |
| else: | |
| return query | |
| #Get values from where clause (input: string; output: list) | |
| def get_values_for_query_filter(query): | |
| query = query.upper() | |
| if query.find('WHERE') > 0: | |
| query_list = query.split(' ') | |
| query_list = [x.replace(' ','') for x in query_list] | |
| query_list[:] = filter(None, query_list) | |
| idx_list = get_indexes(query_list) | |
| idx_list.append(len(query_list)) | |
| subs = [] | |
| for i in range(len(idx_list)): | |
| if i % 2 == 0: | |
| b = idx_list[i] + 1 | |
| e = idx_list[i+1] - 1 | |
| if b != e: | |
| s = '' | |
| for ix in range(b,e + 1): | |
| s = s + query_list[ix] + ' ' | |
| s = s[:-1] | |
| else: | |
| s = query_list[b] | |
| subs.append(s.replace("'","")) | |
| return subs | |
| # Get columns name after where (input: string, output: list) | |
| def get_cols_name_for_where(query): | |
| query = query.upper() | |
| subs = [] | |
| if query.find('WHERE') > 0: | |
| query_list = query.split(' ') | |
| query_list = [x.replace(' ','') for x in query_list] | |
| query_list[:] = filter(None, query_list) | |
| idx_list = get_indexes(query_list) | |
| #idx_list.append(len(query_list)) | |
| idx_list.insert(0, query_list.index('WHERE')) | |
| for i in range(len(idx_list)-1): | |
| if i % 2 == 0: | |
| b = idx_list[i] + 1 | |
| e = idx_list[i+1] - 1 | |
| if b != e: | |
| s = '' | |
| for ix in range(b,e + 1): | |
| s = s + query_list[ix] + ' ' | |
| s = s[:-1] | |
| else: | |
| s = query_list[b] | |
| subs.append(s) | |
| return subs | |
| def check_if_number(s): | |
| try: | |
| a = float(s) | |
| return True | |
| except: | |
| return False | |
| #Correct missing compare operator (input: string; output: string) | |
| #T5 seems to have problem with '<' operator so if there is none this is used. | |
| def check_if_correct_cmp_operators(query): | |
| query = query.upper() | |
| if query.find('WHERE') > 0: | |
| query = add_spaces_cmp_operators(query) | |
| query_list = query.split(' ') | |
| w = query_list.index('WHERE') | |
| cmp_operators = ['=', '>', '<', '>=', '<=', '<>', 'LIKE'] | |
| s = 0 | |
| for op in cmp_operators: | |
| s = s + query_list.count(op) | |
| if s == 0: | |
| if check_if_number(query_list[-1]): | |
| query_list.insert(len(query_list)-1,'<') | |
| else: | |
| query_list.insert(len(query_list)-1,'=') | |
| return ' '.join(query_list) | |
| else: | |
| return query | |
| else: return query | |
| #Correct SQL syntax using info from table (input: string, list; output:string) | |
| #Use only for wikisql dataset | |
| def correct_query(query, table_columns): | |
| query = check_if_correct_cmp_operators(query) | |
| query = add_spaces_cmp_operators(query) | |
| #try: | |
| query = aggregator_parser(query) | |
| #except: pass | |
| #try: | |
| query = add_quotes_to_string(query) | |
| #except: pass | |
| #try: | |
| cols_name = get_columns_name_in_query(query) | |
| for col in cols_name: | |
| corrected_col = adjust_col_name(col, table_columns) | |
| query = query.replace(col, corrected_col) | |
| #except: pass | |
| return query | |
| #Correct mispelling using info from question (input: string, string; output:string) | |
| def correct_mispelling(question, query): | |
| query = query.upper() | |
| if query.find('WHERE') > 0: | |
| question = question.upper() | |
| corrections = [] | |
| values = get_values_for_query_filter(query) | |
| for value in values: | |
| l = len(value.split(' ')) | |
| tokens = question.replace(' ', ' ').split(' ') | |
| l_gram = '' | |
| max = -100 | |
| for i in range(0, len(tokens)-l+1, 1): | |
| filter = ' '.join(tokens[i:i+l]).strip('.,?') | |
| #filter = re.sub(r"[,.;@#?!&$]+\ *", " ", filter).strip() | |
| score = -Levenshtein.distance(value, filter) | |
| if score > max: | |
| max = score | |
| correct_filter = filter | |
| corrections.append([value, correct_filter]) | |
| for corr in corrections: | |
| query = query.replace(corr[0], corr[1]) | |
| return query | |
| #Get both text and tables in separate string (input:string; output: string, list) | |
| def get_text_and_cols(text): | |
| out = text.split(' table ID: ') | |
| return out[0], out[1].split(',') |