Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,23 +1,81 @@
|
|
| 1 |
import gradio
|
| 2 |
import torch
|
| 3 |
import transformers
|
| 4 |
-
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
| 5 |
-
|
| 6 |
-
tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5flan-small-finetuned-wikisql-with-cols')
|
| 7 |
-
model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5flan-small-finetuned-wikisql-with-cols')
|
| 8 |
|
|
|
|
| 9 |
def translate_to_sql(text):
|
| 10 |
-
inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')
|
| 11 |
input_ids = inputs.input_ids
|
| 12 |
attention_mask = inputs.attention_mask
|
| 13 |
output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
|
| 14 |
return tokenizer.decode(output[0], skip_special_tokens=True)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
|
|
|
| 17 |
|
| 18 |
#Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
|
| 19 |
gradio_interface = gradio.Interface(
|
| 20 |
-
fn =
|
| 21 |
inputs=gradio.Textbox(lines=2, placeholder="translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees"),
|
| 22 |
outputs = "text")
|
| 23 |
|
|
|
|
| 1 |
import gradio
|
| 2 |
import torch
|
| 3 |
import transformers
|
| 4 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
|
| 5 |
+
from SQL_helper import * #It's bad, just for protype
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
#Translate question to SQL query (input: string; output:string)
|
| 8 |
def translate_to_sql(text):
|
| 9 |
+
inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(device)
|
| 10 |
input_ids = inputs.input_ids
|
| 11 |
attention_mask = inputs.attention_mask
|
| 12 |
output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
|
| 13 |
return tokenizer.decode(output[0], skip_special_tokens=True)
|
| 14 |
|
| 15 |
+
#Get SQL from question with corrected syntax (input: string; output:string)
|
| 16 |
+
def get_SQL(question):
|
| 17 |
+
sql = translate_to_sql(question)
|
| 18 |
+
try:
|
| 19 |
+
text, cols = get_text_and_cols(question)
|
| 20 |
+
output = correct_query(sql, cols)
|
| 21 |
+
#for c in cols:
|
| 22 |
+
# output = output.replace(c.upper(), ' "' + c.strip() + '" ')
|
| 23 |
+
text = text.replace('translate to SQL: ','')
|
| 24 |
+
output = correct_mispelling(text, output)
|
| 25 |
+
return ' '.join(output.split())
|
| 26 |
+
except:
|
| 27 |
+
return ' '.join(sql.split())
|
| 28 |
+
|
| 29 |
+
#Get SQL from question using IT to EN trasnlation model for better quality (input: string, output:string or None)
|
| 30 |
+
def get_SQL_with_translation(question):
|
| 31 |
+
text, cols = get_text_and_cols(question)
|
| 32 |
+
t_text = translator(text.replace('translate to SQL: ',''))[0]['translation_text']
|
| 33 |
+
t_cols = translator(",".join(cols))[0]['translation_text'].split(',')
|
| 34 |
+
if len(t_cols) == len(cols):
|
| 35 |
+
s = 'translate to SQL: ' + t_text + ' table ID: ' + ",".join(t_cols)
|
| 36 |
+
t_sql = get_SQL(s)
|
| 37 |
+
mapping = zip([t.upper() for t in t_cols], [' "' + t.strip() + '" ' for t in cols])
|
| 38 |
+
#mapping = zip([t.upper() for t in t_cols], [t.strip() for t in cols])
|
| 39 |
+
for k,v in mapping:
|
| 40 |
+
t_sql = t_sql.replace(k, v)
|
| 41 |
+
return t_sql
|
| 42 |
+
else:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
#align values traslated with original values from WHERE condition (input: string, string; output:string)
|
| 46 |
+
def align_translation(o_SQL, t_SQL):
|
| 47 |
+
o_values = get_values_for_query_filter(o_SQL)
|
| 48 |
+
t_values = get_values_for_query_filter(t_SQL)
|
| 49 |
+
o_columns = get_cols_name_for_where(o_SQL)
|
| 50 |
+
t_columns = get_cols_name_for_where(t_SQL.replace('"',''))
|
| 51 |
+
for i, t_col in enumerate(t_columns):
|
| 52 |
+
try:
|
| 53 |
+
idx = o_columns.index(t_col)
|
| 54 |
+
t_value = t_values[i]
|
| 55 |
+
o_value = o_values[idx]
|
| 56 |
+
if not(o_value.upper() == t_value.upper()):
|
| 57 |
+
t_SQL = t_SQL.replace(t_value, o_value)
|
| 58 |
+
except:
|
| 59 |
+
pass
|
| 60 |
+
return t_SQL
|
| 61 |
+
|
| 62 |
+
#Main function to convert question in natural language (Italian) in a SQL query (input: string; output: string)
|
| 63 |
+
def convert_question_to_SQL(question):
|
| 64 |
+
original_SQL = get_SQL(question)
|
| 65 |
+
translated_SQL = get_SQL_with_translation(question)
|
| 66 |
+
if translated_SQL is not None:
|
| 67 |
+
return align_translation(original_SQL, translated_SQL)
|
| 68 |
+
else:
|
| 69 |
+
return original_SQL
|
| 70 |
+
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
|
| 72 |
+
model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
|
| 73 |
|
| 74 |
+
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-it-en")
|
| 75 |
|
| 76 |
#Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
|
| 77 |
gradio_interface = gradio.Interface(
|
| 78 |
+
fn = convert_question_to_SQL,
|
| 79 |
inputs=gradio.Textbox(lines=2, placeholder="translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees"),
|
| 80 |
outputs = "text")
|
| 81 |
|