Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from text_generation import Client | |
| # HF-hosted endpoint for testing purposes (requires an HF API token) | |
| API_TOKEN = os.environ.get("API_TOKEN", None) | |
| CURRENT_CLIENT = Client("https://afrts4trc759c6eq.us-east-1.aws.endpoints.huggingface.cloud/generate_stream", | |
| timeout=120, | |
| headers={ | |
| "Accept": "application/json", | |
| "Authorization": f"Bearer {API_TOKEN}", | |
| "Content-Type": "application/json"} | |
| ) | |
| DEFAULT_HEADER = os.environ.get("HEADER", "") | |
| DEFAULT_USER_NAME = os.environ.get("USER_NAME", "user") | |
| DEFAULT_ASSISTANT_NAME = os.environ.get("ASSISTANT_NAME", "assistant") | |
| DEFAULT_SEPARATOR = os.environ.get("SEPARATOR", "<|im_end|>") | |
| PROMPT_TEMPLATE = "<|im_start|>{user_name}\n{query}{separator}\n<|im_start|>{assistant_name}\n{response}" | |
| repo = None | |
| def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): | |
| past = [] | |
| for data in chatbot: | |
| user_data, model_data = data | |
| if not user_data.startswith(user_name): | |
| user_data = user_name + user_data | |
| if not model_data.startswith(sep + assistant_name): | |
| model_data = sep + assistant_name + model_data | |
| past.append(user_data + model_data.rstrip() + sep) | |
| if not inputs.startswith(user_name): | |
| inputs = user_name + inputs | |
| total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
| return total_inputs | |
| def has_no_history(chatbot, history): | |
| return not chatbot and not history | |
| def generate( | |
| user_message, | |
| chatbot, | |
| history, | |
| temperature, | |
| top_p, | |
| max_new_tokens, | |
| repetition_penalty, | |
| header, | |
| user_name, | |
| assistant_name, | |
| separator | |
| ): | |
| # Don't return meaningless message when the input is empty | |
| if not user_message: | |
| print("Empty input") | |
| history.append(user_message) | |
| past_messages = [] | |
| for data in chatbot: | |
| user_data, model_data = data | |
| past_messages.extend( | |
| [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] | |
| ) | |
| print(past_messages) | |
| if len(past_messages) < 1: | |
| prompt = header + PROMPT_TEMPLATE.format(user_name=user_name, | |
| query=user_message, | |
| assistant_name=assistant_name, | |
| response="", | |
| separator=separator) | |
| else: | |
| prompt = header | |
| for i in range(0, len(past_messages), 2): | |
| intermediate_prompt = PROMPT_TEMPLATE.format(user_name=user_name, | |
| query=past_messages[i]["content"], | |
| assistant_name=assistant_name, | |
| response=past_messages[i + 1]["content"], | |
| separator=separator) | |
| # print(prompt, separator, intermediate_prompt) | |
| prompt = prompt + intermediate_prompt + separator + "\n" | |
| # print(prompt) | |
| prompt = prompt + PROMPT_TEMPLATE.format(user_name=user_name, | |
| query=user_message, | |
| assistant_name=assistant_name, | |
| response="", | |
| separator=separator) | |
| temperature = float(temperature) | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(top_p) | |
| generate_kwargs = dict( | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| top_k=40, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| truncate=1024, | |
| # seed=42, | |
| # stop_sequences=[user_name, DEFAULT_SEPARATOR] | |
| stop_sequences=[DEFAULT_SEPARATOR] | |
| ) | |
| # print(prompt) | |
| stream = CURRENT_CLIENT.generate_stream( | |
| prompt, | |
| **generate_kwargs, | |
| ) | |
| output = "" | |
| for idx, response in enumerate(stream): | |
| # print(response.token) | |
| if response.token.text == '': | |
| pass | |
| # print(response.token.text) | |
| # break | |
| if response.token.special: | |
| continue | |
| output += response.token.text | |
| if idx == 0: | |
| history.append(" " + output) | |
| else: | |
| history[-1] = output | |
| chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] | |
| # chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] | |
| yield chat, history, user_message, "" | |
| return chat, history, user_message, "" | |
| def clear_chat(): | |
| return [], [] | |
| title = """<h1 align="center">CroissantLLMChat Playground 🥐</h1>""" | |
| custom_css = """ | |
| #banner-image { | |
| display: block; | |
| margin-left: auto; | |
| margin-right: auto; | |
| } | |
| #chat-message { | |
| font-size: 14px; | |
| min-height: 300px; | |
| } | |
| """ | |
| with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: | |
| gr.HTML(title) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ## Demo platform for 🥐 CroissantLLMChat | |
| The model is of small size (1.3B), about 130 times smaller than GPT3. | |
| As such, it's generalist Chat version logically exhibits reduced understanding, reasoning and knowledge capacities. | |
| For industrial uses, we recommend finetuning the model, but trained this Chat version to allow for experimenting and to showcase the capabilities for it's size. | |
| ## Usage recommendations | |
| We recommend testing the chat model for open-ended writing tasks, tips, translations, etc... | |
| We find direct instructions to work best, and performance to drop after the first round of interactions. | |
| We limit the length of the conversation so clear the Chat between tests ! | |
| ## Errors | |
| The demo is linked to an endpoint that auto-shutdowns after 15mn. If error message appears, wait about 5 minutes and test again once the server is back up ! | |
| The model can hallucinate and generate incorrect or even toxic content. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Group(): | |
| output = gr.Markdown() | |
| chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") | |
| with gr.Row(): | |
| send_button = gr.Button("Send", elem_id="send-btn", visible=True) | |
| clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True) | |
| with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"): | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| value=0.5, | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.1, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ) | |
| top_p = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.9, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| value=512, | |
| minimum=0, | |
| maximum=1024, | |
| step=4, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition Penalty", | |
| value=1.2, | |
| minimum=0.0, | |
| maximum=10, | |
| step=0.1, | |
| interactive=True, | |
| info="The parameter for repetition penalty. 1.0 means no penalty.", | |
| ) | |
| with gr.Accordion(label="Prompt", open=False, elem_id="prompt-accordion"): | |
| header = gr.Textbox( | |
| label="Header instructions", | |
| value=DEFAULT_HEADER, | |
| interactive=True, | |
| info="Instructions given to the assistant at the beginning of the prompt", | |
| ) | |
| user_name = gr.Textbox( | |
| label="User name", | |
| value=DEFAULT_USER_NAME, | |
| interactive=True, | |
| info="Name to be given to the user in the prompt", | |
| ) | |
| assistant_name = gr.Textbox( | |
| label="Assistant name", | |
| value=DEFAULT_ASSISTANT_NAME, | |
| interactive=True, | |
| info="Name to be given to the assistant in the prompt", | |
| ) | |
| separator = gr.Textbox( | |
| label="Separator", | |
| value=DEFAULT_SEPARATOR, | |
| interactive=True, | |
| info="Character to be used when the speaker changes in the prompt", | |
| ) | |
| history = gr.State([]) | |
| last_user_message = gr.State("") | |
| user_message.submit( | |
| generate, | |
| inputs=[ | |
| user_message, | |
| chatbot, | |
| history, | |
| temperature, | |
| top_p, | |
| max_new_tokens, | |
| repetition_penalty, | |
| header, | |
| user_name, | |
| assistant_name, | |
| separator | |
| ], | |
| outputs=[chatbot, history, last_user_message, user_message], | |
| ) | |
| send_button.click( | |
| generate, | |
| inputs=[ | |
| user_message, | |
| chatbot, | |
| history, | |
| temperature, | |
| top_p, | |
| max_new_tokens, | |
| repetition_penalty, | |
| header, | |
| user_name, | |
| assistant_name, | |
| separator | |
| ], | |
| outputs=[chatbot, history, last_user_message, user_message], | |
| ) | |
| clear_chat_button.click(clear_chat, outputs=[chatbot, history]) | |
| demo.queue().launch() | |