txt2sql_poc_v1 / app.py
gbarone77's picture
Update app.py
a1eb554
raw
history blame
1.01 kB
import gradio
import torch
import transformers
from transformers import AutoTokenizer, T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-small-finetuned-wikisql-with-cols')
model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-small-finetuned-wikisql-with-cols')
def translate_to_sql(text):
inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
return tokenizer.decode(output[0], skip_special_tokens=True)
#Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
gradio_interface = gradio.Interface(
fn = translate_to_sql,
inputs=gradio.Textbox(lines=2, placeholder="translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees"),
outputs = "text")
gradio_interface.launch()