Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| # Initialize the Inference Client | |
| client = InferenceClient(model="RekaAI/reka-flash-3", token=os.getenv("HF_TOKEN")) | |
| # Helper function to format the conversation history into a prompt | |
| def format_history(history): | |
| prompt = "You are a helpful and harmless assistant.\n\n" | |
| for item in history: | |
| if item["role"] == "user": | |
| prompt += f"Human: {item['content']}\n" | |
| elif item["role"] == "assistant": | |
| prompt += f"Assistant: {item['content']}\n" | |
| prompt += "Assistant:" | |
| return prompt | |
| # Function to handle message submission and response generation | |
| def submit(message, history, temperature, max_new_tokens, top_p, top_k): | |
| # Add user's message to history | |
| history = history + [{"role": "user", "content": message}] | |
| # Add a "Thinking..." message to simulate the model's reasoning phase | |
| thinking_message = {"role": "assistant", "content": "Thinking..."} | |
| history = history + [thinking_message] | |
| yield history, history # Update chatbot and state | |
| # Format the prompt excluding the "Thinking..." message | |
| prompt = format_history(history[:-1]) | |
| # Stream the response from the Inference API | |
| response = client.text_generation( | |
| prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=1.0, | |
| stop_sequences=["\nHuman:", "\nAssistant:"], | |
| stream=True | |
| ) | |
| # Simulate "thinking" phase with the first 5 chunks | |
| thought_chunks = 0 | |
| max_thought_chunks = 5 | |
| accumulated_thought = "" | |
| for chunk in response: | |
| if thought_chunks < max_thought_chunks: | |
| accumulated_thought += chunk | |
| thinking_message["content"] = "Thinking: " + accumulated_thought | |
| thought_chunks += 1 | |
| if thought_chunks == max_thought_chunks: | |
| # Finalize the "Thought" message and start the "Answer" message | |
| thinking_message["content"] = "Thought: " + accumulated_thought | |
| answer_message = {"role": "assistant", "content": "Answer:"} | |
| history = history + [answer_message] | |
| else: | |
| # Append subsequent chunks to the "Answer" message | |
| answer_message["content"] += chunk | |
| yield history, history # Update UI with each chunk | |
| # Finalize the response | |
| if 'answer_message' in locals(): | |
| answer_message["content"] += "\n\n[End of response]" | |
| else: | |
| thinking_message["content"] += "\n\n[No response generated]" | |
| yield history, history | |
| # Build the Gradio interface | |
| with gr.Blocks() as demo: | |
| # State to store the conversation history | |
| history_state = gr.State([]) | |
| # Chatbot component to display messages | |
| chatbot = gr.Chatbot(type="messages", height=400, label="Conversation") | |
| # Layout with settings and input area | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Advanced settings in a collapsible panel | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7) | |
| max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=1024, step=1, value=512) | |
| top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.9) | |
| top_k = gr.Slider(label="Top K", minimum=1, maximum=100, step=1, value=50) | |
| with gr.Column(scale=4): | |
| # Textbox for user input and buttons | |
| textbox = gr.Textbox(label="Your message") | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.Button("Clear") | |
| # Connect the submit button to the submit function | |
| submit_btn.click( | |
| submit, | |
| inputs=[textbox, history_state, temperature, max_tokens, top_p, top_k], | |
| outputs=[chatbot, history_state] | |
| ) | |
| # Clear button resets the conversation | |
| clear_btn.click(lambda: ([], []), outputs=[chatbot, history_state]) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo.queue().launch() |