txt2sql_poc_v1 / app.py
gbarone77's picture
Update app.py
1704b25
raw
history blame
806 Bytes
import gradio
import transformers
import torch
tokenizer = AutoTokenizer.from_pretrained('t5-small-finetuned-wikisql-with-cols')
model = T5ForConditionalGeneration.from_pretrained('t5-small-finetuned-wikisql-with-cols')
def translate_to_sql(text):
inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
return tokenizer.decode(output[0], skip_special_tokens=True)
#Input example: 'translate to SQL: When was Olympic games held in Rome? table ID: ID, city, year, cost, attendees'
gradio_interface = gradio.Interface(
fn = translate_to_sql,
inputs = "text",
outputs = "text"
)
gradio_interface.launch()