Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Iterable, List, Tuple | |
| from langchain import chains | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from financial_bot import constants | |
| from financial_bot.chains import ( | |
| ContextExtractorChain, | |
| FinancialBotQAChain, | |
| StatelessMemorySequentialChain, | |
| ) | |
| from financial_bot.embeddings import EmbeddingModelSingleton | |
| from financial_bot.handlers import CometLLMMonitoringHandler | |
| from financial_bot.models import build_huggingface_pipeline | |
| from financial_bot.qdrant import build_qdrant_client | |
| from financial_bot.template import get_llm_template | |
| logger = logging.getLogger(__name__) | |
| class FinancialBot: | |
| """ | |
| A language chain bot that uses a language model to generate responses to user inputs. | |
| Args: | |
| llm_model_id (str): The ID of the Hugging Face language model to use. | |
| llm_qlora_model_id (str): The ID of the Hugging Face QLora model to use. | |
| llm_template_name (str): The name of the LLM template to use. | |
| llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference. | |
| llm_inference_temperature (float): The temperature to use during inference. | |
| vector_collection_name (str): The name of the Qdrant vector collection to use. | |
| vector_db_search_topk (int): The number of nearest neighbors to search for in the Qdrant vector database. | |
| model_cache_dir (Path): The directory to use for caching the language model and embedding model. | |
| streaming (bool): Whether to use the Hugging Face streaming API for inference. | |
| embedding_model_device (str): The device to use for the embedding model. | |
| debug (bool): Whether to enable debug mode. | |
| Attributes: | |
| finbot_chain (Chain): The language chain that generates responses to user inputs. | |
| """ | |
| def __init__( | |
| self, | |
| llm_model_id: str = constants.LLM_MODEL_ID, | |
| llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT, | |
| llm_template_name: str = constants.TEMPLATE_NAME, | |
| llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS, | |
| llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE, | |
| vector_collection_name: str = constants.VECTOR_DB_OUTPUT_COLLECTION_NAME, | |
| vector_db_search_topk: int = constants.VECTOR_DB_SEARCH_TOPK, | |
| model_cache_dir: Path = constants.CACHE_DIR, | |
| streaming: bool = False, | |
| embedding_model_device: str = "cuda:0", | |
| debug: bool = False, | |
| ): | |
| self._llm_model_id = llm_model_id | |
| self._llm_qlora_model_id = llm_qlora_model_id | |
| self._llm_template_name = llm_template_name | |
| self._llm_template = get_llm_template(name=self._llm_template_name) | |
| self._llm_inference_max_new_tokens = llm_inference_max_new_tokens | |
| self._llm_inference_temperature = llm_inference_temperature | |
| self._vector_collection_name = vector_collection_name | |
| self._vector_db_search_topk = vector_db_search_topk | |
| self._debug = debug | |
| self._qdrant_client = build_qdrant_client() | |
| self._embd_model = EmbeddingModelSingleton( | |
| cache_dir=model_cache_dir, device=embedding_model_device | |
| ) | |
| self._llm_agent, self._streamer = build_huggingface_pipeline( | |
| llm_model_id=llm_model_id, | |
| llm_lora_model_id=llm_qlora_model_id, | |
| max_new_tokens=llm_inference_max_new_tokens, | |
| temperature=llm_inference_temperature, | |
| use_streamer=streaming, | |
| cache_dir=model_cache_dir, | |
| debug=debug, | |
| ) | |
| self.finbot_chain = self.build_chain() | |
| def is_streaming(self) -> bool: | |
| return self._streamer is not None | |
| def build_chain(self) -> chains.SequentialChain: | |
| """ | |
| Constructs and returns a financial bot chain. | |
| This chain is designed to take as input the user description, `about_me` and a `question` and it will | |
| connect to the VectorDB, searches the financial news that rely on the user's question and injects them into the | |
| payload that is further passed as a prompt to a financial fine-tuned LLM that will provide answers. | |
| The chain consists of two primary stages: | |
| 1. Context Extractor: This stage is responsible for embedding the user's question, | |
| which means converting the textual question into a numerical representation. | |
| This embedded question is then used to retrieve relevant context from the VectorDB. | |
| The output of this chain will be a dict payload. | |
| 2. LLM Generator: Once the context is extracted, | |
| this stage uses it to format a full prompt for the LLM and | |
| then feed it to the model to get a response that is relevant to the user's question. | |
| Returns | |
| ------- | |
| chains.SequentialChain | |
| The constructed financial bot chain. | |
| Notes | |
| ----- | |
| The actual processing flow within the chain can be visualized as: | |
| [about: str][question: str] > ContextChain > | |
| [about: str][question:str] + [context: str] > FinancialChain > | |
| [answer: str] | |
| """ | |
| logger.info("Building 1/3 - ContextExtractorChain") | |
| context_retrieval_chain = ContextExtractorChain( | |
| embedding_model=self._embd_model, | |
| vector_store=self._qdrant_client, | |
| vector_collection=self._vector_collection_name, | |
| top_k=self._vector_db_search_topk, | |
| ) | |
| logger.info("Building 2/3 - FinancialBotQAChain") | |
| if self._debug: | |
| callabacks = [] | |
| else: | |
| try: | |
| comet_project_name = os.environ["COMET_PROJECT_NAME"] | |
| except KeyError: | |
| raise RuntimeError( | |
| "Please set the COMET_PROJECT_NAME environment variable." | |
| ) | |
| callabacks = [ | |
| CometLLMMonitoringHandler( | |
| project_name=f"{comet_project_name}-monitor-prompts", | |
| llm_model_id=self._llm_model_id, | |
| llm_qlora_model_id=self._llm_qlora_model_id, | |
| llm_inference_max_new_tokens=self._llm_inference_max_new_tokens, | |
| llm_inference_temperature=self._llm_inference_temperature, | |
| ) | |
| ] | |
| llm_generator_chain = FinancialBotQAChain( | |
| hf_pipeline=self._llm_agent, | |
| template=self._llm_template, | |
| callbacks=callabacks, | |
| ) | |
| logger.info("Building 3/3 - Connecting chains into SequentialChain") | |
| seq_chain = StatelessMemorySequentialChain( | |
| history_input_key="to_load_history", | |
| memory=ConversationBufferWindowMemory( | |
| memory_key="chat_history", | |
| input_key="question", | |
| output_key="answer", | |
| k=3, | |
| ), | |
| chains=[context_retrieval_chain, llm_generator_chain], | |
| input_variables=["about_me", "question", "to_load_history"], | |
| output_variables=["answer"], | |
| verbose=True, | |
| ) | |
| logger.info("Done building SequentialChain.") | |
| logger.info("Workflow:") | |
| logger.info( | |
| """ | |
| [about: str][question: str] > ContextChain > | |
| [about: str][question:str] + [context: str] > FinancialChain > | |
| [answer: str] | |
| """ | |
| ) | |
| return seq_chain | |
| def answer( | |
| self, | |
| about_me: str, | |
| question: str, | |
| to_load_history: List[Tuple[str, str]] = None, | |
| ) -> str: | |
| """ | |
| Given a short description about the user and a question make the LLM | |
| generate a response. | |
| Parameters | |
| ---------- | |
| about_me : str | |
| Short user description. | |
| question : str | |
| User question. | |
| Returns | |
| ------- | |
| str | |
| LLM generated response. | |
| """ | |
| inputs = { | |
| "about_me": about_me, | |
| "question": question, | |
| "to_load_history": to_load_history if to_load_history else [], | |
| } | |
| response = self.finbot_chain.run(inputs) | |
| return response | |
| def stream_answer(self) -> Iterable[str]: | |
| """Stream the answer from the LLM after each token is generated after calling `answer()`.""" | |
| assert ( | |
| self.is_streaming | |
| ), "Stream answer not available. Build the bot with `use_streamer=True`." | |
| partial_answer = "" | |
| for new_token in self._streamer: | |
| if new_token != self._llm_template.eos: | |
| partial_answer += new_token | |
| yield partial_answer | |