Spaces:
Runtime error
Runtime error
| import inspect | |
| import json | |
| import os | |
| import random | |
| from typing import Literal, cast | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from gradio.data_classes import InterfaceTypes | |
| from gradio.flagging import CSVLogger | |
| from torchvision import transforms | |
| from transformers import AutoTokenizer, LlamaForCausalLM | |
| from trace_exec import run_program_with_trace, CompileTimeError | |
| from vision_processes import load_models | |
| print("-" * 10, "Loading models...") | |
| load_models() | |
| with open('joint.prompt') as f: | |
| prompt_template = f.read().strip() | |
| INPUT_TYPE = 'image' | |
| OUTPUT_TYPE = 'str' | |
| SIGNATURE = f'def execute_command({INPUT_TYPE}) -> {OUTPUT_TYPE}:' | |
| def generate(model, input_text): | |
| torch.cuda.empty_cache() | |
| print("-" * 10, "Before loading LLM:") | |
| print(torch.cuda.memory_summary()) | |
| dtype = os.environ.get("CODELLAMA_DTYPE") | |
| assert dtype in ['bfloat16', '8bit', '4bit', ] | |
| tokenizer = AutoTokenizer.from_pretrained(model) | |
| model = LlamaForCausalLM.from_pretrained( | |
| model, | |
| device_map="auto", | |
| load_in_8bit=dtype == "8bit", | |
| load_in_4bit=dtype == "4bit", | |
| torch_dtype=torch.bfloat16 if dtype == "bfloat16" else None, | |
| ) | |
| print("-" * 10, "LLM loaded:") | |
| print(model) | |
| print(torch.cuda.memory_summary()) | |
| input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
| generated_ids = model.generate( | |
| input_ids.to('cuda'), max_new_tokens=256, stop_strings=["\n\n"], do_sample=False, tokenizer=tokenizer | |
| ) | |
| generated_ids = generated_ids[0][input_ids.shape[1]:] | |
| text = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| del model | |
| torch.cuda.empty_cache() | |
| print("-" * 10, "After loading LLM:") | |
| print(torch.cuda.memory_summary()) | |
| return text | |
| def to_custom_trace(result, error, traced): | |
| if traced is None: | |
| assert isinstance(error, CompileTimeError) | |
| traced = 'Compile Error' | |
| return "-> {}\n\n--- Trace\n\n{}".format(result, traced) | |
| def answer_from_trace(x): | |
| assert x.startswith("->") | |
| return x[2:].splitlines()[0].strip() | |
| def debug(image, question, code, traced_info): | |
| # critic | |
| prompt = f"# Given an image: {question}\n{code}\n\n{traced_info}\n\n# Program is" | |
| print("--- For debug: critic prompt is ---") | |
| print(prompt) | |
| print("---\n") | |
| critic_out = generate("VDebugger/VDebugger-critic-generalist-7B", prompt) | |
| incorrect = critic_out.strip().startswith('wrong') | |
| critic_out = "# Program is" + critic_out | |
| if not incorrect: | |
| yield code, traced_info, critic_out, "N/A", "N/A", answer_from_trace(traced_info) | |
| return | |
| else: | |
| yield code, traced_info, critic_out, "RUNNING IN PROGRESS...", "", "" | |
| # refiner | |
| critic_code = ('def execute_command' + critic_out.split('def execute_command')[1]).strip() | |
| if '# Program is' in code: | |
| critic_code = critic_code.split("# Program is")[0].strip() # errr, an awkward fix | |
| prompt = f"# Given an image: {question}\n{critic_code}\n\n{traced_info}\n\n# Correction" | |
| print("--- For debug: refiner prompt is ---") | |
| print(prompt) | |
| print("---\n") | |
| refiner_out = generate("VDebugger/VDebugger-refiner-generalist-7B", prompt).strip() | |
| yield code, traced_info, critic_out, refiner_out, "RUNNING IN PROGRESS...", "" | |
| # execute (again) | |
| result, error, traced = run_program_with_trace(refiner_out, image, INPUT_TYPE, OUTPUT_TYPE) | |
| traced_info_2 = to_custom_trace(result, error, traced) | |
| yield code, traced_info, critic_out, refiner_out, traced_info_2, answer_from_trace(traced_info_2) | |
| def predict(image, question): | |
| if image is None: | |
| gr.Warning("Please provide an image", duration=5) | |
| return | |
| image = transforms.Compose([transforms.ToTensor()])(image) | |
| question = question.strip() | |
| if question == "": | |
| gr.Warning("Please provide a question", duration=5) | |
| return | |
| # codellama | |
| prompt = prompt_template.replace("INSERT_QUERY_HERE", f"Given an image: {question}\n{SIGNATURE}") | |
| code = generate("codellama/CodeLlama-7b-Python-hf", prompt) | |
| code = (SIGNATURE + code).strip() | |
| yield code, "RUNNING IN PROGRESS...", "", "", "", "" | |
| # execute | |
| result, error, traced = run_program_with_trace(code, image, INPUT_TYPE, OUTPUT_TYPE) | |
| traced_info = to_custom_trace(result, error, traced) | |
| yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" | |
| for tup in debug(image, question, code, traced_info): | |
| yield tup | |
| return | |
| def re_debug(image, question, code, traced_info): | |
| if code is None or code == "" or traced_info is None or traced_info == "": | |
| gr.Warning("No prior debugging round", duration=5) | |
| return | |
| yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" | |
| for tup in debug(image, question, code, traced_info): | |
| yield tup | |
| return | |
| DESCRIPTION = """# VDebugger | |
| | [Paper](https://arxiv.org/abs/2406.13444) | [Project](https://shirley-wu.github.io/vdebugger/) | [Code](https://github.com/shirley-wu/vdebugger/) | [Models and Data](https://huggingface.co/VDebugger) | | |
| **VDebugger** is a novel critic-refiner framework trained to localize and debug *visual programs* by tracking execution step by step. In this demo, we show the visual programs, the outputs from both the critic and the refiner, as well as the final result. | |
| **Warning:** Reduced performance and accuracy may be observed. Due to resource limitation of huggingface spaces, this demo runs Llama inference in 4-bit quantization and uses smaller foundation VLMs. For full capacity, please use the original code.""" | |
| class MyInterface(gr.Interface): | |
| def __init__(self): | |
| super(gr.Interface, self).__init__( | |
| title=None, | |
| theme=None, | |
| analytics_enabled=None, | |
| mode="tabbed_interface", | |
| css=None, | |
| js=None, | |
| head=None, | |
| ) | |
| self.interface_type = InterfaceTypes.STANDARD | |
| self.description = DESCRIPTION | |
| self.cache_examples = None | |
| self.examples_per_page = 5 | |
| self.example_labels = None | |
| self.batch = False | |
| self.live = False | |
| self.api_name = "predict" | |
| self.max_batch_size = 4 | |
| self.concurrency_limit = 'default' | |
| self.show_progress = "full" | |
| self.allow_flagging = 'auto' | |
| self.flagging_options = [("Flag", ""), ] | |
| self.flagging_callback = CSVLogger() | |
| self.flagging_dir = 'flagged' | |
| # Load examples | |
| with open('examples/questions.json') as f: | |
| example_questions = json.load(f) | |
| self.examples = [] | |
| for question in example_questions: | |
| self.examples.append([ | |
| Image.open('examples/{}.jpg'.format(question['imageId'])), question['question'], | |
| ]) | |
| def load_random_example(): | |
| image, question = random.choice(self.examples) | |
| return image, question, "", "", "", "", "", "" | |
| # Render the Gradio UI | |
| with self: | |
| self.render_title_description() | |
| with gr.Row(): | |
| image = gr.Image(label="Image", type="pil", width="30%", scale=1) | |
| question = gr.Textbox(label="Question", scale=2) | |
| with gr.Row(): | |
| _clear_btn = gr.ClearButton(value="Clear", variant="secondary") | |
| _random_eg_btn = gr.Button("Random Example Input") | |
| _submit_btn = gr.Button("Submit", variant="primary") | |
| if inspect.isgeneratorfunction(predict) or inspect.isasyncgenfunction(predict): | |
| _stop1_btn = gr.Button("Stop", variant="stop", visible=False) | |
| _redebug_btn = gr.Button("Debug for Another Round", variant="primary") | |
| if inspect.isgeneratorfunction(re_debug) or inspect.isasyncgenfunction(re_debug): | |
| _stop2_btn = gr.Button("Stop", variant="stop", visible=False) | |
| with gr.Row(): | |
| o1 = gr.Textbox(label="No debugging: program") | |
| o2 = gr.Textbox(label="No debugging: execution") | |
| with gr.Row(): | |
| o3 = gr.Textbox(label="VDebugger: critic") | |
| o4 = gr.Textbox(label="VDebugger: refiner") | |
| with gr.Row(): | |
| o5 = gr.Textbox(label="VDebugger: execution") | |
| o6 = gr.Textbox(label="VDebugger: final answer") | |
| question.submit(fn=predict, inputs=[image, question], outputs=[o1, o2, o3, o4, o5, o6]) | |
| _random_eg_btn.click(fn=load_random_example, outputs=[image, question, o1, o2, o3, o4, o5, o6]) | |
| async def cleanup(): | |
| return [gr.Button(visible=True), gr.Button(visible=False)] | |
| # Setup redebug event | |
| triggers = [_redebug_btn.click, ] | |
| extra_output = [_redebug_btn, _stop2_btn] | |
| predict_event = gr.on( | |
| triggers, | |
| gr.utils.async_lambda( | |
| lambda: ( | |
| gr.Button(visible=False), | |
| gr.Button(visible=True), | |
| ) | |
| ), | |
| inputs=None, | |
| outputs=[_redebug_btn, _stop2_btn], | |
| queue=False, | |
| show_api=False, | |
| ).then( | |
| re_debug, | |
| [image, question, o4, o5], | |
| [o1, o2, o3, o4, o5, o6], | |
| api_name=self.api_name, | |
| scroll_to_output=False, | |
| preprocess=not (self.api_mode), | |
| postprocess=not (self.api_mode), | |
| batch=self.batch, | |
| max_batch_size=self.max_batch_size, | |
| concurrency_limit=self.concurrency_limit, | |
| show_progress=cast( | |
| Literal["full", "minimal", "hidden"], self.show_progress | |
| ), | |
| ) | |
| redebug_event = predict_event.then( | |
| cleanup, | |
| inputs=None, | |
| outputs=extra_output, # type: ignore | |
| queue=False, | |
| show_api=False, | |
| ) | |
| _stop2_btn.click( | |
| cleanup, | |
| inputs=None, | |
| outputs=[_redebug_btn, _stop2_btn], | |
| cancels=predict_event, | |
| queue=False, | |
| show_api=False, | |
| ) | |
| # Setup submit event | |
| triggers = [_submit_btn.click, question.submit, ] | |
| extra_output = [_submit_btn, _stop1_btn] | |
| predict_event = gr.on( | |
| triggers, | |
| gr.utils.async_lambda( | |
| lambda: ( | |
| gr.Button(visible=False), | |
| gr.Button(visible=True), | |
| ) | |
| ), | |
| inputs=None, | |
| outputs=[_submit_btn, _stop1_btn], | |
| queue=False, | |
| show_api=False, | |
| ).then( | |
| predict, | |
| [image, question], | |
| [o1, o2, o3, o4, o5, o6], | |
| api_name=self.api_name, | |
| scroll_to_output=False, | |
| preprocess=not (self.api_mode), | |
| postprocess=not (self.api_mode), | |
| batch=self.batch, | |
| max_batch_size=self.max_batch_size, | |
| concurrency_limit=self.concurrency_limit, | |
| show_progress=cast( | |
| Literal["full", "minimal", "hidden"], self.show_progress | |
| ), | |
| ) | |
| submit_event = predict_event.then( | |
| cleanup, | |
| inputs=None, | |
| outputs=extra_output, # type: ignore | |
| queue=False, | |
| show_api=False, | |
| ) | |
| _stop1_btn.click( | |
| cleanup, | |
| inputs=None, | |
| outputs=[_submit_btn, _stop1_btn], | |
| cancels=predict_event, | |
| queue=False, | |
| show_api=False, | |
| ) | |
| # Finally borrow Interface stuff | |
| self.input_components = [image, question] | |
| self.output_components = [o1, o2, o3, o4, o5, o6] | |
| self.fn = predict | |
| self.attach_clear_events(_clear_btn, None) | |
| self.render_examples() | |
| if __name__ == "__main__": | |
| MyInterface().launch(share=os.environ.get("SHARE", '') != "") | |