Spaces:
Runtime error
Runtime error
File size: 1,012 Bytes
6fff48b 1704b25 b4f20c9 b845e1d 6fff48b b845e1d 6fff48b 1704b25 1bed5cf 6fff48b 1704b25 a1eb554 b171847 80a1529 6fff48b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import gradio
import torch
import transformers
from transformers import AutoTokenizer, T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained('gbarone77/t5-small-finetuned-wikisql-with-cols')
model = T5ForConditionalGeneration.from_pretrained('gbarone77/t5-small-finetuned-wikisql-with-cols')
def translate_to_sql(text):
inputs = tokenizer(text, padding='longest', max_length=64, return_tensors='pt')
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
output = model.generate(input_ids, attention_mask=attention_mask, max_length=64)
return tokenizer.decode(output[0], skip_special_tokens=True)
#Input example: 'translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees'
gradio_interface = gradio.Interface(
fn = translate_to_sql,
inputs=gradio.Textbox(lines=2, placeholder="translate to SQL: When was Olympics held in Rome? table ID: ID, city, year, cost, attendees"),
outputs = "text")
gradio_interface.launch()
|