Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import json | |
| from model import create_flan_T5_model | |
| from timeit import default_timer as timer | |
| from typing import Tuple, Dict | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ### Load example texts ### | |
| with open("data.json", 'r', encoding='utf-8') as f: | |
| loaded_data = json.load(f) | |
| questions_texts = loaded_data["questions"] | |
| system_prompts = loaded_data["system_prompts"] | |
| response_texts = loaded_data["responses"] | |
| ### Model and transforms preparation ### | |
| # Create model and tokenizer | |
| model, tokenizer = create_flan_T5_model() | |
| # Load saved weights | |
| model.load_state_dict( | |
| torch.load(f="flan-t5-small.pth", | |
| map_location=torch.device("cpu")) # load to CPU | |
| ) | |
| ### Predict function ### | |
| def predict(selection: str) -> Tuple[Dict, str, float]: | |
| start_time = timer() | |
| model.eval() | |
| # Extract the question part from the selection | |
| # Assuming the format "Prompt: {prompt} | Question: {question}" | |
| question = selection.split("| Question: ")[1] | |
| # Find the index of the question | |
| idx = questions_texts.index(question) | |
| # Now, use the index to get the system prompt and actual response | |
| system_prompt = system_prompts[idx] | |
| response = response_texts[idx] | |
| input_text = f"context: {system_prompt} question: {question}" | |
| model_inputs = tokenizer(input_text, return_tensors="pt", max_length=512, padding='max_length', truncation=True).to(device) | |
| with torch.inference_mode(): | |
| predicted_token_ids = model.generate(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'], max_length=128) | |
| result = tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True) | |
| end_time = timer() | |
| pred_time = round(end_time - start_time, 4) | |
| return {"Predicted Answer": result}, {"Actual Answer": response}, pred_time | |
| ### 4. Gradio app ### | |
| # Create title, description and article | |
| title = "Prompt Answering with Google's flan-t5-small" | |
| description = "[google/flan-t5-small based model](https://huggingface.co/google/flan-t5-small) LLM model trained to take prompts and tasks on the [HuggingFace π€ Open-Orca/OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca). [Source Code Found Here](https://colab.research.google.com/drive/1sIScjt_hyNegHC15Y76JVXEOUvdD_2dh?usp=sharing)" | |
| article = "Built with [Gradio](https://github.com/gradio-app/gradio) and [PyTorch](https://pytorch.org/). [Source Code Found Here](https://colab.research.google.com/drive/1sIScjt_hyNegHC15Y76JVXEOUvdD_2dh?usp=sharing)" | |
| dropdown_choices = [f"Prompt: {prompt} | Question: {question}" for prompt, question in zip(system_prompts, questions_texts)] | |
| # Create the Gradio demo | |
| demo = gr.Interface(fn=predict, | |
| inputs=gr.Dropdown(choices=dropdown_choices, label="Select a Question and Prompt"), | |
| outputs=[ | |
| gr.JSON(label="Predicted Answer"), | |
| gr.Textbox(label="Actual Answer"), | |
| gr.Number(label="Prediction time (s)") | |
| ], | |
| title=title, | |
| description=description, | |
| article=article) | |
| # Launch the demo | |
| demo.launch() | |