txt2sql_poc_v1 / SQL_helper.py
gbarone77's picture
Create SQL_helper.py
a394cab
import Levenshtein
import re
from collections import Counter
#Get columns in query (input_ string; output: list)
def get_columns_name_in_query(query):
cols_from_select = get_cols_name_for_select(query)
cols_from_where = get_cols_name_for_where(query)
return list(set(cols_from_select + cols_from_where))
# Sometime column name maybe ill-defined. This function replace weird chars with underscore (input:list; output:string)
def replace_nonalphanumeric_chars_with_us(l):
well_defined = [re.sub('[^0-9a-zA-Z]+', '_', s) for s in l]
return well_defined
# Adjust column name using columns name from original table (input: column name in SQL query (string),
#list of columns names from table (string); output: corrected column name (if needed) (string))
def adjust_col_name(col_name, columns_available):
columns_available = [x.upper() for x in columns_available]
if col_name.upper() in set(columns_available):
return col_name
else:
max = -100
most_similar_column = 'column123456789011'
for col in columns_available:
score = -Levenshtein.distance(col_name, col)
if score > max:
most_similar_column = col
max = score
return most_similar_column
def min_positive(a,b):
if (b < a) and (b > 0): return b
else: return a
#Return corrected syntax for aggregator operators (input: string; output: string)
#USE only for wikisql dataset
def aggregator_parser(query):
query = query.upper()
if query.find('SELECT MAX') > -1:
end = min_positive(query.find('FROM'), query.find(','))
adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ')
return adjusted_query
elif query.find('SELECT COUNT') > -1:
end = min_positive(query.find('FROM'), query.find(','))
adjusted_query = query.replace(query[12:end],'(' + query[13:end-1] + ') ')
return adjusted_query
elif query.find('SELECT MIN') > -1:
end = min_positive(query.find('FROM'), query.find(','))
adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ')
return adjusted_query
#elif query.find('SELECT DISTINCT') > -1:
#end = query.find('FROM')
#adjusted_query = query.replace(query[15:end],'(' + query[16:end-1] + ') ')
#return adjusted_query
else:
return query
#Return columns name from SELECT operator (input: string; output: list)
def get_cols_name_for_select(query):
query = query.upper()
if query.find('SELECT DISTINCT') > -1:
end = query.find('FROM')
cols = query[15:end-1].split(',')
elif query.find('SELECT MAX') > -1:
end = query.find('FROM')
cols = query[10:end-1].split(',')
elif query.find('SELECT MIN') > -1:
end = query.find('FROM')
cols = query[10:end-1].split(',')
elif query.find('SELECT COUNT') > -1:
end = query.find('FROM')
cols = query[13:end-1].split(',')
elif query.find('SELECT') > -1:
end = query.find('FROM')
cols = query[7:end-1].split(',')
else:
cols = ['']
return [x.replace(' ','').replace(')','').replace('(','').upper() for x in cols]
def get_indexes(l):
ops = []
idx = []
for i in range(len(l)):
if l[i] in ['=', '>', '<', '>=', '<=', '<>', 'LIKE', 'AND', 'OR']:
idx.append(i)
return idx
def add_spaces_cmp_operators(string):
ops = ['=', '>', '<', '>=', '<=', '<>']
for op in ops:
string = string.replace(op, ' ' + op + ' ')
return ' '.join(string.split())
#Check if string and add quotes (input: string; output: string)
#USE only for wikisql dataset
def add_quotes_to_string(query):
query = query.upper()
if query.find('WHERE') > 0:
query_list = query.split(' ')
query_list = [x.replace(' ','') for x in query_list]
query_list[:] = filter(None, query_list)
idx_list = get_indexes(query_list)
idx_list.append(len(query_list))
subs = []
for i in range(len(idx_list)):
if i % 2 == 0:
b = idx_list[i] + 1
e = idx_list[i+1] - 1
if b != e:
s = ''
for ix in range(b,e + 1):
s = s + query_list[ix] + ' '
s = s[:-1]
else:
s = query_list[b]
if not(s.isnumeric()):
s = "'" + s + "'"
subs.append((idx_list[i] + 1, idx_list[i+1] - 1, s))
subs = subs[::-1]
for i in range(len(subs)):
e = subs[i]
if e[0] == e[1]:
query_list[e[0]] = e[2]
else:
query_list[e[0]] = e[2]
idx = e[1]
while idx > e[0]:
query_list.pop(idx)
idx = idx - 1
final_query = ''
for word in query_list:
final_query = final_query + word + ' '
return final_query[:-1]
else:
return query
#Get values from where clause (input: string; output: list)
def get_values_for_query_filter(query):
query = query.upper()
if query.find('WHERE') > 0:
query_list = query.split(' ')
query_list = [x.replace(' ','') for x in query_list]
query_list[:] = filter(None, query_list)
idx_list = get_indexes(query_list)
idx_list.append(len(query_list))
subs = []
for i in range(len(idx_list)):
if i % 2 == 0:
b = idx_list[i] + 1
e = idx_list[i+1] - 1
if b != e:
s = ''
for ix in range(b,e + 1):
s = s + query_list[ix] + ' '
s = s[:-1]
else:
s = query_list[b]
subs.append(s.replace("'",""))
return subs
# Get columns name after where (input: string, output: list)
def get_cols_name_for_where(query):
query = query.upper()
subs = []
if query.find('WHERE') > 0:
query_list = query.split(' ')
query_list = [x.replace(' ','') for x in query_list]
query_list[:] = filter(None, query_list)
idx_list = get_indexes(query_list)
#idx_list.append(len(query_list))
idx_list.insert(0, query_list.index('WHERE'))
for i in range(len(idx_list)-1):
if i % 2 == 0:
b = idx_list[i] + 1
e = idx_list[i+1] - 1
if b != e:
s = ''
for ix in range(b,e + 1):
s = s + query_list[ix] + ' '
s = s[:-1]
else:
s = query_list[b]
subs.append(s)
return subs
def check_if_number(s):
try:
a = float(s)
return True
except:
return False
#Correct missing compare operator (input: string; output: string)
#T5 seems to have problem with '<' operator so if there is none this is used.
def check_if_correct_cmp_operators(query):
query = query.upper()
if query.find('WHERE') > 0:
query = add_spaces_cmp_operators(query)
query_list = query.split(' ')
w = query_list.index('WHERE')
cmp_operators = ['=', '>', '<', '>=', '<=', '<>', 'LIKE']
s = 0
for op in cmp_operators:
s = s + query_list.count(op)
if s == 0:
if check_if_number(query_list[-1]):
query_list.insert(len(query_list)-1,'<')
else:
query_list.insert(len(query_list)-1,'=')
return ' '.join(query_list)
else:
return query
else: return query
#Correct SQL syntax using info from table (input: string, list; output:string)
#Use only for wikisql dataset
def correct_query(query, table_columns):
query = check_if_correct_cmp_operators(query)
query = add_spaces_cmp_operators(query)
#try:
query = aggregator_parser(query)
#except: pass
#try:
query = add_quotes_to_string(query)
#except: pass
#try:
cols_name = get_columns_name_in_query(query)
for col in cols_name:
corrected_col = adjust_col_name(col, table_columns)
query = query.replace(col, corrected_col)
#except: pass
return query
#Correct mispelling using info from question (input: string, string; output:string)
def correct_mispelling(question, query):
query = query.upper()
if query.find('WHERE') > 0:
question = question.upper()
corrections = []
values = get_values_for_query_filter(query)
for value in values:
l = len(value.split(' '))
tokens = question.replace(' ', ' ').split(' ')
l_gram = ''
max = -100
for i in range(0, len(tokens)-l+1, 1):
filter = ' '.join(tokens[i:i+l]).strip('.,?')
#filter = re.sub(r"[,.;@#?!&$]+\ *", " ", filter).strip()
score = -Levenshtein.distance(value, filter)
if score > max:
max = score
correct_filter = filter
corrections.append([value, correct_filter])
for corr in corrections:
query = query.replace(corr[0], corr[1])
return query
#Get both text and tables in separate string (input:string; output: string, list)
def get_text_and_cols(text):
out = text.split(' table ID: ')
return out[0], out[1].split(',')