gbarone77 commited on
Commit
a394cab
·
1 Parent(s): 45caa26

Create SQL_helper.py

Browse files
Files changed (1) hide show
  1. SQL_helper.py +263 -0
SQL_helper.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Levenshtein
2
+ import re
3
+ from collections import Counter
4
+
5
+ #Get columns in query (input_ string; output: list)
6
+ def get_columns_name_in_query(query):
7
+ cols_from_select = get_cols_name_for_select(query)
8
+ cols_from_where = get_cols_name_for_where(query)
9
+ return list(set(cols_from_select + cols_from_where))
10
+
11
+ # Sometime column name maybe ill-defined. This function replace weird chars with underscore (input:list; output:string)
12
+ def replace_nonalphanumeric_chars_with_us(l):
13
+ well_defined = [re.sub('[^0-9a-zA-Z]+', '_', s) for s in l]
14
+ return well_defined
15
+
16
+ # Adjust column name using columns name from original table (input: column name in SQL query (string),
17
+ #list of columns names from table (string); output: corrected column name (if needed) (string))
18
+ def adjust_col_name(col_name, columns_available):
19
+ columns_available = [x.upper() for x in columns_available]
20
+ if col_name.upper() in set(columns_available):
21
+ return col_name
22
+ else:
23
+ max = -100
24
+ most_similar_column = 'column123456789011'
25
+ for col in columns_available:
26
+ score = -Levenshtein.distance(col_name, col)
27
+ if score > max:
28
+ most_similar_column = col
29
+ max = score
30
+ return most_similar_column
31
+
32
+ def min_positive(a,b):
33
+ if (b < a) and (b > 0): return b
34
+ else: return a
35
+
36
+ #Return corrected syntax for aggregator operators (input: string; output: string)
37
+ #USE only for wikisql dataset
38
+ def aggregator_parser(query):
39
+ query = query.upper()
40
+ if query.find('SELECT MAX') > -1:
41
+ end = min_positive(query.find('FROM'), query.find(','))
42
+ adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ')
43
+ return adjusted_query
44
+ elif query.find('SELECT COUNT') > -1:
45
+ end = min_positive(query.find('FROM'), query.find(','))
46
+ adjusted_query = query.replace(query[12:end],'(' + query[13:end-1] + ') ')
47
+ return adjusted_query
48
+ elif query.find('SELECT MIN') > -1:
49
+ end = min_positive(query.find('FROM'), query.find(','))
50
+ adjusted_query = query.replace(query[10:end],'(' + query[11:end-1] + ') ')
51
+ return adjusted_query
52
+ #elif query.find('SELECT DISTINCT') > -1:
53
+ #end = query.find('FROM')
54
+ #adjusted_query = query.replace(query[15:end],'(' + query[16:end-1] + ') ')
55
+ #return adjusted_query
56
+ else:
57
+ return query
58
+
59
+ #Return columns name from SELECT operator (input: string; output: list)
60
+ def get_cols_name_for_select(query):
61
+ query = query.upper()
62
+ if query.find('SELECT DISTINCT') > -1:
63
+ end = query.find('FROM')
64
+ cols = query[15:end-1].split(',')
65
+ elif query.find('SELECT MAX') > -1:
66
+ end = query.find('FROM')
67
+ cols = query[10:end-1].split(',')
68
+ elif query.find('SELECT MIN') > -1:
69
+ end = query.find('FROM')
70
+ cols = query[10:end-1].split(',')
71
+ elif query.find('SELECT COUNT') > -1:
72
+ end = query.find('FROM')
73
+ cols = query[13:end-1].split(',')
74
+ elif query.find('SELECT') > -1:
75
+ end = query.find('FROM')
76
+ cols = query[7:end-1].split(',')
77
+ else:
78
+ cols = ['']
79
+ return [x.replace(' ','').replace(')','').replace('(','').upper() for x in cols]
80
+
81
+ def get_indexes(l):
82
+ ops = []
83
+ idx = []
84
+ for i in range(len(l)):
85
+ if l[i] in ['=', '>', '<', '>=', '<=', '<>', 'LIKE', 'AND', 'OR']:
86
+ idx.append(i)
87
+ return idx
88
+
89
+ def add_spaces_cmp_operators(string):
90
+ ops = ['=', '>', '<', '>=', '<=', '<>']
91
+ for op in ops:
92
+ string = string.replace(op, ' ' + op + ' ')
93
+ return ' '.join(string.split())
94
+
95
+ #Check if string and add quotes (input: string; output: string)
96
+ #USE only for wikisql dataset
97
+ def add_quotes_to_string(query):
98
+ query = query.upper()
99
+ if query.find('WHERE') > 0:
100
+ query_list = query.split(' ')
101
+ query_list = [x.replace(' ','') for x in query_list]
102
+ query_list[:] = filter(None, query_list)
103
+ idx_list = get_indexes(query_list)
104
+ idx_list.append(len(query_list))
105
+ subs = []
106
+ for i in range(len(idx_list)):
107
+ if i % 2 == 0:
108
+ b = idx_list[i] + 1
109
+ e = idx_list[i+1] - 1
110
+ if b != e:
111
+ s = ''
112
+ for ix in range(b,e + 1):
113
+ s = s + query_list[ix] + ' '
114
+ s = s[:-1]
115
+ else:
116
+ s = query_list[b]
117
+ if not(s.isnumeric()):
118
+ s = "'" + s + "'"
119
+ subs.append((idx_list[i] + 1, idx_list[i+1] - 1, s))
120
+ subs = subs[::-1]
121
+ for i in range(len(subs)):
122
+ e = subs[i]
123
+ if e[0] == e[1]:
124
+ query_list[e[0]] = e[2]
125
+ else:
126
+ query_list[e[0]] = e[2]
127
+ idx = e[1]
128
+ while idx > e[0]:
129
+ query_list.pop(idx)
130
+ idx = idx - 1
131
+ final_query = ''
132
+ for word in query_list:
133
+ final_query = final_query + word + ' '
134
+ return final_query[:-1]
135
+ else:
136
+ return query
137
+
138
+ #Get values from where clause (input: string; output: list)
139
+ def get_values_for_query_filter(query):
140
+ query = query.upper()
141
+ if query.find('WHERE') > 0:
142
+ query_list = query.split(' ')
143
+ query_list = [x.replace(' ','') for x in query_list]
144
+ query_list[:] = filter(None, query_list)
145
+ idx_list = get_indexes(query_list)
146
+ idx_list.append(len(query_list))
147
+ subs = []
148
+ for i in range(len(idx_list)):
149
+ if i % 2 == 0:
150
+ b = idx_list[i] + 1
151
+ e = idx_list[i+1] - 1
152
+ if b != e:
153
+ s = ''
154
+ for ix in range(b,e + 1):
155
+ s = s + query_list[ix] + ' '
156
+ s = s[:-1]
157
+ else:
158
+ s = query_list[b]
159
+ subs.append(s.replace("'",""))
160
+ return subs
161
+
162
+
163
+ # Get columns name after where (input: string, output: list)
164
+ def get_cols_name_for_where(query):
165
+ query = query.upper()
166
+ subs = []
167
+ if query.find('WHERE') > 0:
168
+ query_list = query.split(' ')
169
+ query_list = [x.replace(' ','') for x in query_list]
170
+ query_list[:] = filter(None, query_list)
171
+ idx_list = get_indexes(query_list)
172
+ #idx_list.append(len(query_list))
173
+ idx_list.insert(0, query_list.index('WHERE'))
174
+ for i in range(len(idx_list)-1):
175
+ if i % 2 == 0:
176
+ b = idx_list[i] + 1
177
+ e = idx_list[i+1] - 1
178
+ if b != e:
179
+ s = ''
180
+ for ix in range(b,e + 1):
181
+ s = s + query_list[ix] + ' '
182
+ s = s[:-1]
183
+ else:
184
+ s = query_list[b]
185
+ subs.append(s)
186
+ return subs
187
+
188
+ def check_if_number(s):
189
+ try:
190
+ a = float(s)
191
+ return True
192
+ except:
193
+ return False
194
+
195
+ #Correct missing compare operator (input: string; output: string)
196
+ #T5 seems to have problem with '<' operator so if there is none this is used.
197
+ def check_if_correct_cmp_operators(query):
198
+ query = query.upper()
199
+ if query.find('WHERE') > 0:
200
+ query = add_spaces_cmp_operators(query)
201
+ query_list = query.split(' ')
202
+ w = query_list.index('WHERE')
203
+ cmp_operators = ['=', '>', '<', '>=', '<=', '<>', 'LIKE']
204
+ s = 0
205
+ for op in cmp_operators:
206
+ s = s + query_list.count(op)
207
+ if s == 0:
208
+ if check_if_number(query_list[-1]):
209
+ query_list.insert(len(query_list)-1,'<')
210
+ else:
211
+ query_list.insert(len(query_list)-1,'=')
212
+ return ' '.join(query_list)
213
+ else:
214
+ return query
215
+ else: return query
216
+
217
+ #Correct SQL syntax using info from table (input: string, list; output:string)
218
+ #Use only for wikisql dataset
219
+ def correct_query(query, table_columns):
220
+ query = check_if_correct_cmp_operators(query)
221
+ query = add_spaces_cmp_operators(query)
222
+ #try:
223
+ query = aggregator_parser(query)
224
+ #except: pass
225
+ #try:
226
+ query = add_quotes_to_string(query)
227
+ #except: pass
228
+ #try:
229
+ cols_name = get_columns_name_in_query(query)
230
+ for col in cols_name:
231
+ corrected_col = adjust_col_name(col, table_columns)
232
+ query = query.replace(col, corrected_col)
233
+ #except: pass
234
+ return query
235
+
236
+ #Correct mispelling using info from question (input: string, string; output:string)
237
+ def correct_mispelling(question, query):
238
+ query = query.upper()
239
+ if query.find('WHERE') > 0:
240
+ question = question.upper()
241
+ corrections = []
242
+ values = get_values_for_query_filter(query)
243
+ for value in values:
244
+ l = len(value.split(' '))
245
+ tokens = question.replace(' ', ' ').split(' ')
246
+ l_gram = ''
247
+ max = -100
248
+ for i in range(0, len(tokens)-l+1, 1):
249
+ filter = ' '.join(tokens[i:i+l]).strip('.,?')
250
+ #filter = re.sub(r"[,.;@#?!&$]+\ *", " ", filter).strip()
251
+ score = -Levenshtein.distance(value, filter)
252
+ if score > max:
253
+ max = score
254
+ correct_filter = filter
255
+ corrections.append([value, correct_filter])
256
+ for corr in corrections:
257
+ query = query.replace(corr[0], corr[1])
258
+ return query
259
+
260
+ #Get both text and tables in separate string (input:string; output: string, list)
261
+ def get_text_and_cols(text):
262
+ out = text.split(' table ID: ')
263
+ return out[0], out[1].split(',')