Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from example_strings import example1, example2, example3 | |
| # tokenizer6B = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-6B") | |
| # model6B = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-6B") | |
| # tokenizer2B = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-2B") | |
| # model2B = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-2B") | |
| # tokenizer350M = AutoTokenizer.from_pretrained(f"NumbersStation/nsql-2B") | |
| # model350M = AutoModelForCausalLM.from_pretrained(f"NumbersStation/nsql-2B") | |
| def load_model(model_name: str): | |
| tokenizer = AutoTokenizer.from_pretrained(f"NumbersStation/{model_name}") | |
| model = AutoModelForCausalLM.from_pretrained(f"NumbersStation/{model_name}") | |
| return tokenizer, model | |
| def infer(input_text, model_choice): | |
| tokenizer, model = load_model(model_choice) | |
| input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
| generated_ids = model.generate(input_ids, max_length=500) | |
| return (tokenizer.decode(generated_ids[0], skip_special_tokens=True)) | |
| iface = gr.Interface( | |
| title="Text to SQL with NSQL", | |
| description="""The NSQL model family was published by [Numbers Station](https://www.numbersstation.ai/) and is available in three flavors: | |
| - [nsql-6B](https://huggingface.co/NumbersStation/nsql-6B) | |
| - [nsql-2B](https://huggingface.co/NumbersStation/nsql-2B) | |
| - [nsql-350M]((https://huggingface.co/NumbersStation/nsql-350M)) | |
| This demo let's you choose which one you want to use and provides the three examples you can also find in their model cards. | |
| In general you should first provide the table schemas of the tables you have questions about and then prompt it with a natural language question. | |
| The model will then generate a SQL query that you can run against your database. | |
| """, | |
| fn=infer, | |
| inputs=["text", | |
| gr.Dropdown(["nsql-6B", "nsql-2B", "nsql-350M"], value="nsql-6B")], | |
| outputs="text", | |
| examples=[[example1, "nsql-6B"], | |
| [example2, "nsql-2B"], | |
| [example3, "nsql-350M"]]) | |
| iface.launch() |