import Levenshtein import re from collections import Counter #Get columns in query (input_ string; output: list) def get_columns_name_in_query(query): cols_from_select = get_cols_name_for_select(query) cols_from_where = get_cols_name_for_where(query) return list(set(cols_from_select + cols_from_where)) # Sometime column name maybe ill-defined. This function replace weird chars with underscore (input:list; output:string) def replace_nonalphanumeric_chars_with_us(l): well_defined = [re.sub('[^0-9a-zA-Z]+', '_', s) for s in l] return well_defined # Adjust column name using columns name from original table (input: column name in SQL query (string), #list of columns names from table (string); output: corrected column name (if needed) (string)) def adjust_col_name(col_name, columns_available): columns_available = [x.upper() for x in columns_available] if col_name.upper() in set(columns_available): return col_name else: max = -100 most_similar_column = 'column123456789011' for col in columns_available: score = -Levenshtein.distance(col_name, col) if score > max: most_similar_column = col max = score return most_similar_column def min_positive(a,b): if (b < a) and (b > 0): return b else: return a #Return corrected syntax for aggregator operators (input: string; output: string) #USE only for wikisql dataset def aggregator_parser(query): query = query.upper() if query.find('SELECT MAX') > -1: end = min_positive(query.find('FROM'), query.find(',')) adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ') return adjusted_query elif query.find('SELECT COUNT') > -1: end = min_positive(query.find('FROM'), query.find(',')) adjusted_query = query.replace(query[12:end],'(' + query[13:end-1] + ') ') return adjusted_query elif query.find('SELECT MIN') > -1: end = min_positive(query.find('FROM'), query.find(',')) adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ') return adjusted_query #elif query.find('SELECT DISTINCT') > -1: #end = query.find('FROM') #adjusted_query = query.replace(query[15:end],'(' + query[16:end-1] + ') ') #return adjusted_query else: return query #Return columns name from SELECT operator (input: string; output: list) def get_cols_name_for_select(query): query = query.upper() if query.find('SELECT DISTINCT') > -1: end = query.find('FROM') cols = query[15:end-1].split(',') elif query.find('SELECT MAX') > -1: end = query.find('FROM') cols = query[10:end-1].split(',') elif query.find('SELECT MIN') > -1: end = query.find('FROM') cols = query[10:end-1].split(',') elif query.find('SELECT COUNT') > -1: end = query.find('FROM') cols = query[13:end-1].split(',') elif query.find('SELECT') > -1: end = query.find('FROM') cols = query[7:end-1].split(',') else: cols = [''] return [x.replace(' ','').replace(')','').replace('(','').upper() for x in cols] def get_indexes(l): ops = [] idx = [] for i in range(len(l)): if l[i] in ['=', '>', '<', '>=', '<=', '<>', 'LIKE', 'AND', 'OR']: idx.append(i) return idx def add_spaces_cmp_operators(string): ops = ['=', '>', '<', '>=', '<=', '<>'] for op in ops: string = string.replace(op, ' ' + op + ' ') return ' '.join(string.split()) #Check if string and add quotes (input: string; output: string) #USE only for wikisql dataset def add_quotes_to_string(query): query = query.upper() if query.find('WHERE') > 0: query_list = query.split(' ') query_list = [x.replace(' ','') for x in query_list] query_list[:] = filter(None, query_list) idx_list = get_indexes(query_list) idx_list.append(len(query_list)) subs = [] for i in range(len(idx_list)): if i % 2 == 0: b = idx_list[i] + 1 e = idx_list[i+1] - 1 if b != e: s = '' for ix in range(b,e + 1): s = s + query_list[ix] + ' ' s = s[:-1] else: s = query_list[b] if not(s.isnumeric()): s = "'" + s + "'" subs.append((idx_list[i] + 1, idx_list[i+1] - 1, s)) subs = subs[::-1] for i in range(len(subs)): e = subs[i] if e[0] == e[1]: query_list[e[0]] = e[2] else: query_list[e[0]] = e[2] idx = e[1] while idx > e[0]: query_list.pop(idx) idx = idx - 1 final_query = '' for word in query_list: final_query = final_query + word + ' ' return final_query[:-1] else: return query #Get values from where clause (input: string; output: list) def get_values_for_query_filter(query): query = query.upper() if query.find('WHERE') > 0: query_list = query.split(' ') query_list = [x.replace(' ','') for x in query_list] query_list[:] = filter(None, query_list) idx_list = get_indexes(query_list) idx_list.append(len(query_list)) subs = [] for i in range(len(idx_list)): if i % 2 == 0: b = idx_list[i] + 1 e = idx_list[i+1] - 1 if b != e: s = '' for ix in range(b,e + 1): s = s + query_list[ix] + ' ' s = s[:-1] else: s = query_list[b] subs.append(s.replace("'","")) return subs # Get columns name after where (input: string, output: list) def get_cols_name_for_where(query): query = query.upper() subs = [] if query.find('WHERE') > 0: query_list = query.split(' ') query_list = [x.replace(' ','') for x in query_list] query_list[:] = filter(None, query_list) idx_list = get_indexes(query_list) #idx_list.append(len(query_list)) idx_list.insert(0, query_list.index('WHERE')) for i in range(len(idx_list)-1): if i % 2 == 0: b = idx_list[i] + 1 e = idx_list[i+1] - 1 if b != e: s = '' for ix in range(b,e + 1): s = s + query_list[ix] + ' ' s = s[:-1] else: s = query_list[b] subs.append(s) return subs def check_if_number(s): try: a = float(s) return True except: return False #Correct missing compare operator (input: string; output: string) #T5 seems to have problem with '<' operator so if there is none this is used. def check_if_correct_cmp_operators(query): query = query.upper() if query.find('WHERE') > 0: query = add_spaces_cmp_operators(query) query_list = query.split(' ') w = query_list.index('WHERE') cmp_operators = ['=', '>', '<', '>=', '<=', '<>', 'LIKE'] s = 0 for op in cmp_operators: s = s + query_list.count(op) if s == 0: if check_if_number(query_list[-1]): query_list.insert(len(query_list)-1,'<') else: query_list.insert(len(query_list)-1,'=') return ' '.join(query_list) else: return query else: return query #Correct SQL syntax using info from table (input: string, list; output:string) #Use only for wikisql dataset def correct_query(query, table_columns): query = check_if_correct_cmp_operators(query) query = add_spaces_cmp_operators(query) #try: query = aggregator_parser(query) #except: pass #try: query = add_quotes_to_string(query) #except: pass #try: cols_name = get_columns_name_in_query(query) for col in cols_name: corrected_col = adjust_col_name(col, table_columns) query = query.replace(col, corrected_col) #except: pass return query #Correct mispelling using info from question (input: string, string; output:string) def correct_mispelling(question, query): query = query.upper() if query.find('WHERE') > 0: question = question.upper() corrections = [] values = get_values_for_query_filter(query) for value in values: l = len(value.split(' ')) tokens = question.replace(' ', ' ').split(' ') l_gram = '' max = -100 for i in range(0, len(tokens)-l+1, 1): filter = ' '.join(tokens[i:i+l]).strip('.,?') #filter = re.sub(r"[,.;@#?!&$]+\ *", " ", filter).strip() score = -Levenshtein.distance(value, filter) if score > max: max = score correct_filter = filter corrections.append([value, correct_filter]) for corr in corrections: query = query.replace(corr[0], corr[1]) return query #Get both text and tables in separate string (input:string; output: string, list) def get_text_and_cols(text): out = text.split(' table ID: ') return out[0], out[1].split(',')