Spaces:
Runtime error
Runtime error
| import gradio | |
| import torch | |
| import transformers | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline | |
| from SQL_helper import * #It's bad, just for protype | |
| #Translate question to SQL query (input: string; output:string) | |
| def translate_to_sql(text): | |
| inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(device) | |
| input_ids = inputs.input_ids | |
| attention_mask = inputs.attention_mask | |
| output = model.generate(input_ids, attention_mask=attention_mask, max_length=64) | |
| return tokenizer.decode(output[0], skip_special_tokens=True) | |
| #Get SQL from question with corrected syntax (input: string; output:string) | |
| def get_SQL(question): | |
| sql = translate_to_sql(question) | |
| try: | |
| text, cols = get_text_and_cols(question) | |
| output = correct_query(sql, cols) | |
| #for c in cols: | |
| # output = output.replace(c.upper(), ' "' + c.strip() + '" ') | |
| text = text.replace('translate to SQL: ','') | |
| output = correct_mispelling(text, output) | |
| return ' '.join(output.split()) | |
| except: | |
| return ' '.join(sql.split()) | |
| #Get SQL from question using IT to EN trasnlation model for better quality (input: string, output:string or None) | |
| def get_SQL_with_translation(question): | |
| text, cols = get_text_and_cols(question) | |
| t_text = translator(text.replace('translate to SQL: ',''))[0]['translation_text'] | |
| t_cols = translator(",".join(cols))[0]['translation_text'].split(',') | |
| if len(t_cols) == len(cols): | |
| s = 'translate to SQL: ' + t_text + ' table ID: ' + ",".join(t_cols) | |
| t_sql = get_SQL(s) | |
| mapping = zip([t.upper() for t in t_cols], [' "' + t.strip() + '" ' for t in cols]) | |
| #mapping = zip([t.upper() for t in t_cols], [t.strip() for t in cols]) | |
| for k,v in mapping: | |
| t_sql = t_sql.replace(k, v) | |
| return t_sql | |
| else: | |
| return None | |
| #align values traslated with original values from WHERE condition (input: string, string; output:string) | |
| def align_translation(o_SQL, t_SQL): | |
| o_values = get_values_for_query_filter(o_SQL) | |
| t_values = get_values_for_query_filter(t_SQL) | |
| o_columns = get_cols_name_for_where(o_SQL) | |
| t_columns = get_cols_name_for_where(t_SQL.replace('"','')) | |
| for i, t_col in enumerate(t_columns): | |
| try: | |
| idx = o_columns.index(t_col) | |
| t_value = t_values[i] | |
| o_value = o_values[idx] | |
| if not(o_value.upper() == t_value.upper()): | |
| t_SQL = t_SQL.replace(t_value, o_value) | |
| except: | |
| pass | |
| return t_SQL | |
| #Main function to convert question in natural language (Italian) in a SQL query (input: string; output: string) | |
| def convert_question_to_SQL(question): | |
| original_SQL = get_SQL(question) | |
| translated_SQL = get_SQL_with_translation(question) | |
| if translated_SQL is not None: | |
| return align_translation(original_SQL, translated_SQL) | |
| else: | |
| return original_SQL | |
| device = 'cpu' | |
| #tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols') | |
| #model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols') | |
| tokenizer = AutoTokenizer.from_pretrained('model/', local_files_only=True) | |
| model = T5ForConditionalGeneration.from_pretrained('model/', local_files_only=True) | |
| translator = pipeline("translation", model="Helsinki-NLP/opus-mt-it-en") | |
| #Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees' | |
| gradio_interface = gradio.Interface( | |
| fn = convert_question_to_SQL, | |
| inputs=gradio.Textbox(lines=2, placeholder=""), | |
| outputs = "text") | |
| gradio_interface.launch() | |