File size: 1,012 Bytes
6fff48b
1704b25
b4f20c9
b845e1d
6fff48b
b845e1d
 
6fff48b
1704b25
 
 
 
 
 
 
 
 
1bed5cf
6fff48b
1704b25
a1eb554
b171847
80a1529
6fff48b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import gradio
import torch
import transformers
from transformers import AutoTokenizer, T5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-small-finetuned-wikisql-with-cols')
model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-small-finetuned-wikisql-with-cols')

def translate_to_sql(text):
    inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
    return tokenizer.decode(output[0], skip_special_tokens=True)



#Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
gradio_interface = gradio.Interface(
  fn = translate_to_sql,
  inputs=gradio.Textbox(lines=2, placeholder="translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees"),  
  outputs = "text")

gradio_interface.launch()