File size: 3,578 Bytes
6fff48b
1704b25
b4f20c9
40220ef
 
6fff48b
40220ef
1704b25
40220ef
1704b25
 
 
 
 
40220ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8079fc
 
adc5de3
 
 
 
 
1704b25
40220ef
1704b25
1bed5cf
6fff48b
40220ef
f6f11e9
b171847
80a1529
6fff48b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio
import torch
import transformers
from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
from SQL_helper import * #It's bad, just for protype

#Translate question to SQL query (input: string; output:string)
def translate_to_sql(text):
    inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
    return tokenizer.decode(output[0], skip_special_tokens=True)

#Get SQL from question with corrected syntax (input: string; output:string)
def get_SQL(question):
  sql = translate_to_sql(question)    
  try: 
    text, cols = get_text_and_cols(question)     
    output = correct_query(sql, cols)
    #for c in cols:      
    #  output = output.replace(c.upper(), ' "' + c.strip() + '" ')
    text = text.replace('translate to SQL: ','')
    output = correct_mispelling(text, output)
    return ' '.join(output.split())
  except: 
    return ' '.join(sql.split())

#Get SQL from question using IT to EN trasnlation model for better quality (input: string, output:string or None)
def get_SQL_with_translation(question):
  text, cols = get_text_and_cols(question)
  t_text = translator(text.replace('translate to SQL: ',''))[0]['translation_text']   
  t_cols = translator(",".join(cols))[0]['translation_text'].split(',')  
  if len(t_cols) == len(cols):    
    s = 'translate to SQL: ' + t_text + ' table ID: ' + ",".join(t_cols)    
    t_sql = get_SQL(s)   
    mapping = zip([t.upper() for t in t_cols], [' "' + t.strip() + '" ' for t in cols])
    #mapping = zip([t.upper() for t in t_cols], [t.strip() for t in cols])
    for k,v in mapping:
      t_sql = t_sql.replace(k, v)   
    return t_sql 
  else:
    return None

#align values traslated with original values from WHERE condition (input: string, string; output:string)
def align_translation(o_SQL, t_SQL):
  o_values = get_values_for_query_filter(o_SQL)
  t_values = get_values_for_query_filter(t_SQL)
  o_columns =  get_cols_name_for_where(o_SQL)
  t_columns =  get_cols_name_for_where(t_SQL.replace('"',''))  
  for i, t_col in enumerate(t_columns):    
    try:
      idx = o_columns.index(t_col)
      t_value = t_values[i]
      o_value = o_values[idx]
      if not(o_value.upper() == t_value.upper()):
        t_SQL = t_SQL.replace(t_value, o_value)
    except:
      pass
  return t_SQL

#Main function to convert question in natural language (Italian) in a SQL query (input: string; output: string)
def convert_question_to_SQL(question):
  original_SQL = get_SQL(question)
  translated_SQL = get_SQL_with_translation(question)
  if translated_SQL is not None:
    return align_translation(original_SQL, translated_SQL)
  else:
    return original_SQL

device = 'cpu'

#tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
#model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')

tokenizer = AutoTokenizer.from_pretrained('model/', local_files_only=True)
model = T5ForConditionalGeneration.from_pretrained('model/', local_files_only=True)

translator = pipeline("translation", model="Helsinki-NLP/opus-mt-it-en")

#Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
gradio_interface = gradio.Interface(
  fn = convert_question_to_SQL,
  inputs=gradio.Textbox(lines=2, placeholder=""),  
  outputs = "text")

gradio_interface.launch()