Spaces:
Runtime error
Runtime error
| from transformers import MllamaForConditionalGeneration, AutoProcessor, TextIteratorStreamer | |
| from PIL import Image | |
| import requests | |
| import torch | |
| from threading import Thread | |
| import gradio as gr | |
| from gradio import FileData | |
| import time | |
| import spaces | |
| import re | |
| ckpt = "Xkev/Llama-3.2V-11B-cot" | |
| model = MllamaForConditionalGeneration.from_pretrained(ckpt, | |
| torch_dtype=torch.bfloat16).to("cuda") | |
| processor = AutoProcessor.from_pretrained(ckpt) | |
| def bot_streaming(message, history, max_new_tokens=250): | |
| txt = message["text"] | |
| ext_buffer = f"{txt}" | |
| messages= [] | |
| images = [] | |
| for i, msg in enumerate(history): | |
| if isinstance(msg[0], tuple): | |
| messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]}) | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]}) | |
| images.append(Image.open(msg[0][0]).convert("RGB")) | |
| elif isinstance(history[i-1], tuple) and isinstance(msg[0], str): | |
| # messages are already handled | |
| pass | |
| elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn | |
| messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]}) | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]}) | |
| # add current message | |
| if len(message["files"]) == 1: | |
| if isinstance(message["files"][0], str): # examples | |
| image = Image.open(message["files"][0]).convert("RGB") | |
| else: # regular input | |
| image = Image.open(message["files"][0]["path"]).convert("RGB") | |
| images.append(image) | |
| messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]}) | |
| else: | |
| messages.append({"role": "user", "content": [{"type": "text", "text": txt}]}) | |
| texts = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| if images == []: | |
| inputs = processor(text=texts, return_tensors="pt").to("cuda") | |
| else: | |
| inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda") | |
| streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) | |
| generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.6, top_p=0.9) | |
| generated_text = "" | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| generated_text_without_prompt = buffer | |
| time.sleep(0.01) | |
| buffer = re.sub(r"<(\w+)>", r"\<\1\>", buffer) | |
| buffer = re.sub(r"</(\w+)>", r"\</\1\>", buffer) | |
| yield buffer | |
| demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA-CoT", | |
| textbox=gr.MultimodalTextbox(), | |
| additional_inputs = [gr.Slider( | |
| minimum=512, | |
| maximum=1024, | |
| value=512, | |
| step=1, | |
| label="Maximum number of new tokens to generate", | |
| ) | |
| ], | |
| examples=[[{"text": "What is on the flower?", "files": ["./Example1.webp"]},512], | |
| [{"text": "How to make this pastry?", "files": ["./Example2.png"]},512]], | |
| cache_examples=False, | |
| description="Upload an image, and start chatting about it. To learn more about LLaVA-CoT, visit [our GitHub page](https://github.com/PKU-YuanGroup/LLaVA-CoT).", | |
| stop_btn="Stop Generation", | |
| fill_height=True, | |
| multimodal=True) | |
| demo.launch(debug=True) |