Spaces:
Running
Running
| # 该文件封装了对api.py的请求,可以被不同的webui使用 | |
| # 通过ApiRequest和AsyncApiRequest支持同步/异步调用 | |
| from typing import * | |
| from pathlib import Path | |
| # 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同 | |
| from configs import ( | |
| EMBEDDING_MODEL, | |
| DEFAULT_VS_TYPE, | |
| LLM_MODELS, | |
| TEMPERATURE, | |
| SCORE_THRESHOLD, | |
| CHUNK_SIZE, | |
| OVERLAP_SIZE, | |
| ZH_TITLE_ENHANCE, | |
| VECTOR_SEARCH_TOP_K, | |
| SEARCH_ENGINE_TOP_K, | |
| HTTPX_DEFAULT_TIMEOUT, | |
| logger, log_verbose, | |
| ) | |
| import httpx | |
| import contextlib | |
| import json | |
| import os | |
| from io import BytesIO | |
| from server.utils import set_httpx_config, api_address, get_httpx_client | |
| from pprint import pprint | |
| from langchain_core._api import deprecated | |
| set_httpx_config() | |
| class ApiRequest: | |
| ''' | |
| api.py调用的封装(同步模式),简化api调用方式 | |
| ''' | |
| def __init__( | |
| self, | |
| base_url: str = api_address(), | |
| timeout: float = HTTPX_DEFAULT_TIMEOUT, | |
| ): | |
| self.base_url = base_url | |
| self.timeout = timeout | |
| self._use_async = False | |
| self._client = None | |
| def client(self): | |
| if self._client is None or self._client.is_closed: | |
| self._client = get_httpx_client(base_url=self.base_url, | |
| use_async=self._use_async, | |
| timeout=self.timeout) | |
| return self._client | |
| def get( | |
| self, | |
| url: str, | |
| params: Union[Dict, List[Tuple], bytes] = None, | |
| retry: int = 3, | |
| stream: bool = False, | |
| **kwargs: Any, | |
| ) -> Union[httpx.Response, Iterator[httpx.Response], None]: | |
| while retry > 0: | |
| try: | |
| if stream: | |
| return self.client.stream("GET", url, params=params, **kwargs) | |
| else: | |
| return self.client.get(url, params=params, **kwargs) | |
| except Exception as e: | |
| msg = f"error when get {url}: {e}" | |
| logger.error(f'{e.__class__.__name__}: {msg}', | |
| exc_info=e if log_verbose else None) | |
| retry -= 1 | |
| def post( | |
| self, | |
| url: str, | |
| data: Dict = None, | |
| json: Dict = None, | |
| retry: int = 3, | |
| stream: bool = False, | |
| **kwargs: Any | |
| ) -> Union[httpx.Response, Iterator[httpx.Response], None]: | |
| while retry > 0: | |
| try: | |
| # print(kwargs) | |
| if stream: | |
| return self.client.stream("POST", url, data=data, json=json, **kwargs) | |
| else: | |
| return self.client.post(url, data=data, json=json, **kwargs) | |
| except Exception as e: | |
| msg = f"error when post {url}: {e}" | |
| logger.error(f'{e.__class__.__name__}: {msg}', | |
| exc_info=e if log_verbose else None) | |
| retry -= 1 | |
| def delete( | |
| self, | |
| url: str, | |
| data: Dict = None, | |
| json: Dict = None, | |
| retry: int = 3, | |
| stream: bool = False, | |
| **kwargs: Any | |
| ) -> Union[httpx.Response, Iterator[httpx.Response], None]: | |
| while retry > 0: | |
| try: | |
| if stream: | |
| return self.client.stream("DELETE", url, data=data, json=json, **kwargs) | |
| else: | |
| return self.client.delete(url, data=data, json=json, **kwargs) | |
| except Exception as e: | |
| msg = f"error when delete {url}: {e}" | |
| logger.error(f'{e.__class__.__name__}: {msg}', | |
| exc_info=e if log_verbose else None) | |
| retry -= 1 | |
| def _httpx_stream2generator( | |
| self, | |
| response: contextlib._GeneratorContextManager, | |
| as_json: bool = False, | |
| ): | |
| ''' | |
| 将httpx.stream返回的GeneratorContextManager转化为普通生成器 | |
| ''' | |
| async def ret_async(response, as_json): | |
| try: | |
| async with response as r: | |
| async for chunk in r.aiter_text(None): | |
| if not chunk: # fastchat api yield empty bytes on start and end | |
| continue | |
| if as_json: | |
| try: | |
| if chunk.startswith("data: "): | |
| data = json.loads(chunk[6:-2]) | |
| elif chunk.startswith(":"): # skip sse comment line | |
| continue | |
| else: | |
| data = json.loads(chunk) | |
| yield data | |
| except Exception as e: | |
| msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" | |
| logger.error(f'{e.__class__.__name__}: {msg}', | |
| exc_info=e if log_verbose else None) | |
| else: | |
| # print(chunk, end="", flush=True) | |
| yield chunk | |
| except httpx.ConnectError as e: | |
| msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" | |
| logger.error(msg) | |
| yield {"code": 500, "msg": msg} | |
| except httpx.ReadTimeout as e: | |
| msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" | |
| logger.error(msg) | |
| yield {"code": 500, "msg": msg} | |
| except Exception as e: | |
| msg = f"API通信遇到错误:{e}" | |
| logger.error(f'{e.__class__.__name__}: {msg}', | |
| exc_info=e if log_verbose else None) | |
| yield {"code": 500, "msg": msg} | |
| def ret_sync(response, as_json): | |
| try: | |
| with response as r: | |
| for chunk in r.iter_text(None): | |
| if not chunk: # fastchat api yield empty bytes on start and end | |
| continue | |
| if as_json: | |
| try: | |
| if chunk.startswith("data: "): | |
| data = json.loads(chunk[6:-2]) | |
| elif chunk.startswith(":"): # skip sse comment line | |
| continue | |
| else: | |
| data = json.loads(chunk) | |
| yield data | |
| except Exception as e: | |
| msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" | |
| logger.error(f'{e.__class__.__name__}: {msg}', | |
| exc_info=e if log_verbose else None) | |
| else: | |
| # print(chunk, end="", flush=True) | |
| yield chunk | |
| except httpx.ConnectError as e: | |
| msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" | |
| logger.error(msg) | |
| yield {"code": 500, "msg": msg} | |
| except httpx.ReadTimeout as e: | |
| msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" | |
| logger.error(msg) | |
| yield {"code": 500, "msg": msg} | |
| except Exception as e: | |
| msg = f"API通信遇到错误:{e}" | |
| logger.error(f'{e.__class__.__name__}: {msg}', | |
| exc_info=e if log_verbose else None) | |
| yield {"code": 500, "msg": msg} | |
| if self._use_async: | |
| return ret_async(response, as_json) | |
| else: | |
| return ret_sync(response, as_json) | |
| def _get_response_value( | |
| self, | |
| response: httpx.Response, | |
| as_json: bool = False, | |
| value_func: Callable = None, | |
| ): | |
| ''' | |
| 转换同步或异步请求返回的响应 | |
| `as_json`: 返回json | |
| `value_func`: 用户可以自定义返回值,该函数接受response或json | |
| ''' | |
| def to_json(r): | |
| try: | |
| return r.json() | |
| except Exception as e: | |
| msg = "API未能返回正确的JSON。" + str(e) | |
| if log_verbose: | |
| logger.error(f'{e.__class__.__name__}: {msg}', | |
| exc_info=e if log_verbose else None) | |
| return {"code": 500, "msg": msg, "data": None} | |
| if value_func is None: | |
| value_func = (lambda r: r) | |
| async def ret_async(response): | |
| if as_json: | |
| return value_func(to_json(await response)) | |
| else: | |
| return value_func(await response) | |
| if self._use_async: | |
| return ret_async(response) | |
| else: | |
| if as_json: | |
| return value_func(to_json(response)) | |
| else: | |
| return value_func(response) | |
| # 服务器信息 | |
| def get_server_configs(self, **kwargs) -> Dict: | |
| response = self.post("/server/configs", **kwargs) | |
| return self._get_response_value(response, as_json=True) | |
| def list_search_engines(self, **kwargs) -> List: | |
| response = self.post("/server/list_search_engines", **kwargs) | |
| return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"]) | |
| def get_prompt_template( | |
| self, | |
| type: str = "llm_chat", | |
| name: str = "default", | |
| **kwargs, | |
| ) -> str: | |
| data = { | |
| "type": type, | |
| "name": name, | |
| } | |
| response = self.post("/server/get_prompt_template", json=data, **kwargs) | |
| return self._get_response_value(response, value_func=lambda r: r.text) | |
| # 对话相关操作 | |
| def chat_chat( | |
| self, | |
| query: str, | |
| conversation_id: str = None, | |
| history_len: int = -1, | |
| history: List[Dict] = [], | |
| stream: bool = True, | |
| model: str = LLM_MODELS[0], | |
| temperature: float = TEMPERATURE, | |
| max_tokens: int = None, | |
| prompt_name: str = "default", | |
| **kwargs, | |
| ): | |
| ''' | |
| 对应api.py/chat/chat接口 | |
| ''' | |
| data = { | |
| "query": query, | |
| "conversation_id": conversation_id, | |
| "history_len": history_len, | |
| "history": history, | |
| "stream": stream, | |
| "model_name": model, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "prompt_name": prompt_name, | |
| } | |
| # print(f"received input message:") | |
| # pprint(data) | |
| response = self.post("/chat/chat", json=data, stream=True, **kwargs) | |
| return self._httpx_stream2generator(response, as_json=True) | |
| def agent_chat( | |
| self, | |
| query: str, | |
| history: List[Dict] = [], | |
| stream: bool = True, | |
| model: str = LLM_MODELS[0], | |
| temperature: float = TEMPERATURE, | |
| max_tokens: int = None, | |
| prompt_name: str = "default", | |
| ): | |
| ''' | |
| 对应api.py/chat/agent_chat 接口 | |
| ''' | |
| data = { | |
| "query": query, | |
| "history": history, | |
| "stream": stream, | |
| "model_name": model, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "prompt_name": prompt_name, | |
| } | |
| # print(f"received input message:") | |
| # pprint(data) | |
| response = self.post("/chat/agent_chat", json=data, stream=True) | |
| return self._httpx_stream2generator(response, as_json=True) | |
| def knowledge_base_chat( | |
| self, | |
| query: str, | |
| knowledge_base_name: str, | |
| top_k: int = VECTOR_SEARCH_TOP_K, | |
| score_threshold: float = SCORE_THRESHOLD, | |
| history: List[Dict] = [], | |
| stream: bool = True, | |
| model: str = LLM_MODELS[0], | |
| temperature: float = TEMPERATURE, | |
| max_tokens: int = None, | |
| prompt_name: str = "default", | |
| ): | |
| ''' | |
| 对应api.py/chat/knowledge_base_chat接口 | |
| ''' | |
| data = { | |
| "query": query, | |
| "knowledge_base_name": knowledge_base_name, | |
| "top_k": top_k, | |
| "score_threshold": score_threshold, | |
| "history": history, | |
| "stream": stream, | |
| "model_name": model, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "prompt_name": prompt_name, | |
| } | |
| # print(f"received input message:") | |
| # pprint(data) | |
| response = self.post( | |
| "/chat/knowledge_base_chat", | |
| json=data, | |
| stream=True, | |
| ) | |
| return self._httpx_stream2generator(response, as_json=True) | |
| def upload_temp_docs( | |
| self, | |
| files: List[Union[str, Path, bytes]], | |
| knowledge_id: str = None, | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=OVERLAP_SIZE, | |
| zh_title_enhance=ZH_TITLE_ENHANCE, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/upload_tmep_docs接口 | |
| ''' | |
| def convert_file(file, filename=None): | |
| if isinstance(file, bytes): # raw bytes | |
| file = BytesIO(file) | |
| elif hasattr(file, "read"): # a file io like object | |
| filename = filename or file.name | |
| else: # a local path | |
| file = Path(file).absolute().open("rb") | |
| filename = filename or os.path.split(file.name)[-1] | |
| return filename, file | |
| files = [convert_file(file) for file in files] | |
| data = { | |
| "knowledge_id": knowledge_id, | |
| "chunk_size": chunk_size, | |
| "chunk_overlap": chunk_overlap, | |
| "zh_title_enhance": zh_title_enhance, | |
| } | |
| response = self.post( | |
| "/knowledge_base/upload_temp_docs", | |
| data=data, | |
| files=[("files", (filename, file)) for filename, file in files], | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def file_chat( | |
| self, | |
| query: str, | |
| knowledge_id: str, | |
| top_k: int = VECTOR_SEARCH_TOP_K, | |
| score_threshold: float = SCORE_THRESHOLD, | |
| history: List[Dict] = [], | |
| stream: bool = True, | |
| model: str = LLM_MODELS[0], | |
| temperature: float = TEMPERATURE, | |
| max_tokens: int = None, | |
| prompt_name: str = "default", | |
| ): | |
| ''' | |
| 对应api.py/chat/file_chat接口 | |
| ''' | |
| data = { | |
| "query": query, | |
| "knowledge_id": knowledge_id, | |
| "top_k": top_k, | |
| "score_threshold": score_threshold, | |
| "history": history, | |
| "stream": stream, | |
| "model_name": model, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "prompt_name": prompt_name, | |
| } | |
| response = self.post( | |
| "/chat/file_chat", | |
| json=data, | |
| stream=True, | |
| ) | |
| return self._httpx_stream2generator(response, as_json=True) | |
| def search_engine_chat( | |
| self, | |
| query: str, | |
| search_engine_name: str, | |
| top_k: int = SEARCH_ENGINE_TOP_K, | |
| history: List[Dict] = [], | |
| stream: bool = True, | |
| model: str = LLM_MODELS[0], | |
| temperature: float = TEMPERATURE, | |
| max_tokens: int = None, | |
| prompt_name: str = "default", | |
| split_result: bool = False, | |
| ): | |
| ''' | |
| 对应api.py/chat/search_engine_chat接口 | |
| ''' | |
| data = { | |
| "query": query, | |
| "search_engine_name": search_engine_name, | |
| "top_k": top_k, | |
| "history": history, | |
| "stream": stream, | |
| "model_name": model, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "prompt_name": prompt_name, | |
| "split_result": split_result, | |
| } | |
| # print(f"received input message:") | |
| # pprint(data) | |
| response = self.post( | |
| "/chat/search_engine_chat", | |
| json=data, | |
| stream=True, | |
| ) | |
| return self._httpx_stream2generator(response, as_json=True) | |
| # 知识库相关操作 | |
| def list_knowledge_bases( | |
| self, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/list_knowledge_bases接口 | |
| ''' | |
| response = self.get("/knowledge_base/list_knowledge_bases") | |
| return self._get_response_value(response, | |
| as_json=True, | |
| value_func=lambda r: r.get("data", [])) | |
| def create_knowledge_base( | |
| self, | |
| knowledge_base_name: str, | |
| vector_store_type: str = DEFAULT_VS_TYPE, | |
| embed_model: str = EMBEDDING_MODEL, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/create_knowledge_base接口 | |
| ''' | |
| data = { | |
| "knowledge_base_name": knowledge_base_name, | |
| "vector_store_type": vector_store_type, | |
| "embed_model": embed_model, | |
| } | |
| response = self.post( | |
| "/knowledge_base/create_knowledge_base", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def delete_knowledge_base( | |
| self, | |
| knowledge_base_name: str, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/delete_knowledge_base接口 | |
| ''' | |
| response = self.post( | |
| "/knowledge_base/delete_knowledge_base", | |
| json=f"{knowledge_base_name}", | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def list_kb_docs( | |
| self, | |
| knowledge_base_name: str, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/list_files接口 | |
| ''' | |
| response = self.get( | |
| "/knowledge_base/list_files", | |
| params={"knowledge_base_name": knowledge_base_name} | |
| ) | |
| return self._get_response_value(response, | |
| as_json=True, | |
| value_func=lambda r: r.get("data", [])) | |
| def search_kb_docs( | |
| self, | |
| knowledge_base_name: str, | |
| query: str = "", | |
| top_k: int = VECTOR_SEARCH_TOP_K, | |
| score_threshold: int = SCORE_THRESHOLD, | |
| file_name: str = "", | |
| metadata: dict = {}, | |
| ) -> List: | |
| ''' | |
| 对应api.py/knowledge_base/search_docs接口 | |
| ''' | |
| data = { | |
| "query": query, | |
| "knowledge_base_name": knowledge_base_name, | |
| "top_k": top_k, | |
| "score_threshold": score_threshold, | |
| "file_name": file_name, | |
| "metadata": metadata, | |
| } | |
| response = self.post( | |
| "/knowledge_base/search_docs", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def update_docs_by_id( | |
| self, | |
| knowledge_base_name: str, | |
| docs: Dict[str, Dict], | |
| ) -> bool: | |
| ''' | |
| 对应api.py/knowledge_base/update_docs_by_id接口 | |
| ''' | |
| data = { | |
| "knowledge_base_name": knowledge_base_name, | |
| "docs": docs, | |
| } | |
| response = self.post( | |
| "/knowledge_base/update_docs_by_id", | |
| json=data | |
| ) | |
| return self._get_response_value(response) | |
| def upload_kb_docs( | |
| self, | |
| files: List[Union[str, Path, bytes]], | |
| knowledge_base_name: str, | |
| override: bool = False, | |
| to_vector_store: bool = True, | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=OVERLAP_SIZE, | |
| zh_title_enhance=ZH_TITLE_ENHANCE, | |
| docs: Dict = {}, | |
| not_refresh_vs_cache: bool = False, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/upload_docs接口 | |
| ''' | |
| def convert_file(file, filename=None): | |
| if isinstance(file, bytes): # raw bytes | |
| file = BytesIO(file) | |
| elif hasattr(file, "read"): # a file io like object | |
| filename = filename or file.name | |
| else: # a local path | |
| file = Path(file).absolute().open("rb") | |
| filename = filename or os.path.split(file.name)[-1] | |
| return filename, file | |
| files = [convert_file(file) for file in files] | |
| data = { | |
| "knowledge_base_name": knowledge_base_name, | |
| "override": override, | |
| "to_vector_store": to_vector_store, | |
| "chunk_size": chunk_size, | |
| "chunk_overlap": chunk_overlap, | |
| "zh_title_enhance": zh_title_enhance, | |
| "docs": docs, | |
| "not_refresh_vs_cache": not_refresh_vs_cache, | |
| } | |
| if isinstance(data["docs"], dict): | |
| data["docs"] = json.dumps(data["docs"], ensure_ascii=False) | |
| response = self.post( | |
| "/knowledge_base/upload_docs", | |
| data=data, | |
| files=[("files", (filename, file)) for filename, file in files], | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def delete_kb_docs( | |
| self, | |
| knowledge_base_name: str, | |
| file_names: List[str], | |
| delete_content: bool = False, | |
| not_refresh_vs_cache: bool = False, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/delete_docs接口 | |
| ''' | |
| data = { | |
| "knowledge_base_name": knowledge_base_name, | |
| "file_names": file_names, | |
| "delete_content": delete_content, | |
| "not_refresh_vs_cache": not_refresh_vs_cache, | |
| } | |
| response = self.post( | |
| "/knowledge_base/delete_docs", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def update_kb_info(self, knowledge_base_name, kb_info): | |
| ''' | |
| 对应api.py/knowledge_base/update_info接口 | |
| ''' | |
| data = { | |
| "knowledge_base_name": knowledge_base_name, | |
| "kb_info": kb_info, | |
| } | |
| response = self.post( | |
| "/knowledge_base/update_info", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def update_kb_docs( | |
| self, | |
| knowledge_base_name: str, | |
| file_names: List[str], | |
| override_custom_docs: bool = False, | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=OVERLAP_SIZE, | |
| zh_title_enhance=ZH_TITLE_ENHANCE, | |
| docs: Dict = {}, | |
| not_refresh_vs_cache: bool = False, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/update_docs接口 | |
| ''' | |
| data = { | |
| "knowledge_base_name": knowledge_base_name, | |
| "file_names": file_names, | |
| "override_custom_docs": override_custom_docs, | |
| "chunk_size": chunk_size, | |
| "chunk_overlap": chunk_overlap, | |
| "zh_title_enhance": zh_title_enhance, | |
| "docs": docs, | |
| "not_refresh_vs_cache": not_refresh_vs_cache, | |
| } | |
| if isinstance(data["docs"], dict): | |
| data["docs"] = json.dumps(data["docs"], ensure_ascii=False) | |
| response = self.post( | |
| "/knowledge_base/update_docs", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def recreate_vector_store( | |
| self, | |
| knowledge_base_name: str, | |
| allow_empty_kb: bool = True, | |
| vs_type: str = DEFAULT_VS_TYPE, | |
| embed_model: str = EMBEDDING_MODEL, | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=OVERLAP_SIZE, | |
| zh_title_enhance=ZH_TITLE_ENHANCE, | |
| ): | |
| ''' | |
| 对应api.py/knowledge_base/recreate_vector_store接口 | |
| ''' | |
| data = { | |
| "knowledge_base_name": knowledge_base_name, | |
| "allow_empty_kb": allow_empty_kb, | |
| "vs_type": vs_type, | |
| "embed_model": embed_model, | |
| "chunk_size": chunk_size, | |
| "chunk_overlap": chunk_overlap, | |
| "zh_title_enhance": zh_title_enhance, | |
| } | |
| response = self.post( | |
| "/knowledge_base/recreate_vector_store", | |
| json=data, | |
| stream=True, | |
| timeout=None, | |
| ) | |
| return self._httpx_stream2generator(response, as_json=True) | |
| # LLM模型相关操作 | |
| def list_running_models( | |
| self, | |
| controller_address: str = None, | |
| ): | |
| ''' | |
| 获取Fastchat中正运行的模型列表 | |
| ''' | |
| data = { | |
| "controller_address": controller_address, | |
| } | |
| if log_verbose: | |
| logger.info(f'{self.__class__.__name__}:data: {data}') | |
| response = self.post( | |
| "/llm_model/list_running_models", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", [])) | |
| def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]: | |
| ''' | |
| 从服务器上获取当前运行的LLM模型。 | |
| 当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。 | |
| 返回类型为(model_name, is_local_model) | |
| ''' | |
| def ret_sync(): | |
| running_models = self.list_running_models() | |
| if not running_models: | |
| return "", False | |
| model = "" | |
| for m in LLM_MODELS: | |
| if m not in running_models: | |
| continue | |
| is_local = not running_models[m].get("online_api") | |
| if local_first and not is_local: | |
| continue | |
| else: | |
| model = m | |
| break | |
| if not model: # LLM_MODELS中配置的模型都不在running_models里 | |
| model = list(running_models)[0] | |
| is_local = not running_models[model].get("online_api") | |
| return model, is_local | |
| async def ret_async(): | |
| running_models = await self.list_running_models() | |
| if not running_models: | |
| return "", False | |
| model = "" | |
| for m in LLM_MODELS: | |
| if m not in running_models: | |
| continue | |
| is_local = not running_models[m].get("online_api") | |
| if local_first and not is_local: | |
| continue | |
| else: | |
| model = m | |
| break | |
| if not model: # LLM_MODELS中配置的模型都不在running_models里 | |
| model = list(running_models)[0] | |
| is_local = not running_models[model].get("online_api") | |
| return model, is_local | |
| if self._use_async: | |
| return ret_async() | |
| else: | |
| return ret_sync() | |
| def list_config_models( | |
| self, | |
| types: List[str] = ["local", "online"], | |
| ) -> Dict[str, Dict]: | |
| ''' | |
| 获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。 | |
| ''' | |
| data = { | |
| "types": types, | |
| } | |
| response = self.post( | |
| "/llm_model/list_config_models", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) | |
| def get_model_config( | |
| self, | |
| model_name: str = None, | |
| ) -> Dict: | |
| ''' | |
| 获取服务器上模型配置 | |
| ''' | |
| data = { | |
| "model_name": model_name, | |
| } | |
| response = self.post( | |
| "/llm_model/get_model_config", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) | |
| def list_search_engines(self) -> List[str]: | |
| ''' | |
| 获取服务器支持的搜索引擎 | |
| ''' | |
| response = self.post( | |
| "/server/list_search_engines", | |
| ) | |
| return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {})) | |
| def stop_llm_model( | |
| self, | |
| model_name: str, | |
| controller_address: str = None, | |
| ): | |
| ''' | |
| 停止某个LLM模型。 | |
| 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 | |
| ''' | |
| data = { | |
| "model_name": model_name, | |
| "controller_address": controller_address, | |
| } | |
| response = self.post( | |
| "/llm_model/stop", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| def change_llm_model( | |
| self, | |
| model_name: str, | |
| new_model_name: str, | |
| controller_address: str = None, | |
| ): | |
| ''' | |
| 向fastchat controller请求切换LLM模型。 | |
| ''' | |
| if not model_name or not new_model_name: | |
| return { | |
| "code": 500, | |
| "msg": f"未指定模型名称" | |
| } | |
| def ret_sync(): | |
| running_models = self.list_running_models() | |
| if new_model_name == model_name or new_model_name in running_models: | |
| return { | |
| "code": 200, | |
| "msg": "无需切换" | |
| } | |
| if model_name not in running_models: | |
| return { | |
| "code": 500, | |
| "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" | |
| } | |
| config_models = self.list_config_models() | |
| if new_model_name not in config_models.get("local", {}): | |
| return { | |
| "code": 500, | |
| "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" | |
| } | |
| data = { | |
| "model_name": model_name, | |
| "new_model_name": new_model_name, | |
| "controller_address": controller_address, | |
| } | |
| response = self.post( | |
| "/llm_model/change", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| async def ret_async(): | |
| running_models = await self.list_running_models() | |
| if new_model_name == model_name or new_model_name in running_models: | |
| return { | |
| "code": 200, | |
| "msg": "无需切换" | |
| } | |
| if model_name not in running_models: | |
| return { | |
| "code": 500, | |
| "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" | |
| } | |
| config_models = await self.list_config_models() | |
| if new_model_name not in config_models.get("local", {}): | |
| return { | |
| "code": 500, | |
| "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" | |
| } | |
| data = { | |
| "model_name": model_name, | |
| "new_model_name": new_model_name, | |
| "controller_address": controller_address, | |
| } | |
| response = self.post( | |
| "/llm_model/change", | |
| json=data, | |
| ) | |
| return self._get_response_value(response, as_json=True) | |
| if self._use_async: | |
| return ret_async() | |
| else: | |
| return ret_sync() | |
| def embed_texts( | |
| self, | |
| texts: List[str], | |
| embed_model: str = EMBEDDING_MODEL, | |
| to_query: bool = False, | |
| ) -> List[List[float]]: | |
| ''' | |
| 对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型 | |
| ''' | |
| data = { | |
| "texts": texts, | |
| "embed_model": embed_model, | |
| "to_query": to_query, | |
| } | |
| resp = self.post( | |
| "/other/embed_texts", | |
| json=data, | |
| ) | |
| return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) | |
| def chat_feedback( | |
| self, | |
| message_id: str, | |
| score: int, | |
| reason: str = "", | |
| ) -> int: | |
| ''' | |
| 反馈对话评价 | |
| ''' | |
| data = { | |
| "message_id": message_id, | |
| "score": score, | |
| "reason": reason, | |
| } | |
| resp = self.post("/chat/feedback", json=data) | |
| return self._get_response_value(resp) | |
| class AsyncApiRequest(ApiRequest): | |
| def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT): | |
| super().__init__(base_url, timeout) | |
| self._use_async = True | |
| def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: | |
| ''' | |
| return error message if error occured when requests API | |
| ''' | |
| if isinstance(data, dict): | |
| if key in data: | |
| return data[key] | |
| if "code" in data and data["code"] != 200: | |
| return data["msg"] | |
| return "" | |
| def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str: | |
| ''' | |
| return error message if error occured when requests API | |
| ''' | |
| if (isinstance(data, dict) | |
| and key in data | |
| and "code" in data | |
| and data["code"] == 200): | |
| return data[key] | |
| return "" | |
| if __name__ == "__main__": | |
| api = ApiRequest() | |
| aapi = AsyncApiRequest() | |
| # with api.chat_chat("你好") as r: | |
| # for t in r.iter_text(None): | |
| # print(t) | |
| # r = api.chat_chat("你好", no_remote_api=True) | |
| # for t in r: | |
| # print(t) | |
| # r = api.duckduckgo_search_chat("室温超导最新研究进展", no_remote_api=True) | |
| # for t in r: | |
| # print(t) | |
| # print(api.list_knowledge_bases()) | |