Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import gradio as gr | |
| from common.call_llm import chat, chat_stream_generator | |
| from plugin_task.model import Plugin, ReActStep | |
| from plugin_task.plugins import PLUGIN_JSON_SCHEMA, PLUGINS | |
| from plugin_task.prompt import ( | |
| FILLING_SLOT_PROMPT, | |
| FINAL_PROMPT, | |
| INTENT_RECOGNITION_PROMPT, | |
| ) | |
| from plugin_task.util import ( | |
| build_prompt_plugin_variables, | |
| parse_reAct_step, | |
| plugin_parameter_validator, | |
| ) | |
| PLUGIN_ENDPOINT = os.environ.get("PLUGIN_ENDPOINT") | |
| def api_plugin_chat( | |
| session: Dict, | |
| message: str, | |
| chat_history: List[List[str]], | |
| *radio_plugins, | |
| ): | |
| """调用插件""" | |
| if not check_in_plugin_session(session): | |
| plugins = prepare_plugins(radio_plugins) | |
| if not plugins: | |
| gr.Warning("没有启用插件") | |
| return | |
| intention, reAct_step = intent_recognition(message, plugins) | |
| if intention in ("ask_user_for_required_params", "plugin"): | |
| session["origin_message"] = message | |
| session["choice_plugin"] = reAct_step.thought["tool_to_use_for_user"] | |
| session["reAct_step"] = [reAct_step] | |
| else: | |
| intention, reAct_step = filling_slot_with_loop(session, message) | |
| print( | |
| f"[API_PLUGIN_CHAT]. message: {message},\n intention: {intention},\n session: {session}\n" | |
| + "=" * 25 | |
| + "END" | |
| + "=" * 25 | |
| ) | |
| if intention == "fail": | |
| chat_history[-1][1] = reAct_step | |
| session.clear() | |
| yield session, None, chat_history | |
| return | |
| if intention == "ask_user_for_required_params": | |
| chat_history[-1][1] = reAct_step.action_input.get("question", "") | |
| yield session, None, chat_history | |
| if intention == "plugin": | |
| yield from call_final_answer(session, reAct_step, chat_history) | |
| if intention == "chat": | |
| yield from call_chat(session, message, chat_history) | |
| if intention == "end": | |
| session.clear() | |
| chat_history[-1][1] = "[系统消息]:当前插件对话结束" | |
| yield session, None, chat_history | |
| return | |
| return | |
| def filling_slot_with_loop( | |
| session: Dict, message: str, retry: int = 3 | |
| ) -> Tuple[str, Optional[Union[ReActStep, str]]]: | |
| """处理填槽""" | |
| plugin = PLUGINS[session["choice_plugin"]] | |
| while True: | |
| lastest_reAct_step = session["reAct_step"][-1] | |
| if not lastest_reAct_step.observation: | |
| lastest_reAct_step.observation = {"user_answer": message} | |
| reAct_step_str = "\n".join(step.to_str() for step in session["reAct_step"]) | |
| ask_content = FILLING_SLOT_PROMPT.format( | |
| plugin_name=plugin.unique_name_for_model, | |
| description_for_human=plugin.description_for_human, | |
| parameter_schema=plugin.parameter_schema, | |
| question=session["origin_message"], | |
| reAct_step_str=reAct_step_str, | |
| ) | |
| model_response = chat( | |
| [{"content": ask_content, "role": "user"}], | |
| stop="Observation", | |
| endpoint=PLUGIN_ENDPOINT, | |
| ) | |
| print( | |
| f"[FILLING_SLOT_WITH_LOOP] message: {message} ask_content: {ask_content}\n model_response: {model_response}\n" | |
| + "=" * 25 | |
| + "END" | |
| + "=" * 25 | |
| ) | |
| reAct_step = parse_reAct_step(model_response) | |
| if not reAct_step: | |
| if (retry := retry - 1) < 0: | |
| return "fail", model_response | |
| continue | |
| tool_to_use_for_user = reAct_step.thought.get("tool_to_use_for_user") | |
| known_parameter = reAct_step.thought.get("known_params", {}) | |
| if ( | |
| reAct_step.action == "end_conversation" | |
| or tool_to_use_for_user == "end_conversation" | |
| ): | |
| return "end", reAct_step | |
| if ( | |
| reAct_step.action == "ASK_USER_FOR_REQUIRED_PARAMS" | |
| and tool_to_use_for_user == plugin.unique_name_for_model | |
| ): | |
| passed, _ = plugin_parameter_validator( | |
| known_parameter, | |
| tool_to_use_for_user, | |
| ) | |
| if passed: | |
| reAct_step.action = tool_to_use_for_user | |
| action = "plugin" | |
| else: | |
| action = "ask_user_for_required_params" | |
| session["reAct_step"].append(reAct_step) | |
| return action, reAct_step | |
| if ( | |
| reAct_step.action == plugin.unique_name_for_model | |
| and tool_to_use_for_user == plugin.unique_name_for_model | |
| ): | |
| passed, invalid_info = plugin_parameter_validator( | |
| known_parameter, | |
| tool_to_use_for_user, | |
| ) | |
| if not passed: | |
| reAct_step.observation = {"tool_parameters_verification": invalid_info} | |
| session["reAct_step"].append(reAct_step) | |
| continue | |
| session["reAct_step"].append(reAct_step) | |
| return "plugin", reAct_step | |
| def call_chat(session: Dict, message: str, chat_history: List[List[str]]): | |
| from chat_task.chat import generate_chat | |
| for chunk in generate_chat(message, chat_history, PLUGIN_ENDPOINT): | |
| yield session, *chunk | |
| def check_in_plugin_session(session: Dict) -> bool: | |
| """检查是否在插件会话中""" | |
| return bool(session) | |
| def prepare_plugins( | |
| radio_plugins: List[str], | |
| ) -> List[Plugin]: | |
| return [ | |
| PLUGINS[PLUGIN_JSON_SCHEMA[plugin_idx]["unique_name_for_model"]] | |
| for plugin_idx, plugin_status in enumerate(radio_plugins) | |
| if plugin_status == "开启" | |
| ] | |
| def intent_recognition( | |
| message: str, choice_plugins: List[Plugin] | |
| ) -> Tuple[str, Union[ReActStep, str]]: | |
| """意图识别""" | |
| plugins, plugin_names = build_prompt_plugin_variables(choice_plugins) | |
| ask_content = INTENT_RECOGNITION_PROMPT.format( | |
| plugins=plugins, plugin_names=plugin_names, question=message | |
| ) | |
| print( | |
| f"[INTENT_RECOGNITION] message:{message} ask_content: {ask_content}" | |
| + "=" * 25 | |
| + "END" | |
| + "=" * 25 | |
| ) | |
| retry = 3 | |
| while retry != 0: | |
| model_response = chat( | |
| [{"content": ask_content, "role": "user"}], | |
| stop="Observation", | |
| endpoint=PLUGIN_ENDPOINT, | |
| ) | |
| reAct_step = parse_reAct_step(model_response) | |
| if reAct_step: | |
| break | |
| retry -= 1 | |
| if not reAct_step: | |
| print(f"[INTENT_RECOGNITION] model fail: {model_response}") | |
| return "fail", model_response | |
| tool_to_use_for_user = reAct_step.thought.get("tool_to_use_for_user") | |
| known_params = reAct_step.thought.get("known_params", {}) | |
| if reAct_step.action == "TOOL_OTHER": | |
| return "chat", reAct_step | |
| elif ( | |
| reAct_step.action == "end_conversation" | |
| and tool_to_use_for_user == "end_conversation" | |
| ): | |
| return "end", reAct_step | |
| elif tool_to_use_for_user in plugin_names.split(","): | |
| if reAct_step.action in ("ASK_USER_FOR_INTENT", "ASK_USER_FOR_REQUIRED_PARAMS"): | |
| passed, _ = plugin_parameter_validator( | |
| known_params, | |
| tool_to_use_for_user, | |
| ) | |
| if passed: | |
| reAct_step.action = tool_to_use_for_user | |
| return "plugin", reAct_step | |
| return "ask_user_for_required_params", reAct_step | |
| if reAct_step.action in plugin_names.split(","): | |
| return "plugin", reAct_step | |
| return "chat", reAct_step | |
| def call_final_answer(session: Dict, reAct_step: ReActStep, history: List[List[str]]): | |
| """调用最终回答""" | |
| plugin_result = PLUGINS[reAct_step.action].run(**reAct_step.action_input) | |
| lastest_reAct_step = session["reAct_step"][-1] | |
| lastest_reAct_step.observation = {"tool_response": plugin_result} | |
| reAct_step_str = "\n".join(step.to_str() for step in session["reAct_step"]) | |
| final_prompt = FINAL_PROMPT.format( | |
| question=session["origin_message"], | |
| reAct_step_str=reAct_step_str, | |
| ) | |
| print( | |
| f"[CALL_FINAL_ANSWER] final_prompt: {final_prompt}\n" | |
| + "=" * 25 | |
| + "END" | |
| + "=" * 25 | |
| ) | |
| stream_response = chat_stream_generator( | |
| [{"content": final_prompt, "role": "user"}], | |
| endpoint=PLUGIN_ENDPOINT, | |
| ) | |
| for character in stream_response: | |
| history[-1][1] += character | |
| yield session, None, history | |
| session.clear() | |