gbarone77 commited on
Commit
40220ef
·
1 Parent(s): a394cab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -6
app.py CHANGED
@@ -1,23 +1,81 @@
1
  import gradio
2
  import torch
3
  import transformers
4
- from transformers import AutoTokenizer, T5ForConditionalGeneration
5
-
6
- tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5flan-small-finetuned-wikisql-with-cols')
7
- model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5flan-small-finetuned-wikisql-with-cols')
8
 
 
9
  def translate_to_sql(text):
10
- inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')
11
  input_ids = inputs.input_ids
12
  attention_mask = inputs.attention_mask
13
  output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
14
  return tokenizer.decode(output[0], skip_special_tokens=True)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
17
 
18
  #Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
19
  gradio_interface = gradio.Interface(
20
- fn = translate_to_sql,
21
  inputs=gradio.Textbox(lines=2, placeholder="translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees"),
22
  outputs = "text")
23
 
 
1
  import gradio
2
  import torch
3
  import transformers
4
+ from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
5
+ from SQL_helper import * #It's bad, just for protype
 
 
6
 
7
+ #Translate question to SQL query (input: string; output:string)
8
  def translate_to_sql(text):
9
+ inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt').to(device)
10
  input_ids = inputs.input_ids
11
  attention_mask = inputs.attention_mask
12
  output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
13
  return tokenizer.decode(output[0], skip_special_tokens=True)
14
 
15
+ #Get SQL from question with corrected syntax (input: string; output:string)
16
+ def get_SQL(question):
17
+ sql = translate_to_sql(question)
18
+ try:
19
+ text, cols = get_text_and_cols(question)
20
+ output = correct_query(sql, cols)
21
+ #for c in cols:
22
+ # output = output.replace(c.upper(), ' "' + c.strip() + '" ')
23
+ text = text.replace('translate to SQL: ','')
24
+ output = correct_mispelling(text, output)
25
+ return ' '.join(output.split())
26
+ except:
27
+ return ' '.join(sql.split())
28
+
29
+ #Get SQL from question using IT to EN trasnlation model for better quality (input: string, output:string or None)
30
+ def get_SQL_with_translation(question):
31
+ text, cols = get_text_and_cols(question)
32
+ t_text = translator(text.replace('translate to SQL: ',''))[0]['translation_text']
33
+ t_cols = translator(",".join(cols))[0]['translation_text'].split(',')
34
+ if len(t_cols) == len(cols):
35
+ s = 'translate to SQL: ' + t_text + ' table ID: ' + ",".join(t_cols)
36
+ t_sql = get_SQL(s)
37
+ mapping = zip([t.upper() for t in t_cols], [' "' + t.strip() + '" ' for t in cols])
38
+ #mapping = zip([t.upper() for t in t_cols], [t.strip() for t in cols])
39
+ for k,v in mapping:
40
+ t_sql = t_sql.replace(k, v)
41
+ return t_sql
42
+ else:
43
+ return None
44
+
45
+ #align values traslated with original values from WHERE condition (input: string, string; output:string)
46
+ def align_translation(o_SQL, t_SQL):
47
+ o_values = get_values_for_query_filter(o_SQL)
48
+ t_values = get_values_for_query_filter(t_SQL)
49
+ o_columns = get_cols_name_for_where(o_SQL)
50
+ t_columns = get_cols_name_for_where(t_SQL.replace('"',''))
51
+ for i, t_col in enumerate(t_columns):
52
+ try:
53
+ idx = o_columns.index(t_col)
54
+ t_value = t_values[i]
55
+ o_value = o_values[idx]
56
+ if not(o_value.upper() == t_value.upper()):
57
+ t_SQL = t_SQL.replace(t_value, o_value)
58
+ except:
59
+ pass
60
+ return t_SQL
61
+
62
+ #Main function to convert question in natural language (Italian) in a SQL query (input: string; output: string)
63
+ def convert_question_to_SQL(question):
64
+ original_SQL = get_SQL(question)
65
+ translated_SQL = get_SQL_with_translation(question)
66
+ if translated_SQL is not None:
67
+ return align_translation(original_SQL, translated_SQL)
68
+ else:
69
+ return original_SQL
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
72
+ model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-large-finetuned-wikisql-with-cols')
73
 
74
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-it-en")
75
 
76
  #Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
77
  gradio_interface = gradio.Interface(
78
+ fn = convert_question_to_SQL,
79
  inputs=gradio.Textbox(lines=2, placeholder="translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees"),
80
  outputs = "text")
81