txt2sql_poc_v1 / app.py
gbarone77's picture
Update app.py
adc5de3
import gradio
import torch
import transformers
from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
from SQL_helper import * #It's bad, just for protype
#Translate question to SQL query (input: string; output:string)
def translate_to_sql(text):
inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
return tokenizer.decode(output[0], skip_special_tokens=True)
#Get SQL from question with corrected syntax (input: string; output:string)
def get_SQL(question):
sql = translate_to_sql(question)
try:
text, cols = get_text_and_cols(question)
output = correct_query(sql, cols)
#for c in cols:
# output = output.replace(c.upper(), ' "' + c.strip() + '" ')
text = text.replace('translate to SQL: ','')
output = correct_mispelling(text, output)
return ' '.join(output.split())
except:
return ' '.join(sql.split())
#Get SQL from question using IT to EN trasnlation model for better quality (input: string, output:string or None)
def get_SQL_with_translation(question):
text, cols = get_text_and_cols(question)
t_text = translator(text.replace('translate to SQL: ',''))[0]['translation_text']
t_cols = translator(",".join(cols))[0]['translation_text'].split(',')
if len(t_cols) == len(cols):
s = 'translate to SQL: ' + t_text + ' table ID: ' + ",".join(t_cols)
t_sql = get_SQL(s)
mapping = zip([t.upper() for t in t_cols], [' "' + t.strip() + '" ' for t in cols])
#mapping = zip([t.upper() for t in t_cols], [t.strip() for t in cols])
for k,v in mapping:
t_sql = t_sql.replace(k, v)
return t_sql
else:
return None
#align values traslated with original values from WHERE condition (input: string, string; output:string)
def align_translation(o_SQL, t_SQL):
o_values = get_values_for_query_filter(o_SQL)
t_values = get_values_for_query_filter(t_SQL)
o_columns = get_cols_name_for_where(o_SQL)
t_columns = get_cols_name_for_where(t_SQL.replace('"',''))
for i, t_col in enumerate(t_columns):
try:
idx = o_columns.index(t_col)
t_value = t_values[i]
o_value = o_values[idx]
if not(o_value.upper() == t_value.upper()):
t_SQL = t_SQL.replace(t_value, o_value)
except:
pass
return t_SQL
#Main function to convert question in natural language (Italian) in a SQL query (input: string; output: string)
def convert_question_to_SQL(question):
original_SQL = get_SQL(question)
translated_SQL = get_SQL_with_translation(question)
if translated_SQL is not None:
return align_translation(original_SQL, translated_SQL)
else:
return original_SQL
device = 'cpu'
#tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
#model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
tokenizer = AutoTokenizer.from_pretrained('model/', local_files_only=True)
model = T5ForConditionalGeneration.from_pretrained('model/', local_files_only=True)
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-it-en")
#Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
gradio_interface = gradio.Interface(
fn = convert_question_to_SQL,
inputs=gradio.Textbox(lines=2, placeholder=""),
outputs = "text")
gradio_interface.launch()