Spaces:
Runtime error
Runtime error
File size: 3,578 Bytes
6fff48b 1704b25 b4f20c9 40220ef 6fff48b 40220ef 1704b25 40220ef 1704b25 40220ef d8079fc adc5de3 1704b25 40220ef 1704b25 1bed5cf 6fff48b 40220ef f6f11e9 b171847 80a1529 6fff48b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio
import torch
import transformers
from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
from SQL_helper import * #It's bad, just for protype
#Translate question to SQL query (input: string; output:string)
def translate_to_sql(text):
inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
return tokenizer.decode(output[0], skip_special_tokens=True)
#Get SQL from question with corrected syntax (input: string; output:string)
def get_SQL(question):
sql = translate_to_sql(question)
try:
text, cols = get_text_and_cols(question)
output = correct_query(sql, cols)
#for c in cols:
# output = output.replace(c.upper(), ' "' + c.strip() + '" ')
text = text.replace('translate to SQL: ','')
output = correct_mispelling(text, output)
return ' '.join(output.split())
except:
return ' '.join(sql.split())
#Get SQL from question using IT to EN trasnlation model for better quality (input: string, output:string or None)
def get_SQL_with_translation(question):
text, cols = get_text_and_cols(question)
t_text = translator(text.replace('translate to SQL: ',''))[0]['translation_text']
t_cols = translator(",".join(cols))[0]['translation_text'].split(',')
if len(t_cols) == len(cols):
s = 'translate to SQL: ' + t_text + ' table ID: ' + ",".join(t_cols)
t_sql = get_SQL(s)
mapping = zip([t.upper() for t in t_cols], [' "' + t.strip() + '" ' for t in cols])
#mapping = zip([t.upper() for t in t_cols], [t.strip() for t in cols])
for k,v in mapping:
t_sql = t_sql.replace(k, v)
return t_sql
else:
return None
#align values traslated with original values from WHERE condition (input: string, string; output:string)
def align_translation(o_SQL, t_SQL):
o_values = get_values_for_query_filter(o_SQL)
t_values = get_values_for_query_filter(t_SQL)
o_columns = get_cols_name_for_where(o_SQL)
t_columns = get_cols_name_for_where(t_SQL.replace('"',''))
for i, t_col in enumerate(t_columns):
try:
idx = o_columns.index(t_col)
t_value = t_values[i]
o_value = o_values[idx]
if not(o_value.upper() == t_value.upper()):
t_SQL = t_SQL.replace(t_value, o_value)
except:
pass
return t_SQL
#Main function to convert question in natural language (Italian) in a SQL query (input: string; output: string)
def convert_question_to_SQL(question):
original_SQL = get_SQL(question)
translated_SQL = get_SQL_with_translation(question)
if translated_SQL is not None:
return align_translation(original_SQL, translated_SQL)
else:
return original_SQL
device = 'cpu'
#tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
#model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
tokenizer = AutoTokenizer.from_pretrained('model/', local_files_only=True)
model = T5ForConditionalGeneration.from_pretrained('model/', local_files_only=True)
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-it-en")
#Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
gradio_interface = gradio.Interface(
fn = convert_question_to_SQL,
inputs=gradio.Textbox(lines=2, placeholder=""),
outputs = "text")
gradio_interface.launch()
|