Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import json | |
| import os | |
| from screenshot import ( | |
| before_prompt, | |
| prompt_to_generation, | |
| after_generation, | |
| js_save, | |
| js_load_script, | |
| ) | |
| from spaces_info import description, examples, initial_prompt_value | |
| API_URL = os.getenv("API_URL") | |
| HF_API_TOKEN = os.getenv("HF_API_TOKEN") | |
| def query(payload): | |
| print(payload) | |
| response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"}) | |
| print(response) | |
| return json.loads(response.content.decode("utf-8")) | |
| def inference(input_sentence, max_length, sample_or_greedy, seed=42): | |
| if sample_or_greedy == "Sample": | |
| parameters = { | |
| "max_new_tokens": max_length, | |
| "top_p": 0.9, | |
| "do_sample": True, | |
| "seed": seed, | |
| "early_stopping": False, | |
| "length_penalty": 0.0, | |
| "eos_token_id": None, | |
| } | |
| else: | |
| parameters = { | |
| "max_new_tokens": max_length, | |
| "do_sample": False, | |
| "seed": seed, | |
| "early_stopping": False, | |
| "length_penalty": 0.0, | |
| "eos_token_id": None, | |
| } | |
| payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} } | |
| data = query(payload) | |
| if "error" in data: | |
| return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>") | |
| generation = data[0]["generated_text"].split(input_sentence, 1)[1] | |
| return ( | |
| before_prompt | |
| + input_sentence | |
| + prompt_to_generation | |
| + generation | |
| + after_generation, | |
| data[0]["generated_text"], | |
| "", | |
| ) | |
| if __name__ == "__main__": | |
| demo = gr.Blocks() | |
| with demo: | |
| with gr.Row(): | |
| gr.Markdown(value=description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text = gr.Textbox( | |
| label="Input", | |
| value=" ", # should be set to " " when plugged into a real API | |
| ) | |
| tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate") | |
| sampling = gr.Radio( | |
| ["Sample", "Greedy"], label="Sample or greedy", value="Sample" | |
| ) | |
| sampling2 = gr.Radio( | |
| ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"], | |
| value="Sample 1", | |
| label="Sample other generations (only work in 'Sample' mode)", | |
| type="index", | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Submit") | |
| load_image = gr.Button("Generate Image") | |
| with gr.Column(): | |
| text_error = gr.Markdown(label="Log information") | |
| text_out = gr.Textbox(label="Output") | |
| display_out = gr.HTML(label="Image") | |
| display_out.set_event_trigger( | |
| "load", | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| no_target=True, | |
| js=js_load_script, | |
| ) | |
| with gr.Row(): | |
| gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2]) | |
| submit.click( | |
| inference, | |
| inputs=[text, tokens, sampling, sampling2], | |
| outputs=[display_out, text_out, text_error], | |
| ) | |
| load_image.click(fn=None, inputs=None, outputs=None, _js=js_save) | |
| demo.launch() | |