Spaces:
Running
Running
| # From project chatglm-langchain | |
| import threading | |
| from toolbox import Singleton | |
| import os | |
| import shutil | |
| import os | |
| import uuid | |
| import tqdm | |
| from langchain.vectorstores import FAISS | |
| from langchain.docstore.document import Document | |
| from typing import List, Tuple | |
| import numpy as np | |
| from crazy_functions.vector_fns.general_file_loader import load_file | |
| embedding_model_dict = { | |
| "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", | |
| "ernie-base": "nghuyong/ernie-3.0-base-zh", | |
| "text2vec-base": "shibing624/text2vec-base-chinese", | |
| "text2vec": "GanymedeNil/text2vec-large-chinese", | |
| } | |
| # Embedding model name | |
| EMBEDDING_MODEL = "text2vec" | |
| # Embedding running device | |
| EMBEDDING_DEVICE = "cpu" | |
| # 基于上下文的prompt模版,请务必保留"{question}"和"{context}" | |
| PROMPT_TEMPLATE = """已知信息: | |
| {context} | |
| 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}""" | |
| # 文本分句长度 | |
| SENTENCE_SIZE = 100 | |
| # 匹配后单段上下文长度 | |
| CHUNK_SIZE = 250 | |
| # LLM input history length | |
| LLM_HISTORY_LEN = 3 | |
| # return top-k text chunk from vector store | |
| VECTOR_SEARCH_TOP_K = 5 | |
| # 知识检索内容相关度 Score, 数值范围约为0-1100,如果为0,则不生效,经测试设置为小于500时,匹配结果更精准 | |
| VECTOR_SEARCH_SCORE_THRESHOLD = 0 | |
| NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") | |
| FLAG_USER_NAME = uuid.uuid4().hex | |
| # 是否开启跨域,默认为False,如果需要开启,请设置为True | |
| # is open cross domain | |
| OPEN_CROSS_DOMAIN = False | |
| def similarity_search_with_score_by_vector( | |
| self, embedding: List[float], k: int = 4 | |
| ) -> List[Tuple[Document, float]]: | |
| def seperate_list(ls: List[int]) -> List[List[int]]: | |
| lists = [] | |
| ls1 = [ls[0]] | |
| for i in range(1, len(ls)): | |
| if ls[i - 1] + 1 == ls[i]: | |
| ls1.append(ls[i]) | |
| else: | |
| lists.append(ls1) | |
| ls1 = [ls[i]] | |
| lists.append(ls1) | |
| return lists | |
| scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) | |
| docs = [] | |
| id_set = set() | |
| store_len = len(self.index_to_docstore_id) | |
| for j, i in enumerate(indices[0]): | |
| if i == -1 or 0 < self.score_threshold < scores[0][j]: | |
| # This happens when not enough docs are returned. | |
| continue | |
| _id = self.index_to_docstore_id[i] | |
| doc = self.docstore.search(_id) | |
| if not self.chunk_conent: | |
| if not isinstance(doc, Document): | |
| raise ValueError(f"Could not find document for id {_id}, got {doc}") | |
| doc.metadata["score"] = int(scores[0][j]) | |
| docs.append(doc) | |
| continue | |
| id_set.add(i) | |
| docs_len = len(doc.page_content) | |
| for k in range(1, max(i, store_len - i)): | |
| break_flag = False | |
| for l in [i + k, i - k]: | |
| if 0 <= l < len(self.index_to_docstore_id): | |
| _id0 = self.index_to_docstore_id[l] | |
| doc0 = self.docstore.search(_id0) | |
| if docs_len + len(doc0.page_content) > self.chunk_size: | |
| break_flag = True | |
| break | |
| elif doc0.metadata["source"] == doc.metadata["source"]: | |
| docs_len += len(doc0.page_content) | |
| id_set.add(l) | |
| if break_flag: | |
| break | |
| if not self.chunk_conent: | |
| return docs | |
| if len(id_set) == 0 and self.score_threshold > 0: | |
| return [] | |
| id_list = sorted(list(id_set)) | |
| id_lists = seperate_list(id_list) | |
| for id_seq in id_lists: | |
| for id in id_seq: | |
| if id == id_seq[0]: | |
| _id = self.index_to_docstore_id[id] | |
| doc = self.docstore.search(_id) | |
| else: | |
| _id0 = self.index_to_docstore_id[id] | |
| doc0 = self.docstore.search(_id0) | |
| doc.page_content += " " + doc0.page_content | |
| if not isinstance(doc, Document): | |
| raise ValueError(f"Could not find document for id {_id}, got {doc}") | |
| doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]]) | |
| doc.metadata["score"] = int(doc_score) | |
| docs.append(doc) | |
| return docs | |
| class LocalDocQA: | |
| llm: object = None | |
| embeddings: object = None | |
| top_k: int = VECTOR_SEARCH_TOP_K | |
| chunk_size: int = CHUNK_SIZE | |
| chunk_conent: bool = True | |
| score_threshold: int = VECTOR_SEARCH_SCORE_THRESHOLD | |
| def init_cfg(self, | |
| top_k=VECTOR_SEARCH_TOP_K, | |
| ): | |
| self.llm = None | |
| self.top_k = top_k | |
| def init_knowledge_vector_store(self, | |
| filepath, | |
| vs_path: str or os.PathLike = None, | |
| sentence_size=SENTENCE_SIZE, | |
| text2vec=None): | |
| loaded_files = [] | |
| failed_files = [] | |
| if isinstance(filepath, str): | |
| if not os.path.exists(filepath): | |
| print("路径不存在") | |
| return None | |
| elif os.path.isfile(filepath): | |
| file = os.path.split(filepath)[-1] | |
| try: | |
| docs = load_file(filepath, SENTENCE_SIZE) | |
| print(f"{file} 已成功加载") | |
| loaded_files.append(filepath) | |
| except Exception as e: | |
| print(e) | |
| print(f"{file} 未能成功加载") | |
| return None | |
| elif os.path.isdir(filepath): | |
| docs = [] | |
| for file in tqdm(os.listdir(filepath), desc="加载文件"): | |
| fullfilepath = os.path.join(filepath, file) | |
| try: | |
| docs += load_file(fullfilepath, SENTENCE_SIZE) | |
| loaded_files.append(fullfilepath) | |
| except Exception as e: | |
| print(e) | |
| failed_files.append(file) | |
| if len(failed_files) > 0: | |
| print("以下文件未能成功加载:") | |
| for file in failed_files: | |
| print(f"{file}\n") | |
| else: | |
| docs = [] | |
| for file in filepath: | |
| docs += load_file(file, SENTENCE_SIZE) | |
| print(f"{file} 已成功加载") | |
| loaded_files.append(file) | |
| if len(docs) > 0: | |
| print("文件加载完毕,正在生成向量库") | |
| if vs_path and os.path.isdir(vs_path): | |
| try: | |
| self.vector_store = FAISS.load_local(vs_path, text2vec) | |
| self.vector_store.add_documents(docs) | |
| except: | |
| self.vector_store = FAISS.from_documents(docs, text2vec) | |
| else: | |
| self.vector_store = FAISS.from_documents(docs, text2vec) # docs 为Document列表 | |
| self.vector_store.save_local(vs_path) | |
| return vs_path, loaded_files | |
| else: | |
| raise RuntimeError("文件加载失败,请检查文件格式是否正确") | |
| def get_loaded_file(self, vs_path): | |
| ds = self.vector_store.docstore | |
| return set([ds._dict[k].metadata['source'].split(vs_path)[-1] for k in ds._dict]) | |
| # query 查询内容 | |
| # vs_path 知识库路径 | |
| # chunk_conent 是否启用上下文关联 | |
| # score_threshold 搜索匹配score阈值 | |
| # vector_search_top_k 搜索知识库内容条数,默认搜索5条结果 | |
| # chunk_sizes 匹配单段内容的连接上下文长度 | |
| def get_knowledge_based_conent_test(self, query, vs_path, chunk_conent, | |
| score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, | |
| vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE, | |
| text2vec=None): | |
| self.vector_store = FAISS.load_local(vs_path, text2vec) | |
| self.vector_store.chunk_conent = chunk_conent | |
| self.vector_store.score_threshold = score_threshold | |
| self.vector_store.chunk_size = chunk_size | |
| embedding = self.vector_store.embedding_function.embed_query(query) | |
| related_docs_with_score = similarity_search_with_score_by_vector(self.vector_store, embedding, k=vector_search_top_k) | |
| if not related_docs_with_score: | |
| response = {"query": query, | |
| "source_documents": []} | |
| return response, "" | |
| # prompt = f"{query}. You should answer this question using information from following documents: \n\n" | |
| prompt = f"{query}. 你必须利用以下文档中包含的信息回答这个问题: \n\n---\n\n" | |
| prompt += "\n\n".join([f"({k}): " + doc.page_content for k, doc in enumerate(related_docs_with_score)]) | |
| prompt += "\n\n---\n\n" | |
| prompt = prompt.encode('utf-8', 'ignore').decode() # avoid reading non-utf8 chars | |
| # print(prompt) | |
| response = {"query": query, "source_documents": related_docs_with_score} | |
| return response, prompt | |
| def construct_vector_store(vs_id, vs_path, files, sentence_size, history, one_conent, one_content_segmentation, text2vec): | |
| for file in files: | |
| assert os.path.exists(file), "输入文件不存在:" + file | |
| import nltk | |
| if NLTK_DATA_PATH not in nltk.data.path: nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path | |
| local_doc_qa = LocalDocQA() | |
| local_doc_qa.init_cfg() | |
| filelist = [] | |
| if not os.path.exists(os.path.join(vs_path, vs_id)): | |
| os.makedirs(os.path.join(vs_path, vs_id)) | |
| for file in files: | |
| file_name = file.name if not isinstance(file, str) else file | |
| filename = os.path.split(file_name)[-1] | |
| shutil.copyfile(file_name, os.path.join(vs_path, vs_id, filename)) | |
| filelist.append(os.path.join(vs_path, vs_id, filename)) | |
| vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, os.path.join(vs_path, vs_id), sentence_size, text2vec) | |
| if len(loaded_files): | |
| file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" | |
| else: | |
| pass | |
| # file_status = "文件未成功加载,请重新上传文件" | |
| # print(file_status) | |
| return local_doc_qa, vs_path | |
| class knowledge_archive_interface(): | |
| def __init__(self) -> None: | |
| self.threadLock = threading.Lock() | |
| self.current_id = "" | |
| self.kai_path = None | |
| self.qa_handle = None | |
| self.text2vec_large_chinese = None | |
| def get_chinese_text2vec(self): | |
| if self.text2vec_large_chinese is None: | |
| # < -------------------预热文本向量化模组--------------- > | |
| from toolbox import ProxyNetworkActivate | |
| print('Checking Text2vec ...') | |
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
| with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络 | |
| self.text2vec_large_chinese = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese") | |
| return self.text2vec_large_chinese | |
| def feed_archive(self, file_manifest, vs_path, id="default"): | |
| self.threadLock.acquire() | |
| # import uuid | |
| self.current_id = id | |
| self.qa_handle, self.kai_path = construct_vector_store( | |
| vs_id=self.current_id, | |
| vs_path=vs_path, | |
| files=file_manifest, | |
| sentence_size=100, | |
| history=[], | |
| one_conent="", | |
| one_content_segmentation="", | |
| text2vec = self.get_chinese_text2vec(), | |
| ) | |
| self.threadLock.release() | |
| def get_current_archive_id(self): | |
| return self.current_id | |
| def get_loaded_file(self, vs_path): | |
| return self.qa_handle.get_loaded_file(vs_path) | |
| def answer_with_archive_by_id(self, txt, id, vs_path): | |
| self.threadLock.acquire() | |
| if not self.current_id == id: | |
| self.current_id = id | |
| self.qa_handle, self.kai_path = construct_vector_store( | |
| vs_id=self.current_id, | |
| vs_path=vs_path, | |
| files=[], | |
| sentence_size=100, | |
| history=[], | |
| one_conent="", | |
| one_content_segmentation="", | |
| text2vec = self.get_chinese_text2vec(), | |
| ) | |
| VECTOR_SEARCH_SCORE_THRESHOLD = 0 | |
| VECTOR_SEARCH_TOP_K = 4 | |
| CHUNK_SIZE = 512 | |
| resp, prompt = self.qa_handle.get_knowledge_based_conent_test( | |
| query = txt, | |
| vs_path = self.kai_path, | |
| score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, | |
| vector_search_top_k=VECTOR_SEARCH_TOP_K, | |
| chunk_conent=True, | |
| chunk_size=CHUNK_SIZE, | |
| text2vec = self.get_chinese_text2vec(), | |
| ) | |
| self.threadLock.release() | |
| return resp, prompt |