Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from lavis.models import load_model_and_preprocess | |
| import torch | |
| device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
| model_name = "blip2_t5_instruct" | |
| model_type = "flant5xl" | |
| model, vis_processors, _ = load_model_and_preprocess( | |
| name=args.model_name, | |
| model_type=args.model_type, | |
| is_eval=True, | |
| device=device, | |
| ) | |
| def infer(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method): | |
| use_nucleus_sampling = decoding_method == "Nucleus sampling" | |
| print(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, use_nucleus_sampling) | |
| image = vis_processors["eval"](image).unsqueeze(0).to(device) | |
| samples = { | |
| "image": image, | |
| "prompt": prompt, | |
| } | |
| output = model.generate( | |
| samples, | |
| length_penalty=float(len_penalty), | |
| repetition_penalty=float(repetition_penalty), | |
| num_beams=beam_size, | |
| max_length=max_len, | |
| min_length=min_len, | |
| top_p=top_p, | |
| use_nucleus_sampling=use_nucleus_sampling, | |
| ) | |
| return output[0] | |
| theme = gr.themes.Monochrome( | |
| primary_hue="indigo", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| radius_size=gr.themes.sizes.radius_sm, | |
| font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ) | |
| css = ".generating {visibility: hidden}" | |
| with gr.Blocks(theme=theme, analytics_enabled=False,css=css) as demo: | |
| with gr.Column(scale=3): | |
| image_input = gr.Image(type="pil") | |
| prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2) | |
| output = gr.Textbox(label="Output") | |
| submit = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=1): | |
| min_len = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=1, | |
| step=1, | |
| interactive=True, | |
| label="Min Length", | |
| ) | |
| max_len = gr.Slider( | |
| minimum=10, | |
| maximum=500, | |
| value=250, | |
| step=5, | |
| interactive=True, | |
| label="Max Length", | |
| ) | |
| sampling = gr.Radio( | |
| choices=["Beam search", "Nucleus sampling"], | |
| value="Beam search", | |
| label="Text Decoding Method", | |
| interactive=True, | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.5, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| interactive=True, | |
| label="Top p", | |
| ) | |
| beam_size = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| interactive=True, | |
| label="Beam Size", | |
| ) | |
| len_penalty = gr.Slider( | |
| minimum=-1, | |
| maximum=2, | |
| value=1, | |
| step=0.2, | |
| interactive=True, | |
| label="Length Penalty", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=-1, | |
| maximum=3, | |
| value=1, | |
| step=0.2, | |
| interactive=True, | |
| label="Repetition Penalty", | |
| ) | |
| submit.click(infer, inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], outputs=[output]) | |
| demo.queue(concurrency_count=16).launch(debug=True) | |