Spaces:
Running
Running
| from dotenv import load_dotenv | |
| from replicate.client import Client | |
| from transformers import AutoTokenizer # Add this import | |
| import gradio as gr | |
| import json | |
| import time | |
| import re | |
| import os | |
| # CSS styling | |
| css = """ | |
| .category-legend{display:none} | |
| button{height: 60px} | |
| """ | |
| # Constants | |
| MASK_TOKEN = "[MASK]" | |
| # Initialize environment and client | |
| load_dotenv() | |
| replicate = Client(api_token=os.environ.get("REPLICATE_API_TOKEN")) | |
| # Load tokenizer for formatting chat template properly | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True | |
| ) | |
| def parse_constraints(constraints_text): | |
| """Parse constraints in format: 'position:word, position:word, ...'""" | |
| constraints = {} | |
| if not constraints_text: | |
| return constraints | |
| parts = constraints_text.split(",") | |
| for part in parts: | |
| if ":" not in part: | |
| continue | |
| pos_str, word = part.split(":", 1) | |
| try: | |
| pos = int(pos_str.strip()) | |
| word = word.strip() | |
| if word and pos >= 0: | |
| constraints[pos] = word | |
| except ValueError: | |
| continue | |
| return constraints | |
| def format_chat_history(history): | |
| """Format chat history for the LLaDA model""" | |
| messages = [] | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: # Skip if None (for the latest user message) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| return messages | |
| def generate_response_with_visualization( | |
| messages, | |
| gen_length=64, | |
| steps=32, | |
| constraints=None, | |
| temperature=0.5, | |
| cfg_scale=0.0, | |
| block_length=32, | |
| remasking="low_confidence", | |
| ): | |
| """Generate text using the Replicate API version of LLaDA with visualization""" | |
| # Process constraints | |
| if constraints is None: | |
| constraints = {} | |
| constraints_json = json.dumps(constraints) | |
| # Format chat using the tokenizer's chat template | |
| chat_input = tokenizer.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=False | |
| ) | |
| # Call Replicate API | |
| output = replicate.run( | |
| "spuuntries/llada-8b-kcv:e8b3ac0457f822454d662dec90edcac05f6e5947a50b55f92b22aa996acbf780", | |
| input={ | |
| "steps": steps, | |
| "prompt": chat_input, | |
| "cfg_scale": cfg_scale, | |
| "remasking": remasking, | |
| "max_tokens": gen_length, | |
| "constraints": constraints_json, | |
| "temperature": temperature, | |
| "block_length": block_length, | |
| "prompt_template": "{prompt}", # Use the already formatted prompt | |
| }, | |
| wait=False, | |
| ) | |
| # Extract final response and states | |
| final_output = output["final_output"] | |
| states = output["states"] | |
| # Extract only the last assistant response by finding the last occurrence | |
| # of the assistant header pattern | |
| last_assistant_pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>\n" | |
| last_assistant_match = list(re.finditer(last_assistant_pattern, final_output)) | |
| if last_assistant_match: | |
| # Get the last match | |
| last_match = last_assistant_match[-1] | |
| # Start position of the actual content (after the header) | |
| start_pos = last_match.end() | |
| # Extract everything from this position to the end or until end token | |
| end_pattern = r"<\|endoftext\|>|<\|start_header_id\|>" | |
| end_match = re.search(end_pattern, final_output[start_pos:]) | |
| if end_match: | |
| end_pos = start_pos + end_match.start() | |
| response_text = final_output[start_pos:end_pos].strip() | |
| else: | |
| response_text = final_output[start_pos:].strip() | |
| else: | |
| response_text = "Error: Could not parse the model response." | |
| # Process states for visualization | |
| visualization_states = [] | |
| # Add initial state (all masked) | |
| initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] | |
| visualization_states.append(initial_state) | |
| for state in states: | |
| # Similar parsing for visualization states | |
| last_assistant_match = list(re.finditer(last_assistant_pattern, state)) | |
| if last_assistant_match: | |
| last_match = last_assistant_match[-1] | |
| start_pos = last_match.end() | |
| tokens_text = state[start_pos:].strip() | |
| tokens = tokens_text.split() | |
| current_state = [] | |
| for token in tokens: | |
| if token == "[MASK]": | |
| current_state.append((token, "#444444")) # Dark gray for masks | |
| else: | |
| current_state.append( | |
| (token, "#6699CC") | |
| ) # Light blue for revealed tokens | |
| visualization_states.append(current_state) | |
| else: | |
| # Fallback if we can't parse properly | |
| visualization_states.append( | |
| [(MASK_TOKEN, "#FF6666")] | |
| ) # Red mask as error indicator | |
| return visualization_states, response_text.replace("<|eot_id|>", "") | |
| def create_chatbot_demo(): | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# LLaDA - Large Language Diffusion Model Demo") | |
| gr.Markdown( | |
| "[model](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct), [project page](https://ml-gsai.github.io/LLaDA-demo/)" | |
| ) | |
| # STATE MANAGEMENT | |
| chat_history = gr.State([]) | |
| # Current response text box (hidden) | |
| current_response = gr.Textbox( | |
| label="Current Response", | |
| placeholder="The assistant's response will appear here...", | |
| lines=3, | |
| visible=False, | |
| ) | |
| # UI COMPONENTS | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot_ui = gr.Chatbot(label="Conversation", height=500) | |
| # Message input | |
| with gr.Group(): | |
| with gr.Row(): | |
| user_input = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| show_label=False, | |
| ) | |
| send_btn = gr.Button("Send") | |
| constraints_input = gr.Textbox( | |
| label="Word Constraints", | |
| info="Format: 'position:word, position:word, ...' Example: '0:Once, 5:upon, 10:time'", | |
| placeholder="0:Once, 5:upon, 10:time", | |
| value="", | |
| ) | |
| with gr.Column(scale=2): | |
| output_vis = gr.HighlightedText( | |
| label="Denoising Process Visualization", | |
| combine_adjacent=False, | |
| show_legend=True, | |
| ) | |
| # Advanced generation settings | |
| with gr.Accordion("Generation Settings", open=False): | |
| with gr.Row(): | |
| gen_length = gr.Slider( | |
| minimum=16, maximum=128, value=64, step=8, label="Generation Length" | |
| ) | |
| steps = gr.Slider( | |
| minimum=8, maximum=128, value=128, step=4, label="Denoising Steps" | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="Temperature" | |
| ) | |
| cfg_scale = gr.Slider( | |
| minimum=0.0, maximum=2.0, value=0.3, step=0.1, label="CFG Scale" | |
| ) | |
| with gr.Row(): | |
| block_length = gr.Slider( | |
| minimum=8, maximum=128, value=32, step=8, label="Block Length" | |
| ) | |
| remasking_strategy = gr.Radio( | |
| choices=["low_confidence", "random"], | |
| value="low_confidence", | |
| label="Remasking Strategy", | |
| ) | |
| with gr.Row(): | |
| visualization_delay = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.05, | |
| step=0.01, | |
| label="Visualization Delay (seconds)", | |
| ) | |
| # Clear button | |
| clear_btn = gr.Button("Clear Conversation") | |
| def add_message(history, message, response): | |
| """Add a message pair to the history and return the updated history""" | |
| history = history.copy() | |
| history.append([message, response]) | |
| return history | |
| def user_message_submitted( | |
| message, history, gen_length, steps, constraints, delay | |
| ): | |
| """Process a submitted user message""" | |
| # Skip empty messages | |
| if not message.strip(): | |
| # Return current state unchanged | |
| history_for_display = history.copy() | |
| return history, history_for_display, "", [], "" | |
| # Add user message to history | |
| history = add_message(history, message, None) | |
| # Format for display - temporarily show user message with empty response | |
| history_for_display = history.copy() | |
| # Clear the input | |
| message_out = "" | |
| # Return immediately to update UI with user message | |
| return history, history_for_display, message_out, [], "" | |
| def bot_response( | |
| history, | |
| gen_length, | |
| steps, | |
| constraints, | |
| delay, | |
| temperature, | |
| cfg_scale, | |
| block_length, | |
| remasking, | |
| ): | |
| """Generate bot response for the latest message""" | |
| if not history: | |
| return history, [], "" | |
| try: | |
| # Format all messages except the last one (which has no response yet) | |
| messages = format_chat_history(history[:-1]) | |
| # Add the last user message | |
| messages.append({"role": "user", "content": history[-1][0]}) | |
| # Parse constraints | |
| parsed_constraints = parse_constraints(constraints) | |
| # Generate response with visualization | |
| vis_states, response_text = generate_response_with_visualization( | |
| messages, | |
| gen_length=gen_length, | |
| steps=steps, | |
| constraints=parsed_constraints, | |
| temperature=temperature, | |
| cfg_scale=cfg_scale, | |
| block_length=block_length, | |
| remasking=remasking, | |
| ) | |
| # Update history with the assistant's response | |
| history[-1][1] = response_text | |
| # Return the initial state immediately | |
| yield history, vis_states[0], response_text | |
| # Then animate through visualization states | |
| for state in vis_states[1:]: | |
| time.sleep(delay) | |
| yield history, state, response_text | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| print(error_msg) | |
| # Show error in visualization | |
| error_vis = [(error_msg, "red")] | |
| # Don't update history with error | |
| yield history, error_vis, error_msg | |
| def clear_conversation(): | |
| """Clear the conversation history""" | |
| return [], [], "", [] | |
| # EVENT HANDLERS | |
| # Clear button handler | |
| clear_btn.click( | |
| fn=clear_conversation, | |
| inputs=[], | |
| outputs=[chat_history, chatbot_ui, current_response, output_vis], | |
| ) | |
| # User message submission flow (2-step process) | |
| # Step 1: Add user message to history and update UI | |
| msg_submit = user_input.submit( | |
| fn=user_message_submitted, | |
| inputs=[ | |
| user_input, | |
| chat_history, | |
| gen_length, | |
| steps, | |
| constraints_input, | |
| visualization_delay, | |
| ], | |
| outputs=[ | |
| chat_history, | |
| chatbot_ui, | |
| user_input, | |
| output_vis, | |
| current_response, | |
| ], | |
| ) | |
| # Also connect the send button | |
| send_click = send_btn.click( | |
| fn=user_message_submitted, | |
| inputs=[ | |
| user_input, | |
| chat_history, | |
| gen_length, | |
| steps, | |
| constraints_input, | |
| visualization_delay, | |
| ], | |
| outputs=[ | |
| chat_history, | |
| chatbot_ui, | |
| user_input, | |
| output_vis, | |
| current_response, | |
| ], | |
| ) | |
| # Step 2: Generate bot response | |
| # This happens after the user message is displayed | |
| msg_submit.then( | |
| fn=bot_response, | |
| inputs=[ | |
| chat_history, | |
| gen_length, | |
| steps, | |
| constraints_input, | |
| visualization_delay, | |
| temperature, | |
| cfg_scale, | |
| block_length, | |
| remasking_strategy, | |
| ], | |
| outputs=[chatbot_ui, output_vis, current_response], | |
| ) | |
| send_click.then( | |
| fn=bot_response, | |
| inputs=[ | |
| chat_history, | |
| gen_length, | |
| steps, | |
| constraints_input, | |
| visualization_delay, | |
| temperature, | |
| cfg_scale, | |
| block_length, | |
| remasking_strategy, | |
| ], | |
| outputs=[chatbot_ui, output_vis, current_response], | |
| ) | |
| return demo | |
| # Launch the demo | |
| if __name__ == "__main__": | |
| demo = create_chatbot_demo() | |
| demo.queue().launch(server_name="0.0.0.0") | |