Spaces:
Build error
Build error
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from openai import OpenAI | |
| from prompt_template import PromptTemplate, PromptLoader | |
| from assistant import AIAssistant | |
| from pathlib import Path | |
| # Load prompts from YAML | |
| prompts = PromptLoader.load_prompts("prompts.yaml") | |
| # Available models and their configurations | |
| MODELS = { | |
| "Zephyr 7B Beta": { | |
| "name": "HuggingFaceH4/zephyr-7b-beta", | |
| "provider": "huggingface" | |
| }, | |
| "Mistral 7B": { | |
| "name": "mistralai/Mistral-7B-v0.1", | |
| "provider": "huggingface" | |
| }, | |
| "GPT-3.5 Turbo": { | |
| "name": "gpt-3.5-turbo", | |
| "provider": "openai" | |
| } | |
| } | |
| # Available prompt strategies | |
| PROMPT_STRATEGIES = { | |
| "Default": "system_context", | |
| "Chain of Thought": "cot_prompt", | |
| "Knowledge-based": "knowledge_prompt", | |
| "Few-shot Learning": "few_shot_prompt", | |
| "Meta-prompting": "meta_prompt" | |
| } | |
| def create_assistant(model_name): | |
| model_info = MODELS[model_name] | |
| if model_info["provider"] == "huggingface": | |
| client = InferenceClient(model_info["name"]) | |
| else: # OpenAI | |
| client = OpenAI() | |
| return AIAssistant( | |
| client=client, | |
| model=model_info["name"] | |
| ) | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| model_name, | |
| prompt_strategy, | |
| override_params: bool, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| assistant = create_assistant(model_name) | |
| # Get selected prompt template and system context | |
| prompt_template: PromptTemplate = prompts[PROMPT_STRATEGIES[prompt_strategy]] | |
| system_context: PromptTemplate = prompts["system_context"] | |
| # Format system context with the selected prompt strategy | |
| formatted_system_message = system_context.format(prompt_strategy=prompt_template.template) | |
| # Prepare messages | |
| messages = [{"role": "system", "content": formatted_system_message}] | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| # Get generation parameters | |
| generation_params = prompt_template.parameters if not override_params else { | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p | |
| } | |
| # Generate response using the assistant | |
| for response in assistant.generate_response( | |
| prompt_template=prompt_template, | |
| generation_params=generation_params, | |
| stream=True, | |
| messages=messages | |
| ): | |
| yield response | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0], | |
| label="Select Model" | |
| ) | |
| prompt_strategy_dropdown = gr.Dropdown( | |
| choices=list(PROMPT_STRATEGIES.keys()), | |
| value=list(PROMPT_STRATEGIES.keys())[0], | |
| label="Select Prompt Strategy" | |
| ) | |
| with gr.Row(): | |
| override_params = gr.Checkbox( | |
| label="Override Template Parameters", | |
| value=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(visible=False) as param_controls: | |
| max_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=2048, | |
| value=512, | |
| step=1, | |
| label="Max new tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=4.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ) | |
| chatbot = gr.ChatInterface( | |
| fn=respond, | |
| additional_inputs=[ | |
| model_dropdown, | |
| prompt_strategy_dropdown, | |
| override_params, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ] | |
| ) | |
| def toggle_param_controls(override): | |
| return gr.Column(visible=override) | |
| override_params.change( | |
| toggle_param_controls, | |
| inputs=[override_params], | |
| outputs=[param_controls] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |