Spaces:
Running
on
Zero
Running
on
Zero
| """Developed by Ruslan Magana Vsevolodovna""" | |
| from collections.abc import Iterator | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import Thread | |
| import io | |
| import base64 | |
| import random | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
| from themes.research_monochrome import theme | |
| # ============================================================================= | |
| # Constants & Prompts | |
| # ============================================================================= | |
| today_date = datetime.today().strftime("%B %-d, %Y") | |
| SYS_PROMPT = """ | |
| Respond in the following format: | |
| <reasoning> | |
| ... | |
| </reasoning> | |
| <answer> | |
| ... | |
| </answer> | |
| """ | |
| TITLE = "IBM Granite 3.1 8b Reasoning & Vision Preview" | |
| DESCRIPTION = """ | |
| <p>Granite 3.1 8b Reasoning is an open‐source LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision‐language capabilities. Start with one of the sample prompts | |
| or enter your own. Keep in mind that AI can occasionally make mistakes. | |
| <span class="gr_docs_link"> | |
| <a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a> | |
| </span> | |
| </p> | |
| """ | |
| MAX_INPUT_TOKEN_LENGTH = 128_000 | |
| MAX_NEW_TOKENS = 1024 | |
| TEMPERATURE = 0.5 | |
| TOP_P = 0.85 | |
| TOP_K = 50 | |
| REPETITION_PENALTY = 1.05 | |
| # Vision defaults (advanced settings) | |
| VISION_TEMPERATURE = 0.2 | |
| VISION_TOP_P = 0.95 | |
| VISION_TOP_K = 50 | |
| VISION_MAX_TOKENS = 128 | |
| if not torch.cuda.is_available(): | |
| print("This demo may not work on CPU.") | |
| # ============================================================================= | |
| # Text Model Loading | |
| # ============================================================================= | |
| granite_text_model = "ruslanmv/granite-3.1-8b-Reasoning" | |
| text_model = AutoModelForCausalLM.from_pretrained( | |
| granite_text_model, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(granite_text_model) | |
| tokenizer.use_default_system_prompt = False | |
| # ============================================================================= | |
| # Vision Model Loading | |
| # ============================================================================= | |
| vision_model_path = "ibm-granite/granite-vision-3.1-2b-preview" | |
| vision_processor = LlavaNextProcessor.from_pretrained(vision_model_path, use_fast=True) | |
| vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| vision_model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True # Ensure the custom code is used so that weight shapes match. | |
| ) | |
| # ============================================================================= | |
| # Unified Display Function | |
| # ============================================================================= | |
| def get_text_from_content(content): | |
| """Helper to extract text from a list of content items.""" | |
| texts = [] | |
| for item in content: | |
| if isinstance(item, dict): | |
| if item.get("type") == "text": | |
| texts.append(item.get("text", "")) | |
| elif item.get("type") == "image": | |
| image = item.get("image") | |
| if image is not None: | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| texts.append(f'<img src="data:image/jpeg;base64,{img_str}" style="max-width: 200px; max-height: 200px;">') | |
| else: | |
| texts.append("<image>") | |
| else: | |
| texts.append(str(item)) | |
| return " ".join(texts) | |
| def display_unified_conversation(conversation): | |
| """ | |
| Combine both text-only and vision messages. | |
| Each conversation entry is expected to be a dict with keys: | |
| - role: "user" or "assistant" | |
| - content: either a string (for text) or a list of content items (for vision) | |
| """ | |
| chat_history = [] | |
| i = 0 | |
| while i < len(conversation): | |
| if conversation[i]["role"] == "user": | |
| user_content = conversation[i]["content"] | |
| if isinstance(user_content, list): | |
| user_msg = get_text_from_content(user_content) | |
| else: | |
| user_msg = user_content | |
| assistant_msg = "" | |
| if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant": | |
| asst_content = conversation[i+1]["content"] | |
| if isinstance(asst_content, list): | |
| assistant_msg = get_text_from_content(asst_content) | |
| else: | |
| assistant_msg = asst_content | |
| i += 2 | |
| else: | |
| i += 1 | |
| chat_history.append((user_msg, assistant_msg)) | |
| else: | |
| i += 1 | |
| return chat_history | |
| # ============================================================================= | |
| # Text Generation Function (for text-only chat) | |
| # ============================================================================= | |
| def generate( | |
| message: str, | |
| chat_history: list[dict], | |
| temperature: float = TEMPERATURE, | |
| repetition_penalty: float = REPETITION_PENALTY, | |
| top_p: float = TOP_P, | |
| top_k: float = TOP_K, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| ) -> Iterator[str]: | |
| """ | |
| Generate function for text chat. It streams tokens and stops once the generated answer | |
| contains the closing </answer> tag. | |
| """ | |
| conversation = [] | |
| conversation.append({"role": "system", "content": SYS_PROMPT}) | |
| conversation.extend(chat_history) | |
| conversation.append({"role": "user", "content": message}) | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| truncation=True, | |
| max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens, | |
| ) | |
| input_ids = input_ids.to(text_model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = { | |
| "input_ids": input_ids, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "temperature": temperature, | |
| "num_beams": 1, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| t = Thread(target=text_model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| reasoning_started = False | |
| answer_started = False | |
| collected_reasoning = "" | |
| collected_answer = "" | |
| for text in streamer: | |
| outputs.append(text) | |
| current_output = "".join(outputs) | |
| if "<reasoning>" in current_output and not reasoning_started: | |
| reasoning_started = True | |
| reasoning_start_index = current_output.find("<reasoning>") + len("<reasoning>") | |
| collected_reasoning = current_output[reasoning_start_index:] | |
| yield "[Reasoning]: " | |
| outputs = [collected_reasoning] | |
| elif reasoning_started and "<answer>" in current_output and not answer_started: | |
| answer_started = True | |
| reasoning_end_index = current_output.find("<answer>") | |
| collected_reasoning = current_output[len("<reasoning>"):reasoning_end_index] | |
| answer_start_index = current_output.find("<answer>") + len("<answer>") | |
| collected_answer = current_output[answer_start_index:] | |
| yield "\n[Answer]: " | |
| outputs = [collected_answer] | |
| yield collected_answer | |
| elif reasoning_started and not answer_started: | |
| collected_reasoning += text | |
| yield text | |
| elif answer_started: | |
| collected_answer += text | |
| yield text | |
| if "</answer>" in collected_answer: | |
| break | |
| else: | |
| yield text | |
| # ============================================================================= | |
| # Vision Chat Inference Function (for image+text chat) | |
| # ============================================================================= | |
| def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, top_p=VISION_TOP_P, top_k=VISION_TOP_K, max_tokens=VISION_MAX_TOKENS): | |
| if conversation is None: | |
| conversation = [] | |
| user_content = [] | |
| if image is not None: | |
| user_content.append({"type": "image", "image": image}) | |
| if text and text.strip(): | |
| user_content.append({"type": "text", "text": text.strip()}) | |
| if not user_content: | |
| return display_unified_conversation(conversation), conversation | |
| conversation.append({"role": "user", "content": user_content}) | |
| inputs = vision_processor.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to("cuda") | |
| torch.manual_seed(random.randint(0, 10000)) | |
| generation_kwargs = { | |
| "max_new_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "do_sample": True, | |
| } | |
| output = vision_model.generate(**inputs, **generation_kwargs) | |
| assistant_response = vision_processor.decode(output[0], skip_special_tokens=True) | |
| if "<|assistant|>" in assistant_response: | |
| assistant_response_parts = assistant_response.split("<|assistant|>") | |
| assistant_response_text = assistant_response_parts[-1].strip() | |
| else: | |
| assistant_response_text = assistant_response.strip() | |
| conversation.append({"role": "assistant", "content": [{"type": "text", "text": assistant_response_text.strip()}]}) | |
| return display_unified_conversation(conversation), conversation | |
| # ============================================================================= | |
| # Unified Send-Message Function | |
| # | |
| # We now maintain two histories: | |
| # - unified_state: complete conversation (for display) | |
| # - internal_text_state: only text turns (for text generation) | |
| # Vision turns update only unified_state. | |
| # ============================================================================= | |
| def send_message(image, text, | |
| text_temperature, text_repetition_penalty, text_top_p, text_top_k, text_max_new_tokens, | |
| vision_temperature, vision_top_p, vision_top_k, vision_max_tokens, | |
| unified_state, vision_state, internal_text_state): | |
| # Initialize states if empty | |
| if unified_state is None: | |
| unified_state = [] | |
| if internal_text_state is None: | |
| internal_text_state = [] | |
| if image is not None: | |
| # Use vision inference. | |
| user_msg = [] | |
| user_msg.append({"type": "image", "image": image}) | |
| if text and text.strip(): | |
| user_msg.append({"type": "text", "text": text.strip()}) | |
| unified_state.append({"role": "user", "content": user_msg}) | |
| chat_history, updated_vision_conv = chat_inference(image, text, vision_state, | |
| temperature=vision_temperature, | |
| top_p=vision_top_p, | |
| top_k=vision_top_k, | |
| max_tokens=vision_max_tokens) | |
| vision_state = updated_vision_conv | |
| if updated_vision_conv and updated_vision_conv[-1]["role"] == "assistant": | |
| unified_state.append(updated_vision_conv[-1]) | |
| yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state | |
| else: | |
| # Text-only mode: update both unified and internal text states. | |
| unified_state.append({"role": "user", "content": text}) | |
| internal_text_state.append({"role": "user", "content": text}) | |
| unified_state.append({"role": "assistant", "content": ""}) | |
| internal_text_state.append({"role": "assistant", "content": ""}) | |
| yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state | |
| base_conv = internal_text_state[:-1] | |
| assistant_text = "" | |
| for chunk in generate( | |
| text, base_conv, | |
| temperature=text_temperature, | |
| repetition_penalty=text_repetition_penalty, | |
| top_p=text_top_p, | |
| top_k=text_top_k, | |
| max_new_tokens=text_max_new_tokens | |
| ): | |
| assistant_text += chunk | |
| unified_state[-1]["content"] = assistant_text | |
| internal_text_state[-1]["content"] = assistant_text | |
| yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state | |
| yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state | |
| # ============================================================================= | |
| # Clear Chat Function | |
| # ============================================================================= | |
| def clear_chat(): | |
| # Clear unified conversation, vision state, and internal text state. | |
| return [], [], [], [], "", None | |
| # ============================================================================= | |
| # UI Layout with Gradio | |
| # ============================================================================= | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| with gr.Blocks(fill_height=True, css_paths=[str(css_file_path)], head_paths=[str(head_file_path)], theme=theme, title=TITLE) as demo: | |
| gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"]) | |
| gr.HTML(DESCRIPTION) | |
| chatbot = gr.Chatbot(label="Chat History", height=500) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Upload Image (optional)") | |
| text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message") | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Text Advanced Settings", open=False): | |
| text_temperature_slider = gr.Slider(minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]) | |
| repetition_penalty_slider = gr.Slider(minimum=0, maximum=2.0, value=REPETITION_PENALTY, step=0.05, label="Repetition Penalty", elem_classes=["gr_accordion_element"]) | |
| top_p_slider = gr.Slider(minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]) | |
| top_k_slider = gr.Slider(minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]) | |
| max_new_tokens_slider = gr.Slider(minimum=1, maximum=2000, value=MAX_NEW_TOKENS, step=1, label="Max New Tokens", elem_classes=["gr_accordion_element"]) | |
| with gr.Accordion("Vision Advanced Settings", open=False): | |
| vision_temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=VISION_TEMPERATURE, step=0.01, label="Vision Temperature", elem_classes=["gr_accordion_element"]) | |
| vision_top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"]) | |
| vision_top_k_slider = gr.Slider(minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"]) | |
| vision_max_tokens_slider = gr.Slider(minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"]) | |
| send_button = gr.Button("Send Message") | |
| clear_button = gr.Button("Clear Chat") | |
| # Conversation state variables: | |
| # - unified_state: complete conversation for display (text and vision) | |
| # - vision_state: state for vision turns | |
| # - internal_text_state: only text turns (for text-generation) | |
| unified_state = gr.State([]) | |
| vision_state = gr.State([]) | |
| internal_text_state = gr.State([]) | |
| send_button.click( | |
| send_message, | |
| inputs=[ | |
| image_input, text_input, | |
| text_temperature_slider, repetition_penalty_slider, top_p_slider, top_k_slider, max_new_tokens_slider, | |
| vision_temperature_slider, vision_top_p_slider, vision_top_k_slider, vision_max_tokens_slider, | |
| unified_state, vision_state, internal_text_state | |
| ], | |
| outputs=[chatbot, unified_state, vision_state, internal_text_state], | |
| ) | |
| clear_button.click( | |
| clear_chat, | |
| inputs=None, | |
| outputs=[chatbot, unified_state, vision_state, internal_text_state, text_input, image_input] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/cheetah1.jpg", "What is in this image?"], | |
| [None, "Compute Pi."], | |
| [None, "Explain quantum computing to a beginner."], | |
| [None, "What is OpenShift?"], | |
| [None, "Importance of low latency inference"], | |
| [None, "Boosting productivity habits"], | |
| [None, "Explain and document your code"], | |
| [None, "Generate Java Code"] | |
| ], | |
| inputs=[image_input, text_input], | |
| example_labels=[ | |
| "Vision Example: What is in this image?", | |
| "Compute Pi.", | |
| "Explain quantum computing", | |
| "What is OpenShift?", | |
| "Importance of low latency inference", | |
| "Boosting productivity habits", | |
| "Explain and document your code", | |
| "Generate Java Code" | |
| ], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True, share=False) | |