Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| from gemma.modeling_gemma import GemmaForCausalLM | |
| import torch | |
| import time | |
| # Assuming the GemmaForCausalLM and the specific tokenizer are correctly installed and imported | |
| def inference(input_text): | |
| start_time = time.time() | |
| input_ids = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| input_length = input_ids["input_ids"].shape[1] | |
| outputs = model.generate( | |
| input_ids=input_ids["input_ids"], | |
| max_length=1024, | |
| do_sample=False) | |
| generated_sequence = outputs[:, input_length:].tolist() | |
| res = tokenizer.decode(generated_sequence[0]) | |
| end_time = time.time() | |
| return {"output": res, "latency": f"{end_time - start_time:.2f} seconds"} | |
| # Initialize the tokenizer and model | |
| model_id = "NexaAIDev/Octopus-v2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = GemmaForCausalLM.from_pretrained( | |
| model_id, torch_dtype=torch.bfloat16, device_map="auto" | |
| ) | |
| def gradio_interface(input_text): | |
| nexa_query = f"Below is the query from the users, please call the correct function and generate the parameters to call the function.\n\nQuery: {input_text} \n\nResponse:" | |
| result = inference(nexa_query) | |
| return result["output"], result["latency"] | |
| iface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your query here..."), | |
| outputs=[gr.outputs.Textbox(label="Output"), gr.outputs.Textbox(label="Latency")], | |
| title="Gemma Model Inference", | |
| description="This application uses the Gemma model for generating responses based on the input query." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |