Spaces:
Sleeping
Sleeping
| from typing import List, Optional | |
| from langchain.schema.language_model import BaseLanguageModel | |
| from server.knowledge_base.model.kb_document_model import DocumentWithVSId | |
| from configs import (logger) | |
| from langchain.chains import StuffDocumentsChain, LLMChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain.docstore.document import Document | |
| from langchain.output_parsers.regex import RegexParser | |
| from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain | |
| import sys | |
| import asyncio | |
| class SummaryAdapter: | |
| _OVERLAP_SIZE: int | |
| token_max: int | |
| _separator: str = "\n\n" | |
| chain: MapReduceDocumentsChain | |
| def __init__(self, overlap_size: int, token_max: int, | |
| chain: MapReduceDocumentsChain): | |
| self._OVERLAP_SIZE = overlap_size | |
| self.chain = chain | |
| self.token_max = token_max | |
| def form_summary(cls, | |
| llm: BaseLanguageModel, | |
| reduce_llm: BaseLanguageModel, | |
| overlap_size: int, | |
| token_max: int = 1300): | |
| """ | |
| 获取实例 | |
| :param reduce_llm: 用于合并摘要的llm | |
| :param llm: 用于生成摘要的llm | |
| :param overlap_size: 重叠部分大小 | |
| :param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错 | |
| :return: | |
| """ | |
| # This controls how each document will be formatted. Specifically, | |
| document_prompt = PromptTemplate( | |
| input_variables=["page_content"], | |
| template="{page_content}" | |
| ) | |
| # The prompt here should take as an input variable the | |
| # `document_variable_name` | |
| prompt_template = ( | |
| "根据文本执行任务。以下任务信息" | |
| "{task_briefing}" | |
| "文本内容如下: " | |
| "\r\n" | |
| "{context}" | |
| ) | |
| prompt = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["task_briefing", "context"] | |
| ) | |
| llm_chain = LLMChain(llm=llm, prompt=prompt) | |
| # We now define how to combine these summaries | |
| reduce_prompt = PromptTemplate.from_template( | |
| "Combine these summaries: {context}" | |
| ) | |
| reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt) | |
| document_variable_name = "context" | |
| combine_documents_chain = StuffDocumentsChain( | |
| llm_chain=reduce_llm_chain, | |
| document_prompt=document_prompt, | |
| document_variable_name=document_variable_name | |
| ) | |
| reduce_documents_chain = ReduceDocumentsChain( | |
| token_max=token_max, | |
| combine_documents_chain=combine_documents_chain, | |
| ) | |
| chain = MapReduceDocumentsChain( | |
| llm_chain=llm_chain, | |
| document_variable_name=document_variable_name, | |
| reduce_documents_chain=reduce_documents_chain, | |
| # 返回中间步骤 | |
| return_intermediate_steps=True | |
| ) | |
| return cls(overlap_size=overlap_size, | |
| chain=chain, | |
| token_max=token_max) | |
| def summarize(self, | |
| file_description: str, | |
| docs: List[DocumentWithVSId] = [] | |
| ) -> List[Document]: | |
| if sys.version_info < (3, 10): | |
| loop = asyncio.get_event_loop() | |
| else: | |
| try: | |
| loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| # 同步调用协程代码 | |
| return loop.run_until_complete(self.asummarize(file_description=file_description, | |
| docs=docs)) | |
| async def asummarize(self, | |
| file_description: str, | |
| docs: List[DocumentWithVSId] = []) -> List[Document]: | |
| logger.info("start summary") | |
| """ | |
| 这个过程分成两个部分: | |
| 1. 对每个文档进行处理,得到每个文档的摘要 | |
| map_results = self.llm_chain.apply( | |
| # FYI - this is parallelized and so it is fast. | |
| [{self.document_variable_name: d.page_content, **kwargs} for d in docs], | |
| callbacks=callbacks, | |
| ) | |
| 2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤 | |
| result, extra_return_dict = self.reduce_documents_chain.combine_docs( | |
| result_docs, token_max=token_max, callbacks=callbacks, **kwargs | |
| ) | |
| """ | |
| summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs, | |
| task_briefing="描述不同方法之间的接近度和相似性," | |
| "以帮助读者理解它们之间的关系。") | |
| print(summary_combine) | |
| print(summary_intermediate_steps) | |
| # if len(summary_combine) == 0: | |
| # # 为空重新生成,数量减半 | |
| # result_docs = [ | |
| # Document(page_content=question_result_key, metadata=docs[i].metadata) | |
| # # This uses metadata from the docs, and the textual results from `results` | |
| # for i, question_result_key in enumerate( | |
| # summary_intermediate_steps["intermediate_steps"][ | |
| # :len(summary_intermediate_steps["intermediate_steps"]) // 2 | |
| # ]) | |
| # ] | |
| # summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs( | |
| # result_docs, token_max=self.token_max | |
| # ) | |
| logger.info("end summary") | |
| doc_ids = ",".join([doc.id for doc in docs]) | |
| _metadata = { | |
| "file_description": file_description, | |
| "summary_intermediate_steps": summary_intermediate_steps, | |
| "doc_ids": doc_ids | |
| } | |
| summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata) | |
| return [summary_combine_doc] | |
| def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]: | |
| """ | |
| # 将文档中page_content句子叠加的部分去掉 | |
| :param docs: | |
| :param separator: | |
| :return: | |
| """ | |
| merge_docs = [] | |
| pre_doc = None | |
| for doc in docs: | |
| # 第一个文档直接添加 | |
| if len(merge_docs) == 0: | |
| pre_doc = doc.page_content | |
| merge_docs.append(doc.page_content) | |
| continue | |
| # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 | |
| # 迭代递减pre_doc的长度,每次迭代删除前面的字符, | |
| # 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator) | |
| for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1): | |
| # 每次迭代删除前面的字符 | |
| pre_doc = pre_doc[1:] | |
| if doc.page_content[:len(pre_doc)] == pre_doc: | |
| # 删除下一个开头重叠的部分 | |
| merge_docs.append(doc.page_content[len(pre_doc):]) | |
| break | |
| pre_doc = doc.page_content | |
| return merge_docs | |
| def _join_docs(self, docs: List[str]) -> Optional[str]: | |
| text = self._separator.join(docs) | |
| text = text.strip() | |
| if text == "": | |
| return None | |
| else: | |
| return text | |
| if __name__ == '__main__': | |
| docs = [ | |
| '梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的', | |
| '梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象', | |
| '使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各', | |
| '值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全' | |
| ] | |
| _OVERLAP_SIZE = 1 | |
| separator: str = "\n\n" | |
| merge_docs = [] | |
| # 将文档中page_content句子叠加的部分去掉, | |
| # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 | |
| pre_doc = None | |
| for doc in docs: | |
| # 第一个文档直接添加 | |
| if len(merge_docs) == 0: | |
| pre_doc = doc | |
| merge_docs.append(doc) | |
| continue | |
| # 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分 | |
| # 迭代递减pre_doc的长度,每次迭代删除前面的字符, | |
| # 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator) | |
| for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1): | |
| # 每次迭代删除前面的字符 | |
| pre_doc = pre_doc[1:] | |
| if doc[:len(pre_doc)] == pre_doc: | |
| # 删除下一个开头重叠的部分 | |
| page_content = doc[len(pre_doc):] | |
| merge_docs.append(page_content) | |
| pre_doc = doc | |
| break | |
| # 将merge_docs中的句子合并成一个文档 | |
| text = separator.join(merge_docs) | |
| text = text.strip() | |
| print(text) | |