Spaces:
Sleeping
Sleeping
| from typing import List | |
| import os | |
| import shutil | |
| from langchain.embeddings.base import Embeddings | |
| from langchain.schema import Document | |
| from langchain.vectorstores.elasticsearch import ElasticsearchStore | |
| from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM | |
| from server.knowledge_base.kb_service.base import KBService, SupportedVSType | |
| from server.knowledge_base.utils import KnowledgeFile | |
| from server.utils import load_local_embeddings | |
| from elasticsearch import Elasticsearch,BadRequestError | |
| from configs import logger | |
| from configs import kbs_config | |
| class ESKBService(KBService): | |
| def do_init(self): | |
| self.kb_path = self.get_kb_path(self.kb_name) | |
| self.index_name = os.path.split(self.kb_path)[-1] | |
| self.IP = kbs_config[self.vs_type()]['host'] | |
| self.PORT = kbs_config[self.vs_type()]['port'] | |
| self.user = kbs_config[self.vs_type()].get("user",'') | |
| self.password = kbs_config[self.vs_type()].get("password",'') | |
| self.dims_length = kbs_config[self.vs_type()].get("dims_length",None) | |
| self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE) | |
| try: | |
| # ES python客户端连接(仅连接) | |
| if self.user != "" and self.password != "": | |
| self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}", | |
| basic_auth=(self.user,self.password)) | |
| else: | |
| logger.warning("ES未配置用户名和密码") | |
| self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}") | |
| except ConnectionError: | |
| logger.error("连接到 Elasticsearch 失败!") | |
| raise ConnectionError | |
| except Exception as e: | |
| logger.error(f"Error 发生 : {e}") | |
| raise e | |
| try: | |
| # 首先尝试通过es_client_python创建 | |
| mappings = { | |
| "properties": { | |
| "dense_vector": { | |
| "type": "dense_vector", | |
| "dims": self.dims_length, | |
| "index": True | |
| } | |
| } | |
| } | |
| self.es_client_python.indices.create(index=self.index_name, mappings=mappings) | |
| except BadRequestError as e: | |
| logger.error("创建索引失败,重新") | |
| logger.error(e) | |
| try: | |
| # langchain ES 连接、创建索引 | |
| if self.user != "" and self.password != "": | |
| self.db_init = ElasticsearchStore( | |
| es_url=f"http://{self.IP}:{self.PORT}", | |
| index_name=self.index_name, | |
| query_field="context", | |
| vector_query_field="dense_vector", | |
| embedding=self.embeddings_model, | |
| es_user=self.user, | |
| es_password=self.password | |
| ) | |
| else: | |
| logger.warning("ES未配置用户名和密码") | |
| self.db_init = ElasticsearchStore( | |
| es_url=f"http://{self.IP}:{self.PORT}", | |
| index_name=self.index_name, | |
| query_field="context", | |
| vector_query_field="dense_vector", | |
| embedding=self.embeddings_model, | |
| ) | |
| except ConnectionError: | |
| print("### 初始化 Elasticsearch 失败!") | |
| logger.error("### 初始化 Elasticsearch 失败!") | |
| raise ConnectionError | |
| except Exception as e: | |
| logger.error(f"Error 发生 : {e}") | |
| raise e | |
| try: | |
| # 尝试通过db_init创建索引 | |
| self.db_init._create_index_if_not_exists( | |
| index_name=self.index_name, | |
| dims_length=self.dims_length | |
| ) | |
| except Exception as e: | |
| logger.error("创建索引失败...") | |
| logger.error(e) | |
| # raise e | |
| def get_kb_path(knowledge_base_name: str): | |
| return os.path.join(KB_ROOT_PATH, knowledge_base_name) | |
| def get_vs_path(knowledge_base_name: str): | |
| return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store") | |
| def do_create_kb(self): | |
| if os.path.exists(self.doc_path): | |
| if not os.path.exists(os.path.join(self.kb_path, "vector_store")): | |
| os.makedirs(os.path.join(self.kb_path, "vector_store")) | |
| else: | |
| logger.warning("directory `vector_store` already exists.") | |
| def vs_type(self) -> str: | |
| return SupportedVSType.ES | |
| def _load_es(self, docs, embed_model): | |
| # 将docs写入到ES中 | |
| try: | |
| # 连接 + 同时写入文档 | |
| if self.user != "" and self.password != "": | |
| self.db = ElasticsearchStore.from_documents( | |
| documents=docs, | |
| embedding=embed_model, | |
| es_url= f"http://{self.IP}:{self.PORT}", | |
| index_name=self.index_name, | |
| distance_strategy="COSINE", | |
| query_field="context", | |
| vector_query_field="dense_vector", | |
| verify_certs=False, | |
| es_user=self.user, | |
| es_password=self.password | |
| ) | |
| else: | |
| self.db = ElasticsearchStore.from_documents( | |
| documents=docs, | |
| embedding=embed_model, | |
| es_url= f"http://{self.IP}:{self.PORT}", | |
| index_name=self.index_name, | |
| distance_strategy="COSINE", | |
| query_field="context", | |
| vector_query_field="dense_vector", | |
| verify_certs=False) | |
| except ConnectionError as ce: | |
| print(ce) | |
| print("连接到 Elasticsearch 失败!") | |
| logger.error("连接到 Elasticsearch 失败!") | |
| except Exception as e: | |
| logger.error(f"Error 发生 : {e}") | |
| print(e) | |
| def do_search(self, query:str, top_k: int, score_threshold: float): | |
| # 文本相似性检索 | |
| docs = self.db_init.similarity_search_with_score(query=query, | |
| k=top_k) | |
| return docs | |
| def del_doc_by_ids(self, ids: List[str]) -> bool: | |
| for doc_id in ids: | |
| try: | |
| self.es_client_python.delete(index=self.index_name, | |
| id=doc_id, | |
| refresh=True) | |
| except Exception as e: | |
| logger.error(f"ES Docs Delete Error! {e}") | |
| def do_delete_doc(self, kb_file, **kwargs): | |
| if self.es_client_python.indices.exists(index=self.index_name): | |
| # 从向量数据库中删除索引(文档名称是Keyword) | |
| query = { | |
| "query": { | |
| "term": { | |
| "metadata.source.keyword": kb_file.filepath | |
| } | |
| } | |
| } | |
| # 注意设置size,默认返回10个。 | |
| search_results = self.es_client_python.search(body=query, size=50) | |
| delete_list = [hit["_id"] for hit in search_results['hits']['hits']] | |
| if len(delete_list) == 0: | |
| return None | |
| else: | |
| for doc_id in delete_list: | |
| try: | |
| self.es_client_python.delete(index=self.index_name, | |
| id=doc_id, | |
| refresh=True) | |
| except Exception as e: | |
| logger.error(f"ES Docs Delete Error! {e}") | |
| # self.db_init.delete(ids=delete_list) | |
| #self.es_client_python.indices.refresh(index=self.index_name) | |
| def do_add_doc(self, docs: List[Document], **kwargs): | |
| '''向知识库添加文件''' | |
| print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}") | |
| print("*"*100) | |
| self._load_es(docs=docs, embed_model=self.embeddings_model) | |
| # 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...] | |
| print("写入数据成功.") | |
| print("*"*100) | |
| if self.es_client_python.indices.exists(index=self.index_name): | |
| file_path = docs[0].metadata.get("source") | |
| query = { | |
| "query": { | |
| "term": { | |
| "metadata.source.keyword": file_path | |
| } | |
| } | |
| } | |
| search_results = self.es_client_python.search(body=query) | |
| if len(search_results["hits"]["hits"]) == 0: | |
| raise ValueError("召回元素个数为0") | |
| info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]] | |
| return info_docs | |
| def do_clear_vs(self): | |
| """从知识库删除全部向量""" | |
| if self.es_client_python.indices.exists(index=self.kb_name): | |
| self.es_client_python.indices.delete(index=self.kb_name) | |
| def do_drop_kb(self): | |
| """删除知识库""" | |
| # self.kb_file: 知识库路径 | |
| if os.path.exists(self.kb_path): | |
| shutil.rmtree(self.kb_path) | |
| if __name__ == '__main__': | |
| esKBService = ESKBService("test") | |
| #esKBService.clear_vs() | |
| #esKBService.create_kb() | |
| esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test")) | |
| print(esKBService.search_docs("如何启动api服务")) | |