import gradio import torch import transformers from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline from SQL_helper import * #It's bad, just for protype #Translate question to SQL query (input: string; output:string) def translate_to_sql(text): inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(device) input_ids = inputs.input_ids attention_mask = inputs.attention_mask output = model.generate(input_ids, attention_mask=attention_mask, max_length=64) return tokenizer.decode(output[0], skip_special_tokens=True) #Get SQL from question with corrected syntax (input: string; output:string) def get_SQL(question): sql = translate_to_sql(question) try: text, cols = get_text_and_cols(question) output = correct_query(sql, cols) #for c in cols: # output = output.replace(c.upper(), ' "' + c.strip() + '" ') text = text.replace('translate to SQL: ','') output = correct_mispelling(text, output) return ' '.join(output.split()) except: return ' '.join(sql.split()) #Get SQL from question using IT to EN trasnlation model for better quality (input: string, output:string or None) def get_SQL_with_translation(question): text, cols = get_text_and_cols(question) t_text = translator(text.replace('translate to SQL: ',''))[0]['translation_text'] t_cols = translator(",".join(cols))[0]['translation_text'].split(',') if len(t_cols) == len(cols): s = 'translate to SQL: ' + t_text + ' table ID: ' + ",".join(t_cols) t_sql = get_SQL(s) mapping = zip([t.upper() for t in t_cols], [' "' + t.strip() + '" ' for t in cols]) #mapping = zip([t.upper() for t in t_cols], [t.strip() for t in cols]) for k,v in mapping: t_sql = t_sql.replace(k, v) return t_sql else: return None #align values traslated with original values from WHERE condition (input: string, string; output:string) def align_translation(o_SQL, t_SQL): o_values = get_values_for_query_filter(o_SQL) t_values = get_values_for_query_filter(t_SQL) o_columns = get_cols_name_for_where(o_SQL) t_columns = get_cols_name_for_where(t_SQL.replace('"','')) for i, t_col in enumerate(t_columns): try: idx = o_columns.index(t_col) t_value = t_values[i] o_value = o_values[idx] if not(o_value.upper() == t_value.upper()): t_SQL = t_SQL.replace(t_value, o_value) except: pass return t_SQL #Main function to convert question in natural language (Italian) in a SQL query (input: string; output: string) def convert_question_to_SQL(question): original_SQL = get_SQL(question) translated_SQL = get_SQL_with_translation(question) if translated_SQL is not None: return align_translation(original_SQL, translated_SQL) else: return original_SQL device = 'cpu' #tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols') #model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols') tokenizer = AutoTokenizer.from_pretrained('model/', local_files_only=True) model = T5ForConditionalGeneration.from_pretrained('model/', local_files_only=True) translator = pipeline("translation", model="Helsinki-NLP/opus-mt-it-en") #Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees' gradio_interface = gradio.Interface( fn = convert_question_to_SQL, inputs=gradio.Textbox(lines=2, placeholder=""), outputs = "text") gradio_interface.launch()