Spaces:
Runtime error
Runtime error
Upload 22 files
Browse files- financial_bot/__init__.py +55 -0
- financial_bot/__pycache__/__init__.cpython-310.pyc +0 -0
- financial_bot/__pycache__/base.cpython-310.pyc +0 -0
- financial_bot/__pycache__/chains.cpython-310.pyc +0 -0
- financial_bot/__pycache__/constants.cpython-310.pyc +0 -0
- financial_bot/__pycache__/embeddings.cpython-310.pyc +0 -0
- financial_bot/__pycache__/handlers.cpython-310.pyc +0 -0
- financial_bot/__pycache__/langchain_bot.cpython-310.pyc +0 -0
- financial_bot/__pycache__/models.cpython-310.pyc +0 -0
- financial_bot/__pycache__/qdrant.cpython-310.pyc +0 -0
- financial_bot/__pycache__/template.cpython-310.pyc +0 -0
- financial_bot/__pycache__/utils.cpython-310.pyc +0 -0
- financial_bot/base.py +38 -0
- financial_bot/chains.py +226 -0
- financial_bot/constants.py +23 -0
- financial_bot/embeddings.py +123 -0
- financial_bot/handlers.py +64 -0
- financial_bot/langchain_bot.py +223 -0
- financial_bot/models.py +264 -0
- financial_bot/qdrant.py +49 -0
- financial_bot/template.py +132 -0
- financial_bot/utils.py +106 -0
financial_bot/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import logging.config
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
from dotenv import find_dotenv, load_dotenv
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def initialize(logging_config_path: str = "logging.yaml", env_file_path: str = ".env"):
|
| 12 |
+
"""
|
| 13 |
+
Initializes the logger and environment variables.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
logging_config_path (str): The path to the logging configuration file. Defaults to "logging.yaml".
|
| 17 |
+
env_file_path (str): The path to the environment variables file. Defaults to ".env".
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
logger.info("Initializing logger...")
|
| 21 |
+
try:
|
| 22 |
+
initialize_logger(config_path=logging_config_path)
|
| 23 |
+
except FileNotFoundError:
|
| 24 |
+
logger.warning(
|
| 25 |
+
f"No logging configuration file found at: {logging_config_path}. Setting logging level to INFO."
|
| 26 |
+
)
|
| 27 |
+
logging.basicConfig(level=logging.INFO)
|
| 28 |
+
|
| 29 |
+
logger.info("Initializing env vars...")
|
| 30 |
+
if env_file_path is None:
|
| 31 |
+
env_file_path = find_dotenv(raise_error_if_not_found=True, usecwd=False)
|
| 32 |
+
|
| 33 |
+
logger.info(f"Loading environment variables from: {env_file_path}")
|
| 34 |
+
found_env_file = load_dotenv(env_file_path, verbose=True, override=True)
|
| 35 |
+
if found_env_file is False:
|
| 36 |
+
raise RuntimeError(f"Could not find environment file at: {env_file_path}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def initialize_logger(
|
| 40 |
+
config_path: str = "logging.yaml", logs_dir_name: str = "logs"
|
| 41 |
+
) -> logging.Logger:
|
| 42 |
+
"""Initialize logger from a YAML config file."""
|
| 43 |
+
|
| 44 |
+
# Create logs directory.
|
| 45 |
+
config_path_parent = Path(config_path).parent
|
| 46 |
+
logs_dir = config_path_parent / logs_dir_name
|
| 47 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
with open(config_path, "rt") as f:
|
| 50 |
+
config = yaml.safe_load(f.read())
|
| 51 |
+
|
| 52 |
+
# Make sure that existing logger will still work.
|
| 53 |
+
config["disable_existing_loggers"] = False
|
| 54 |
+
|
| 55 |
+
logging.config.dictConfig(config)
|
financial_bot/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
financial_bot/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (936 Bytes). View file
|
|
|
financial_bot/__pycache__/chains.cpython-310.pyc
ADDED
|
Binary file (6.98 kB). View file
|
|
|
financial_bot/__pycache__/constants.cpython-310.pyc
ADDED
|
Binary file (720 Bytes). View file
|
|
|
financial_bot/__pycache__/embeddings.cpython-310.pyc
ADDED
|
Binary file (4.37 kB). View file
|
|
|
financial_bot/__pycache__/handlers.cpython-310.pyc
ADDED
|
Binary file (2.59 kB). View file
|
|
|
financial_bot/__pycache__/langchain_bot.cpython-310.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
financial_bot/__pycache__/models.cpython-310.pyc
ADDED
|
Binary file (8.25 kB). View file
|
|
|
financial_bot/__pycache__/qdrant.cpython-310.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
financial_bot/__pycache__/template.cpython-310.pyc
ADDED
|
Binary file (3.84 kB). View file
|
|
|
financial_bot/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
financial_bot/base.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from threading import Lock
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SingletonMeta(type):
|
| 5 |
+
"""
|
| 6 |
+
This is a thread-safe implementation of Singleton.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
_instances = {}
|
| 10 |
+
|
| 11 |
+
_lock: Lock = Lock()
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
We now have a lock object that will be used to synchronize threads during
|
| 15 |
+
first access to the Singleton.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __call__(cls, *args, **kwargs):
|
| 19 |
+
"""
|
| 20 |
+
Possible changes to the value of the `__init__` argument do not affect
|
| 21 |
+
the returned instance.
|
| 22 |
+
"""
|
| 23 |
+
# Now, imagine that the program has just been launched. Since there's no
|
| 24 |
+
# Singleton instance yet, multiple threads can simultaneously pass the
|
| 25 |
+
# previous conditional and reach this point almost at the same time. The
|
| 26 |
+
# first of them will acquire lock and will proceed further, while the
|
| 27 |
+
# rest will wait here.
|
| 28 |
+
with cls._lock:
|
| 29 |
+
# The first thread to acquire the lock, reaches this conditional,
|
| 30 |
+
# goes inside and creates the Singleton instance. Once it leaves the
|
| 31 |
+
# lock block, a thread that might have been waiting for the lock
|
| 32 |
+
# release may then enter this section. But since the Singleton field
|
| 33 |
+
# is already initialized, the thread won't create a new object.
|
| 34 |
+
if cls not in cls._instances:
|
| 35 |
+
instance = super().__call__(*args, **kwargs)
|
| 36 |
+
cls._instances[cls] = instance
|
| 37 |
+
|
| 38 |
+
return cls._instances[cls]
|
financial_bot/chains.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import qdrant_client
|
| 5 |
+
from langchain import chains
|
| 6 |
+
from langchain.callbacks.manager import CallbackManagerForChainRun
|
| 7 |
+
from langchain.chains.base import Chain
|
| 8 |
+
from langchain.llms import HuggingFacePipeline
|
| 9 |
+
from unstructured.cleaners.core import (
|
| 10 |
+
clean,
|
| 11 |
+
clean_extra_whitespace,
|
| 12 |
+
clean_non_ascii_chars,
|
| 13 |
+
group_broken_paragraphs,
|
| 14 |
+
replace_unicode_quotes,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from financial_bot.embeddings import EmbeddingModelSingleton
|
| 18 |
+
from financial_bot.template import PromptTemplate
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class StatelessMemorySequentialChain(chains.SequentialChain):
|
| 22 |
+
"""
|
| 23 |
+
A sequential chain that uses a stateless memory to store context between calls.
|
| 24 |
+
|
| 25 |
+
This chain overrides the _call and prep_outputs methods to load and clear the memory
|
| 26 |
+
before and after each call, respectively.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
history_input_key: str = "to_load_history"
|
| 30 |
+
|
| 31 |
+
def _call(self, inputs: Dict[str, str], **kwargs) -> Dict[str, str]:
|
| 32 |
+
"""
|
| 33 |
+
Override _call to load history before calling the chain.
|
| 34 |
+
|
| 35 |
+
This method loads the history from the input dictionary and saves it to the
|
| 36 |
+
stateless memory. It then updates the inputs dictionary with the memory values
|
| 37 |
+
and removes the history input key. Finally, it calls the parent _call method
|
| 38 |
+
with the updated inputs and returns the results.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
to_load_history = inputs[self.history_input_key]
|
| 42 |
+
for (
|
| 43 |
+
human,
|
| 44 |
+
ai,
|
| 45 |
+
) in to_load_history:
|
| 46 |
+
self.memory.save_context(
|
| 47 |
+
inputs={self.memory.input_key: human},
|
| 48 |
+
outputs={self.memory.output_key: ai},
|
| 49 |
+
)
|
| 50 |
+
memory_values = self.memory.load_memory_variables({})
|
| 51 |
+
inputs.update(memory_values)
|
| 52 |
+
|
| 53 |
+
del inputs[self.history_input_key]
|
| 54 |
+
|
| 55 |
+
return super()._call(inputs, **kwargs)
|
| 56 |
+
|
| 57 |
+
def prep_outputs(
|
| 58 |
+
self,
|
| 59 |
+
inputs: Dict[str, str],
|
| 60 |
+
outputs: Dict[str, str],
|
| 61 |
+
return_only_outputs: bool = False,
|
| 62 |
+
) -> Dict[str, str]:
|
| 63 |
+
"""
|
| 64 |
+
Override prep_outputs to clear the internal memory after each call.
|
| 65 |
+
|
| 66 |
+
This method calls the parent prep_outputs method to get the results, then
|
| 67 |
+
clears the stateless memory and removes the memory key from the results
|
| 68 |
+
dictionary. It then returns the updated results.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
results = super().prep_outputs(inputs, outputs, return_only_outputs)
|
| 72 |
+
|
| 73 |
+
# Clear the internal memory.
|
| 74 |
+
self.memory.clear()
|
| 75 |
+
if self.memory.memory_key in results:
|
| 76 |
+
results[self.memory.memory_key] = ""
|
| 77 |
+
|
| 78 |
+
return results
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ContextExtractorChain(Chain):
|
| 82 |
+
"""
|
| 83 |
+
Encode the question, search the vector store for top-k articles and return
|
| 84 |
+
context news from documents collection of Alpaca news.
|
| 85 |
+
|
| 86 |
+
Attributes:
|
| 87 |
+
-----------
|
| 88 |
+
top_k : int
|
| 89 |
+
The number of top matches to retrieve from the vector store.
|
| 90 |
+
embedding_model : EmbeddingModelSingleton
|
| 91 |
+
The embedding model to use for encoding the question.
|
| 92 |
+
vector_store : qdrant_client.QdrantClient
|
| 93 |
+
The vector store to search for matches.
|
| 94 |
+
vector_collection : str
|
| 95 |
+
The name of the collection to search in the vector store.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
top_k: int = 1
|
| 99 |
+
embedding_model: EmbeddingModelSingleton
|
| 100 |
+
vector_store: qdrant_client.QdrantClient
|
| 101 |
+
vector_collection: str
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def input_keys(self) -> List[str]:
|
| 105 |
+
return ["about_me", "question"]
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def output_keys(self) -> List[str]:
|
| 109 |
+
return ["context"]
|
| 110 |
+
|
| 111 |
+
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 112 |
+
_, quest_key = self.input_keys
|
| 113 |
+
question_str = inputs[quest_key]
|
| 114 |
+
|
| 115 |
+
cleaned_question = self.clean(question_str)
|
| 116 |
+
# TODO: Instead of cutting the question at 'max_input_length', chunk the question in 'max_input_length' chunks,
|
| 117 |
+
# pass them through the model and average the embeddings.
|
| 118 |
+
cleaned_question = cleaned_question[: self.embedding_model.max_input_length]
|
| 119 |
+
embeddings = self.embedding_model(cleaned_question)
|
| 120 |
+
|
| 121 |
+
# TODO: Using the metadata, use the filter to take into consideration only the news from the last 24 hours
|
| 122 |
+
# (or other time frame).
|
| 123 |
+
matches = self.vector_store.search(
|
| 124 |
+
query_vector=embeddings,
|
| 125 |
+
k=self.top_k,
|
| 126 |
+
collection_name=self.vector_collection,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
context = ""
|
| 130 |
+
for match in matches:
|
| 131 |
+
context += match.payload["summary"] + "\n"
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"context": context,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def clean(self, question: str) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Clean the input question by removing unwanted characters.
|
| 140 |
+
|
| 141 |
+
Parameters:
|
| 142 |
+
-----------
|
| 143 |
+
question : str
|
| 144 |
+
The input question to clean.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
--------
|
| 148 |
+
str
|
| 149 |
+
The cleaned question.
|
| 150 |
+
"""
|
| 151 |
+
question = clean(question)
|
| 152 |
+
question = replace_unicode_quotes(question)
|
| 153 |
+
question = clean_non_ascii_chars(question)
|
| 154 |
+
|
| 155 |
+
return question
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class FinancialBotQAChain(Chain):
|
| 159 |
+
"""This custom chain handles LLM generation upon given prompt"""
|
| 160 |
+
|
| 161 |
+
hf_pipeline: HuggingFacePipeline
|
| 162 |
+
template: PromptTemplate
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def input_keys(self) -> List[str]:
|
| 166 |
+
"""Returns a list of input keys for the chain"""
|
| 167 |
+
|
| 168 |
+
return ["context"]
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def output_keys(self) -> List[str]:
|
| 172 |
+
"""Returns a list of output keys for the chain"""
|
| 173 |
+
|
| 174 |
+
return ["answer"]
|
| 175 |
+
|
| 176 |
+
def _call(
|
| 177 |
+
self,
|
| 178 |
+
inputs: Dict[str, Any],
|
| 179 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
| 180 |
+
) -> Dict[str, Any]:
|
| 181 |
+
"""Calls the chain with the given inputs and returns the output"""
|
| 182 |
+
|
| 183 |
+
inputs = self.clean(inputs)
|
| 184 |
+
prompt = self.template.format_infer(
|
| 185 |
+
{
|
| 186 |
+
"user_context": inputs["about_me"],
|
| 187 |
+
"news_context": inputs["context"],
|
| 188 |
+
"chat_history": inputs["chat_history"],
|
| 189 |
+
"question": inputs["question"],
|
| 190 |
+
}
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
start_time = time.time()
|
| 194 |
+
response = self.hf_pipeline(prompt["prompt"])
|
| 195 |
+
end_time = time.time()
|
| 196 |
+
duration_milliseconds = (end_time - start_time) * 1000
|
| 197 |
+
|
| 198 |
+
if run_manager:
|
| 199 |
+
run_manager.on_chain_end(
|
| 200 |
+
outputs={
|
| 201 |
+
"answer": response,
|
| 202 |
+
},
|
| 203 |
+
# TODO: Count tokens instead of using len().
|
| 204 |
+
metadata={
|
| 205 |
+
"prompt": prompt["prompt"],
|
| 206 |
+
"prompt_template_variables": prompt["payload"],
|
| 207 |
+
"prompt_template": self.template.infer_raw_template,
|
| 208 |
+
"usage.prompt_tokens": len(prompt["prompt"]),
|
| 209 |
+
"usage.total_tokens": len(prompt["prompt"]) + len(response),
|
| 210 |
+
"usage.actual_new_tokens": len(response),
|
| 211 |
+
"duration_milliseconds": duration_milliseconds,
|
| 212 |
+
},
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return {"answer": response}
|
| 216 |
+
|
| 217 |
+
def clean(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
| 218 |
+
"""Cleans the inputs by removing extra whitespace and grouping broken paragraphs"""
|
| 219 |
+
|
| 220 |
+
for key, input in inputs.items():
|
| 221 |
+
cleaned_input = clean_extra_whitespace(input)
|
| 222 |
+
cleaned_input = group_broken_paragraphs(cleaned_input)
|
| 223 |
+
|
| 224 |
+
inputs[key] = cleaned_input
|
| 225 |
+
|
| 226 |
+
return inputs
|
financial_bot/constants.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
# == Embeddings model ==
|
| 4 |
+
EMBEDDING_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
| 5 |
+
EMBEDDING_MODEL_MAX_INPUT_LENGTH = 384
|
| 6 |
+
|
| 7 |
+
# == VECTOR Database ==
|
| 8 |
+
VECTOR_DB_OUTPUT_COLLECTION_NAME = "alpaca_financial_news"
|
| 9 |
+
VECTOR_DB_SEARCH_TOPK = 1
|
| 10 |
+
|
| 11 |
+
# == LLM Model ==
|
| 12 |
+
LLM_MODEL_ID = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
|
| 13 |
+
LLM_QLORA_CHECKPOINT = "plantbased/mistral-7b-instruct-v0.2-4bit"
|
| 14 |
+
|
| 15 |
+
LLM_INFERNECE_MAX_NEW_TOKENS = 500
|
| 16 |
+
LLM_INFERENCE_TEMPERATURE = 1.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# == Prompt Template ==
|
| 20 |
+
TEMPLATE_NAME = "mistral"
|
| 21 |
+
|
| 22 |
+
# === Misc ===
|
| 23 |
+
CACHE_DIR = Path.home() / ".cache" / "hands-on-llms"
|
financial_bot/embeddings.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModel, AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from financial_bot import constants
|
| 9 |
+
from financial_bot.base import SingletonMeta
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EmbeddingModelSingleton(metaclass=SingletonMeta):
|
| 15 |
+
"""
|
| 16 |
+
A singleton class that provides a pre-trained transformer model for generating embeddings of input text.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
model_id (str): The identifier of the pre-trained transformer model to use.
|
| 20 |
+
max_input_length (int): The maximum length of input text to tokenize.
|
| 21 |
+
device (str): The device to use for running the model (e.g. "cpu", "cuda").
|
| 22 |
+
cache_dir (Optional[Path]): The directory to cache the pre-trained model files.
|
| 23 |
+
If None, the default cache directory is used.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
max_input_length (int): The maximum length of input text to tokenize.
|
| 27 |
+
tokenizer (AutoTokenizer): The tokenizer used to tokenize input text.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
model_id: str = constants.EMBEDDING_MODEL_ID,
|
| 33 |
+
max_input_length: int = constants.EMBEDDING_MODEL_MAX_INPUT_LENGTH,
|
| 34 |
+
device: str = "cuda:0",
|
| 35 |
+
cache_dir: Optional[str] = None,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initializes the EmbeddingModelSingleton instance.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model_id (str): The identifier of the pre-trained transformer model to use.
|
| 42 |
+
max_input_length (int): The maximum length of input text to tokenize.
|
| 43 |
+
device (str): The device to use for running the model (e.g. "cpu", "cuda").
|
| 44 |
+
cache_dir (Optional[Path]): The directory to cache the pre-trained model files.
|
| 45 |
+
If None, the default cache directory is used.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
self._model_id = model_id
|
| 49 |
+
self._device = device
|
| 50 |
+
self._max_input_length = max_input_length
|
| 51 |
+
|
| 52 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 53 |
+
self._model = AutoModel.from_pretrained(
|
| 54 |
+
model_id,
|
| 55 |
+
cache_dir=str(cache_dir) if cache_dir else None,
|
| 56 |
+
).to(self._device)
|
| 57 |
+
self._model.eval()
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def max_input_length(self) -> int:
|
| 61 |
+
"""
|
| 62 |
+
Returns the maximum length of input text to tokenize.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
int: The maximum length of input text to tokenize.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
return self._max_input_length
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def tokenizer(self) -> AutoTokenizer:
|
| 72 |
+
"""
|
| 73 |
+
Returns the tokenizer used to tokenize input text.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
AutoTokenizer: The tokenizer used to tokenize input text.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
return self._tokenizer
|
| 80 |
+
|
| 81 |
+
def __call__(
|
| 82 |
+
self, input_text: str, to_list: bool = True
|
| 83 |
+
) -> Union[np.ndarray, list]:
|
| 84 |
+
"""
|
| 85 |
+
Generates embeddings for the input text using the pre-trained transformer model.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
input_text (str): The input text to generate embeddings for.
|
| 89 |
+
to_list (bool): Whether to return the embeddings as a list or numpy array. Defaults to True.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Union[np.ndarray, list]: The embeddings generated for the input text.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
tokenized_text = self._tokenizer(
|
| 97 |
+
input_text,
|
| 98 |
+
padding=True,
|
| 99 |
+
truncation=True,
|
| 100 |
+
return_tensors="pt",
|
| 101 |
+
max_length=self._max_input_length,
|
| 102 |
+
).to(self._device)
|
| 103 |
+
except Exception:
|
| 104 |
+
logger.error(traceback.format_exc())
|
| 105 |
+
logger.error(f"Error tokenizing the following input text: {input_text}")
|
| 106 |
+
|
| 107 |
+
return [] if to_list else np.array([])
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
result = self._model(**tokenized_text)
|
| 111 |
+
except Exception:
|
| 112 |
+
logger.error(traceback.format_exc())
|
| 113 |
+
logger.error(
|
| 114 |
+
f"Error generating embeddings for the following model_id: {self._model_id} and input text: {input_text}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return [] if to_list else np.array([])
|
| 118 |
+
|
| 119 |
+
embeddings = result.last_hidden_state[:, 0, :].cpu().detach().numpy()
|
| 120 |
+
if to_list:
|
| 121 |
+
embeddings = embeddings.flatten().tolist()
|
| 122 |
+
|
| 123 |
+
return embeddings
|
financial_bot/handlers.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import comet_llm
|
| 4 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
| 5 |
+
|
| 6 |
+
from financial_bot import constants
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CometLLMMonitoringHandler(BaseCallbackHandler):
|
| 10 |
+
"""
|
| 11 |
+
A callback handler for monitoring LLM models using Comet.ml.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
project_name (str): The name of the Comet.ml project to log to.
|
| 15 |
+
llm_model_id (str): The ID of the LLM model to use for inference.
|
| 16 |
+
llm_qlora_model_id (str): The ID of the PEFT model to use for inference.
|
| 17 |
+
llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference.
|
| 18 |
+
llm_inference_temperature (float): The temperature to use during inference.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
project_name: str = None,
|
| 24 |
+
llm_model_id: str = constants.LLM_MODEL_ID,
|
| 25 |
+
llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
|
| 26 |
+
llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
|
| 27 |
+
llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
|
| 28 |
+
):
|
| 29 |
+
self._project_name = project_name
|
| 30 |
+
self._llm_model_id = llm_model_id
|
| 31 |
+
self._llm_qlora_model_id = llm_qlora_model_id
|
| 32 |
+
self._llm_inference_max_new_tokens = llm_inference_max_new_tokens
|
| 33 |
+
self._llm_inference_temperature = llm_inference_temperature
|
| 34 |
+
|
| 35 |
+
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
| 36 |
+
"""
|
| 37 |
+
A callback function that logs the prompt and output to Comet.ml.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
outputs (Dict[str, Any]): The output of the LLM model.
|
| 41 |
+
**kwargs (Any): Additional arguments passed to the function.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
should_log_prompt = "metadata" in kwargs
|
| 45 |
+
if should_log_prompt:
|
| 46 |
+
metadata = kwargs["metadata"]
|
| 47 |
+
|
| 48 |
+
comet_llm.log_prompt(
|
| 49 |
+
project=self._project_name,
|
| 50 |
+
prompt=metadata["prompt"],
|
| 51 |
+
output=outputs["answer"],
|
| 52 |
+
prompt_template=metadata["prompt_template"],
|
| 53 |
+
prompt_template_variables=metadata["prompt_template_variables"],
|
| 54 |
+
metadata={
|
| 55 |
+
"usage.prompt_tokens": metadata["usage.prompt_tokens"],
|
| 56 |
+
"usage.total_tokens": metadata["usage.total_tokens"],
|
| 57 |
+
"usage.max_new_tokens": self._llm_inference_max_new_tokens,
|
| 58 |
+
"usage.temperature": self._llm_inference_temperature,
|
| 59 |
+
"usage.actual_new_tokens": metadata["usage.actual_new_tokens"],
|
| 60 |
+
"model": self._llm_model_id,
|
| 61 |
+
"peft_model": self._llm_qlora_model_id,
|
| 62 |
+
},
|
| 63 |
+
duration=metadata["duration_milliseconds"],
|
| 64 |
+
)
|
financial_bot/langchain_bot.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Iterable, List, Tuple
|
| 5 |
+
|
| 6 |
+
from langchain import chains
|
| 7 |
+
from langchain.memory import ConversationBufferWindowMemory
|
| 8 |
+
|
| 9 |
+
from financial_bot import constants
|
| 10 |
+
from financial_bot.chains import (
|
| 11 |
+
ContextExtractorChain,
|
| 12 |
+
FinancialBotQAChain,
|
| 13 |
+
StatelessMemorySequentialChain,
|
| 14 |
+
)
|
| 15 |
+
from financial_bot.embeddings import EmbeddingModelSingleton
|
| 16 |
+
from financial_bot.handlers import CometLLMMonitoringHandler
|
| 17 |
+
from financial_bot.models import build_huggingface_pipeline
|
| 18 |
+
from financial_bot.qdrant import build_qdrant_client
|
| 19 |
+
from financial_bot.template import get_llm_template
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FinancialBot:
|
| 25 |
+
"""
|
| 26 |
+
A language chain bot that uses a language model to generate responses to user inputs.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
llm_model_id (str): The ID of the Hugging Face language model to use.
|
| 30 |
+
llm_qlora_model_id (str): The ID of the Hugging Face QLora model to use.
|
| 31 |
+
llm_template_name (str): The name of the LLM template to use.
|
| 32 |
+
llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference.
|
| 33 |
+
llm_inference_temperature (float): The temperature to use during inference.
|
| 34 |
+
vector_collection_name (str): The name of the Qdrant vector collection to use.
|
| 35 |
+
vector_db_search_topk (int): The number of nearest neighbors to search for in the Qdrant vector database.
|
| 36 |
+
model_cache_dir (Path): The directory to use for caching the language model and embedding model.
|
| 37 |
+
streaming (bool): Whether to use the Hugging Face streaming API for inference.
|
| 38 |
+
embedding_model_device (str): The device to use for the embedding model.
|
| 39 |
+
debug (bool): Whether to enable debug mode.
|
| 40 |
+
|
| 41 |
+
Attributes:
|
| 42 |
+
finbot_chain (Chain): The language chain that generates responses to user inputs.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
llm_model_id: str = constants.LLM_MODEL_ID,
|
| 48 |
+
llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
|
| 49 |
+
llm_template_name: str = constants.TEMPLATE_NAME,
|
| 50 |
+
llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
|
| 51 |
+
llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
|
| 52 |
+
vector_collection_name: str = constants.VECTOR_DB_OUTPUT_COLLECTION_NAME,
|
| 53 |
+
vector_db_search_topk: int = constants.VECTOR_DB_SEARCH_TOPK,
|
| 54 |
+
model_cache_dir: Path = constants.CACHE_DIR,
|
| 55 |
+
streaming: bool = False,
|
| 56 |
+
embedding_model_device: str = "cuda:0",
|
| 57 |
+
debug: bool = False,
|
| 58 |
+
):
|
| 59 |
+
self._llm_model_id = llm_model_id
|
| 60 |
+
self._llm_qlora_model_id = llm_qlora_model_id
|
| 61 |
+
self._llm_template_name = llm_template_name
|
| 62 |
+
self._llm_template = get_llm_template(name=self._llm_template_name)
|
| 63 |
+
self._llm_inference_max_new_tokens = llm_inference_max_new_tokens
|
| 64 |
+
self._llm_inference_temperature = llm_inference_temperature
|
| 65 |
+
self._vector_collection_name = vector_collection_name
|
| 66 |
+
self._vector_db_search_topk = vector_db_search_topk
|
| 67 |
+
self._debug = debug
|
| 68 |
+
|
| 69 |
+
self._qdrant_client = build_qdrant_client()
|
| 70 |
+
|
| 71 |
+
self._embd_model = EmbeddingModelSingleton(
|
| 72 |
+
cache_dir=model_cache_dir, device=embedding_model_device
|
| 73 |
+
)
|
| 74 |
+
self._llm_agent, self._streamer = build_huggingface_pipeline(
|
| 75 |
+
llm_model_id=llm_model_id,
|
| 76 |
+
llm_lora_model_id=llm_qlora_model_id,
|
| 77 |
+
max_new_tokens=llm_inference_max_new_tokens,
|
| 78 |
+
temperature=llm_inference_temperature,
|
| 79 |
+
use_streamer=streaming,
|
| 80 |
+
cache_dir=model_cache_dir,
|
| 81 |
+
debug=debug,
|
| 82 |
+
)
|
| 83 |
+
self.finbot_chain = self.build_chain()
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def is_streaming(self) -> bool:
|
| 87 |
+
return self._streamer is not None
|
| 88 |
+
|
| 89 |
+
def build_chain(self) -> chains.SequentialChain:
|
| 90 |
+
"""
|
| 91 |
+
Constructs and returns a financial bot chain.
|
| 92 |
+
This chain is designed to take as input the user description, `about_me` and a `question` and it will
|
| 93 |
+
connect to the VectorDB, searches the financial news that rely on the user's question and injects them into the
|
| 94 |
+
payload that is further passed as a prompt to a financial fine-tuned LLM that will provide answers.
|
| 95 |
+
|
| 96 |
+
The chain consists of two primary stages:
|
| 97 |
+
1. Context Extractor: This stage is responsible for embedding the user's question,
|
| 98 |
+
which means converting the textual question into a numerical representation.
|
| 99 |
+
This embedded question is then used to retrieve relevant context from the VectorDB.
|
| 100 |
+
The output of this chain will be a dict payload.
|
| 101 |
+
|
| 102 |
+
2. LLM Generator: Once the context is extracted,
|
| 103 |
+
this stage uses it to format a full prompt for the LLM and
|
| 104 |
+
then feed it to the model to get a response that is relevant to the user's question.
|
| 105 |
+
|
| 106 |
+
Returns
|
| 107 |
+
-------
|
| 108 |
+
chains.SequentialChain
|
| 109 |
+
The constructed financial bot chain.
|
| 110 |
+
|
| 111 |
+
Notes
|
| 112 |
+
-----
|
| 113 |
+
The actual processing flow within the chain can be visualized as:
|
| 114 |
+
[about: str][question: str] > ContextChain >
|
| 115 |
+
[about: str][question:str] + [context: str] > FinancialChain >
|
| 116 |
+
[answer: str]
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
logger.info("Building 1/3 - ContextExtractorChain")
|
| 120 |
+
context_retrieval_chain = ContextExtractorChain(
|
| 121 |
+
embedding_model=self._embd_model,
|
| 122 |
+
vector_store=self._qdrant_client,
|
| 123 |
+
vector_collection=self._vector_collection_name,
|
| 124 |
+
top_k=self._vector_db_search_topk,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
logger.info("Building 2/3 - FinancialBotQAChain")
|
| 128 |
+
if self._debug:
|
| 129 |
+
callabacks = []
|
| 130 |
+
else:
|
| 131 |
+
try:
|
| 132 |
+
comet_project_name = os.environ["COMET_PROJECT_NAME"]
|
| 133 |
+
except KeyError:
|
| 134 |
+
raise RuntimeError(
|
| 135 |
+
"Please set the COMET_PROJECT_NAME environment variable."
|
| 136 |
+
)
|
| 137 |
+
callabacks = [
|
| 138 |
+
CometLLMMonitoringHandler(
|
| 139 |
+
project_name=f"{comet_project_name}-monitor-prompts",
|
| 140 |
+
llm_model_id=self._llm_model_id,
|
| 141 |
+
llm_qlora_model_id=self._llm_qlora_model_id,
|
| 142 |
+
llm_inference_max_new_tokens=self._llm_inference_max_new_tokens,
|
| 143 |
+
llm_inference_temperature=self._llm_inference_temperature,
|
| 144 |
+
)
|
| 145 |
+
]
|
| 146 |
+
llm_generator_chain = FinancialBotQAChain(
|
| 147 |
+
hf_pipeline=self._llm_agent,
|
| 148 |
+
template=self._llm_template,
|
| 149 |
+
callbacks=callabacks,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
logger.info("Building 3/3 - Connecting chains into SequentialChain")
|
| 153 |
+
seq_chain = StatelessMemorySequentialChain(
|
| 154 |
+
history_input_key="to_load_history",
|
| 155 |
+
memory=ConversationBufferWindowMemory(
|
| 156 |
+
memory_key="chat_history",
|
| 157 |
+
input_key="question",
|
| 158 |
+
output_key="answer",
|
| 159 |
+
k=3,
|
| 160 |
+
),
|
| 161 |
+
chains=[context_retrieval_chain, llm_generator_chain],
|
| 162 |
+
input_variables=["about_me", "question", "to_load_history"],
|
| 163 |
+
output_variables=["answer"],
|
| 164 |
+
verbose=True,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
logger.info("Done building SequentialChain.")
|
| 168 |
+
logger.info("Workflow:")
|
| 169 |
+
logger.info(
|
| 170 |
+
"""
|
| 171 |
+
[about: str][question: str] > ContextChain >
|
| 172 |
+
[about: str][question:str] + [context: str] > FinancialChain >
|
| 173 |
+
[answer: str]
|
| 174 |
+
"""
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
return seq_chain
|
| 178 |
+
|
| 179 |
+
def answer(
|
| 180 |
+
self,
|
| 181 |
+
about_me: str,
|
| 182 |
+
question: str,
|
| 183 |
+
to_load_history: List[Tuple[str, str]] = None,
|
| 184 |
+
) -> str:
|
| 185 |
+
"""
|
| 186 |
+
Given a short description about the user and a question make the LLM
|
| 187 |
+
generate a response.
|
| 188 |
+
|
| 189 |
+
Parameters
|
| 190 |
+
----------
|
| 191 |
+
about_me : str
|
| 192 |
+
Short user description.
|
| 193 |
+
question : str
|
| 194 |
+
User question.
|
| 195 |
+
|
| 196 |
+
Returns
|
| 197 |
+
-------
|
| 198 |
+
str
|
| 199 |
+
LLM generated response.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
inputs = {
|
| 203 |
+
"about_me": about_me,
|
| 204 |
+
"question": question,
|
| 205 |
+
"to_load_history": to_load_history if to_load_history else [],
|
| 206 |
+
}
|
| 207 |
+
response = self.finbot_chain.run(inputs)
|
| 208 |
+
|
| 209 |
+
return response
|
| 210 |
+
|
| 211 |
+
def stream_answer(self) -> Iterable[str]:
|
| 212 |
+
"""Stream the answer from the LLM after each token is generated after calling `answer()`."""
|
| 213 |
+
|
| 214 |
+
assert (
|
| 215 |
+
self.is_streaming
|
| 216 |
+
), "Stream answer not available. Build the bot with `use_streamer=True`."
|
| 217 |
+
|
| 218 |
+
partial_answer = ""
|
| 219 |
+
for new_token in self._streamer:
|
| 220 |
+
if new_token != self._llm_template.eos:
|
| 221 |
+
partial_answer += new_token
|
| 222 |
+
|
| 223 |
+
yield partial_answer
|
financial_bot/models.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from comet_ml import API
|
| 8 |
+
from langchain.llms import HuggingFacePipeline
|
| 9 |
+
from peft import LoraConfig, PeftConfig, PeftModel
|
| 10 |
+
from transformers import (
|
| 11 |
+
AutoModelForCausalLM,
|
| 12 |
+
AutoTokenizer,
|
| 13 |
+
BitsAndBytesConfig,
|
| 14 |
+
StoppingCriteria,
|
| 15 |
+
StoppingCriteriaList,
|
| 16 |
+
TextIteratorStreamer,
|
| 17 |
+
pipeline,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from financial_bot import constants
|
| 21 |
+
from financial_bot.utils import MockedPipeline
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def download_from_model_registry(
|
| 27 |
+
model_id: str, cache_dir: Optional[Path] = None
|
| 28 |
+
) -> Path:
|
| 29 |
+
"""
|
| 30 |
+
Downloads a model from the Comet ML Learning model registry.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model_id (str): The ID of the model to download, in the format "workspace/model_name:version".
|
| 34 |
+
cache_dir (Optional[Path]): The directory to cache the downloaded model in. Defaults to the value of
|
| 35 |
+
`constants.CACHE_DIR`.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Path: The path to the downloaded model directory.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
if cache_dir is None:
|
| 42 |
+
cache_dir = constants.CACHE_DIR
|
| 43 |
+
output_folder = cache_dir / "models" / model_id
|
| 44 |
+
|
| 45 |
+
already_downloaded = output_folder.exists()
|
| 46 |
+
if not already_downloaded:
|
| 47 |
+
workspace, model_id = model_id.split("/")
|
| 48 |
+
model_name, version = model_id.split(":")
|
| 49 |
+
|
| 50 |
+
api = API()
|
| 51 |
+
model = api.get_model(workspace=workspace, model_name=model_name)
|
| 52 |
+
model.download(version=version, output_folder=output_folder, expand=True)
|
| 53 |
+
else:
|
| 54 |
+
logger.info(f"Model {model_id=} already downloaded to: {output_folder}")
|
| 55 |
+
|
| 56 |
+
subdirs = [d for d in output_folder.iterdir() if d.is_dir()]
|
| 57 |
+
if len(subdirs) == 1:
|
| 58 |
+
model_dir = subdirs[0]
|
| 59 |
+
else:
|
| 60 |
+
raise RuntimeError(
|
| 61 |
+
f"There should be only one directory inside the model folder. \
|
| 62 |
+
Check the downloaded model at: {output_folder}"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
logger.info(f"Model {model_id=} downloaded from the registry to: {model_dir}")
|
| 66 |
+
|
| 67 |
+
return model_dir
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class StopOnTokens(StoppingCriteria):
|
| 71 |
+
"""
|
| 72 |
+
A stopping criteria that stops generation when a specific token is generated.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
stop_ids (List[int]): A list of token ids that will trigger the stopping criteria.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, stop_ids: List[int]):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self._stop_ids = stop_ids
|
| 82 |
+
|
| 83 |
+
def __call__(
|
| 84 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
| 85 |
+
) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
Check if the last generated token is in the stop_ids list.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
input_ids (torch.LongTensor): The input token ids.
|
| 91 |
+
scores (torch.FloatTensor): The scores of the generated tokens.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
bool: True if the last generated token is in the stop_ids list, False otherwise.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
for stop_id in self._stop_ids:
|
| 98 |
+
if input_ids[0][-1] == stop_id:
|
| 99 |
+
return True
|
| 100 |
+
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def build_huggingface_pipeline(
|
| 105 |
+
llm_model_id: str,
|
| 106 |
+
llm_lora_model_id: str,
|
| 107 |
+
max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
|
| 108 |
+
temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
|
| 109 |
+
gradient_checkpointing: bool = False,
|
| 110 |
+
use_streamer: bool = False,
|
| 111 |
+
cache_dir: Optional[Path] = None,
|
| 112 |
+
debug: bool = False,
|
| 113 |
+
) -> Tuple[HuggingFacePipeline, Optional[TextIteratorStreamer]]:
|
| 114 |
+
"""
|
| 115 |
+
Builds a HuggingFace pipeline for text generation using a custom LLM + Finetuned checkpoint.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
llm_model_id (str): The ID or path of the LLM model.
|
| 119 |
+
llm_lora_model_id (str): The ID or path of the LLM LoRA model.
|
| 120 |
+
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 128.
|
| 121 |
+
temperature (float, optional): The temperature to use for sampling. Defaults to 0.7.
|
| 122 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
|
| 123 |
+
use_streamer (bool, optional): Whether to use a text iterator streamer. Defaults to False.
|
| 124 |
+
cache_dir (Optional[Path], optional): The directory to use for caching. Defaults to None.
|
| 125 |
+
debug (bool, optional): Whether to use a mocked pipeline for debugging. Defaults to False.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Tuple[HuggingFacePipeline, Optional[TextIteratorStreamer]]: A tuple containing the HuggingFace pipeline
|
| 129 |
+
and the text iterator streamer (if used).
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
if debug is True:
|
| 133 |
+
return (
|
| 134 |
+
HuggingFacePipeline(
|
| 135 |
+
pipeline=MockedPipeline(f=lambda _: "You are doing great!")
|
| 136 |
+
),
|
| 137 |
+
None,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
model, tokenizer, _ = build_qlora_model(
|
| 141 |
+
pretrained_model_name_or_path=llm_model_id,
|
| 142 |
+
peft_pretrained_model_name_or_path=llm_lora_model_id,
|
| 143 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 144 |
+
cache_dir=cache_dir,
|
| 145 |
+
)
|
| 146 |
+
model.eval()
|
| 147 |
+
|
| 148 |
+
if use_streamer:
|
| 149 |
+
streamer = TextIteratorStreamer(
|
| 150 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
| 151 |
+
)
|
| 152 |
+
stop_on_tokens = StopOnTokens(stop_ids=[tokenizer.eos_token_id])
|
| 153 |
+
stopping_criteria = StoppingCriteriaList([stop_on_tokens])
|
| 154 |
+
else:
|
| 155 |
+
streamer = None
|
| 156 |
+
stopping_criteria = StoppingCriteriaList([])
|
| 157 |
+
|
| 158 |
+
pipe = pipeline(
|
| 159 |
+
"text-generation",
|
| 160 |
+
model=model,
|
| 161 |
+
tokenizer=tokenizer,
|
| 162 |
+
max_new_tokens=max_new_tokens,
|
| 163 |
+
temperature=temperature,
|
| 164 |
+
streamer=streamer,
|
| 165 |
+
stopping_criteria=stopping_criteria,
|
| 166 |
+
)
|
| 167 |
+
hf = HuggingFacePipeline(pipeline=pipe)
|
| 168 |
+
|
| 169 |
+
return hf, streamer
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def build_qlora_model(
|
| 173 |
+
pretrained_model_name_or_path: str = "tiiuae/falcon-7b-instruct",
|
| 174 |
+
peft_pretrained_model_name_or_path: Optional[str] = None,
|
| 175 |
+
gradient_checkpointing: bool = True,
|
| 176 |
+
cache_dir: Optional[Path] = None,
|
| 177 |
+
) -> Tuple[AutoModelForCausalLM, AutoTokenizer, PeftConfig]:
|
| 178 |
+
"""
|
| 179 |
+
Function that builds a QLoRA LLM model based on the given HuggingFace name:
|
| 180 |
+
1. Create and prepare the bitsandbytes configuration for QLoRa's quantization
|
| 181 |
+
2. Download, load, and quantize on-the-fly Falcon-7b
|
| 182 |
+
3. Create and prepare the LoRa configuration
|
| 183 |
+
4. Load and configuration Falcon-7B's tokenizer
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
pretrained_model_name_or_path (str): The name or path of the pretrained model to use.
|
| 187 |
+
peft_pretrained_model_name_or_path (Optional[str]): The name or path of the PEFT pretrained model to use.
|
| 188 |
+
gradient_checkpointing (bool): Whether to use gradient checkpointing or not.
|
| 189 |
+
cache_dir (Optional[Path]): The directory to cache the downloaded models.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Tuple[AutoModelForCausalLM, AutoTokenizer, PeftConfig]:
|
| 193 |
+
A tuple containing the QLoRA LLM model, tokenizer, and PEFT config.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
bnb_config = BitsAndBytesConfig(
|
| 197 |
+
load_in_4bit=True,
|
| 198 |
+
bnb_4bit_use_double_quant=True,
|
| 199 |
+
bnb_4bit_quant_type="nf4",
|
| 200 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 204 |
+
pretrained_model_name_or_path,
|
| 205 |
+
revision="main",
|
| 206 |
+
quantization_config=bnb_config,
|
| 207 |
+
load_in_4bit=True,
|
| 208 |
+
device_map="auto",
|
| 209 |
+
trust_remote_code=False,
|
| 210 |
+
cache_dir=str(cache_dir) if cache_dir else None,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 214 |
+
pretrained_model_name_or_path,
|
| 215 |
+
trust_remote_code=False,
|
| 216 |
+
truncation=True,
|
| 217 |
+
cache_dir=str(cache_dir) if cache_dir else None,
|
| 218 |
+
)
|
| 219 |
+
if tokenizer.pad_token_id is None:
|
| 220 |
+
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 223 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 224 |
+
|
| 225 |
+
if peft_pretrained_model_name_or_path:
|
| 226 |
+
is_model_name = not os.path.isdir(peft_pretrained_model_name_or_path)
|
| 227 |
+
if is_model_name:
|
| 228 |
+
logger.info(
|
| 229 |
+
f"Downloading {peft_pretrained_model_name_or_path} from CometML Model Registry:"
|
| 230 |
+
)
|
| 231 |
+
peft_pretrained_model_name_or_path = download_from_model_registry(
|
| 232 |
+
model_id=peft_pretrained_model_name_or_path,
|
| 233 |
+
cache_dir=cache_dir,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
logger.info(f"Loading Lora Confing from: {peft_pretrained_model_name_or_path}")
|
| 237 |
+
lora_config = LoraConfig.from_pretrained(peft_pretrained_model_name_or_path)
|
| 238 |
+
assert (
|
| 239 |
+
lora_config.base_model_name_or_path == pretrained_model_name_or_path
|
| 240 |
+
), f"Lora Model trained on different base model than the one requested: \
|
| 241 |
+
{lora_config.base_model_name_or_path} != {pretrained_model_name_or_path}"
|
| 242 |
+
|
| 243 |
+
logger.info(f"Loading Peft Model from: {peft_pretrained_model_name_or_path}")
|
| 244 |
+
model = PeftModel.from_pretrained(model, peft_pretrained_model_name_or_path)
|
| 245 |
+
else:
|
| 246 |
+
lora_config = LoraConfig(
|
| 247 |
+
lora_alpha=16,
|
| 248 |
+
lora_dropout=0.1,
|
| 249 |
+
r=64,
|
| 250 |
+
bias="none",
|
| 251 |
+
task_type="CAUSAL_LM",
|
| 252 |
+
target_modules=["query_key_value"],
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if gradient_checkpointing:
|
| 256 |
+
model.gradient_checkpointing_enable()
|
| 257 |
+
model.config.use_cache = (
|
| 258 |
+
False # Gradient checkpointing is not compatible with caching.
|
| 259 |
+
)
|
| 260 |
+
else:
|
| 261 |
+
model.gradient_checkpointing_disable()
|
| 262 |
+
model.config.use_cache = True # It is good practice to enable caching when using the model for inference.
|
| 263 |
+
|
| 264 |
+
return model, tokenizer, lora_config
|
financial_bot/qdrant.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import qdrant_client
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_qdrant_client(
|
| 11 |
+
url: Optional[str] = None,
|
| 12 |
+
api_key: Optional[str] = None,
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
Builds a Qdrant client object using the provided URL and API key.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
url (Optional[str]): The URL of the Qdrant server. If not provided, the function will attempt
|
| 19 |
+
to read it from the QDRANT_URL environment variable.
|
| 20 |
+
api_key (Optional[str]): The API key to use for authentication. If not provided, the function will attempt
|
| 21 |
+
to read it from the QDRANT_API_KEY environment variable.
|
| 22 |
+
|
| 23 |
+
Raises:
|
| 24 |
+
KeyError: If the URL or API key is not provided and cannot be read from the environment variables.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
qdrant_client.QdrantClient: A Qdrant client object.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
logger.info("Building QDrant Client")
|
| 31 |
+
if url is None:
|
| 32 |
+
try:
|
| 33 |
+
url = os.environ["QDRANT_URL"]
|
| 34 |
+
except KeyError:
|
| 35 |
+
raise KeyError(
|
| 36 |
+
"QDRANT_URL must be set as environment variable or manually passed as an argument."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if api_key is None:
|
| 40 |
+
try:
|
| 41 |
+
api_key = os.environ["QDRANT_API_KEY"]
|
| 42 |
+
except KeyError:
|
| 43 |
+
raise KeyError(
|
| 44 |
+
"QDRANT_API_KEY must be set as environment variable or manually passed as an argument."
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
client = qdrant_client.QdrantClient(url, api_key=api_key)
|
| 48 |
+
|
| 49 |
+
return client
|
financial_bot/template.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script defines a PromptTemplate class that assists in generating
|
| 3 |
+
conversation/prompt templates. The script facilitates formatting prompts
|
| 4 |
+
for inference and training by combining various context elements and user inputs.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import dataclasses
|
| 9 |
+
from typing import Dict, List, Union
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclasses.dataclass
|
| 13 |
+
class PromptTemplate:
|
| 14 |
+
"""A class that manages prompt templates"""
|
| 15 |
+
|
| 16 |
+
# The name of this template
|
| 17 |
+
name: str
|
| 18 |
+
# The template of the system prompt
|
| 19 |
+
system_template: str = "{system_message}"
|
| 20 |
+
# The template for the system context
|
| 21 |
+
context_template: str = "{user_context}\n{news_context}"
|
| 22 |
+
# The template for the conversation history
|
| 23 |
+
chat_history_template: str = "{chat_history}"
|
| 24 |
+
# The template of the user question
|
| 25 |
+
question_template: str = "{question}"
|
| 26 |
+
# The template of the system answer
|
| 27 |
+
answer_template: str = "{answer}"
|
| 28 |
+
# The system message
|
| 29 |
+
system_message: str = ""
|
| 30 |
+
# Separator
|
| 31 |
+
sep: str = "\n"
|
| 32 |
+
eos: str = "</s>"
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def input_variables(self) -> List[str]:
|
| 36 |
+
"""Returns a list of input variables for the prompt template"""
|
| 37 |
+
|
| 38 |
+
return ["user_context", "news_context", "chat_history", "question", "answer"]
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def train_raw_template(self):
|
| 42 |
+
"""Returns the training prompt template format"""
|
| 43 |
+
|
| 44 |
+
system = self.system_template.format(system_message=self.system_message)
|
| 45 |
+
context = f"{self.sep}{self.context_template}"
|
| 46 |
+
chat_history = f"{self.sep}{self.chat_history_template}"
|
| 47 |
+
question = f"{self.sep}{self.question_template}"
|
| 48 |
+
answer = f"{self.sep}{self.answer_template}"
|
| 49 |
+
|
| 50 |
+
return f"{system}{context}{chat_history}{question}{answer}{self.eos}"
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def infer_raw_template(self):
|
| 54 |
+
"""Returns the inference prompt template format"""
|
| 55 |
+
|
| 56 |
+
system = self.system_template.format(system_message=self.system_message)
|
| 57 |
+
context = f"{self.sep}{self.context_template}"
|
| 58 |
+
chat_history = f"{self.sep}{self.chat_history_template}"
|
| 59 |
+
question = f"{self.sep}{self.question_template}"
|
| 60 |
+
|
| 61 |
+
return f"{system}{context}{chat_history}{question}{self.eos}"
|
| 62 |
+
|
| 63 |
+
def format_train(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
|
| 64 |
+
"""Formats the data sample to a training sample"""
|
| 65 |
+
|
| 66 |
+
prompt = self.train_raw_template.format(
|
| 67 |
+
user_context=sample["user_context"],
|
| 68 |
+
news_context=sample["news_context"],
|
| 69 |
+
chat_history=sample.get("chat_history", ""),
|
| 70 |
+
question=sample["question"],
|
| 71 |
+
answer=sample["answer"],
|
| 72 |
+
)
|
| 73 |
+
return {"prompt": prompt, "payload": sample}
|
| 74 |
+
|
| 75 |
+
def format_infer(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
|
| 76 |
+
"""Formats the data sample to a testing sample"""
|
| 77 |
+
|
| 78 |
+
prompt = self.infer_raw_template.format(
|
| 79 |
+
user_context=sample["user_context"],
|
| 80 |
+
news_context=sample["news_context"],
|
| 81 |
+
chat_history=sample.get("chat_history", ""),
|
| 82 |
+
question=sample["question"],
|
| 83 |
+
)
|
| 84 |
+
return {"prompt": prompt, "payload": sample}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Global Templates registry
|
| 88 |
+
templates: Dict[str, PromptTemplate] = {}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def register_llm_template(template: PromptTemplate):
|
| 92 |
+
"""Register a new template to the global templates registry"""
|
| 93 |
+
|
| 94 |
+
templates[template.name] = template
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_llm_template(name: str) -> PromptTemplate:
|
| 98 |
+
"""Returns the template assigned to the given name"""
|
| 99 |
+
|
| 100 |
+
return templates[name]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
##### Register Templates #####
|
| 104 |
+
# - Mistral 7B Instruct v0.2 Template
|
| 105 |
+
register_llm_template(
|
| 106 |
+
PromptTemplate(
|
| 107 |
+
name="mistral",
|
| 108 |
+
system_template="<s>{system_message}",
|
| 109 |
+
system_message="You are a helpful assistant, with financial expertise.",
|
| 110 |
+
context_template="{user_context}\n{news_context}",
|
| 111 |
+
chat_history_template="Summary: {chat_history}",
|
| 112 |
+
question_template="[INST] {question} [/INST]",
|
| 113 |
+
answer_template="{answer}",
|
| 114 |
+
sep="\n",
|
| 115 |
+
eos=" </s>",
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# - FALCON (spec: https://huggingface.co/tiiuae/falcon-7b/blob/main/tokenizer.json)
|
| 120 |
+
register_llm_template(
|
| 121 |
+
PromptTemplate(
|
| 122 |
+
name="falcon",
|
| 123 |
+
system_template=">>INTRODUCTION<< {system_message}",
|
| 124 |
+
system_message="You are a helpful assistant, with financial expertise.",
|
| 125 |
+
context_template=">>DOMAIN<< {user_context}\n{news_context}",
|
| 126 |
+
chat_history_template=">>SUMMARY<< {chat_history}",
|
| 127 |
+
question_template=">>QUESTION<< {question}",
|
| 128 |
+
answer_template=">>ANSWER<< {answer}",
|
| 129 |
+
sep="\n",
|
| 130 |
+
eos="<|endoftext|>",
|
| 131 |
+
)
|
| 132 |
+
)
|
financial_bot/utils.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
from typing import Callable, Dict, List
|
| 5 |
+
|
| 6 |
+
import psutil
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def log_available_gpu_memory():
|
| 13 |
+
"""
|
| 14 |
+
Logs the available GPU memory for each available GPU device.
|
| 15 |
+
|
| 16 |
+
If no GPUs are available, logs "No GPUs available".
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
None
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
if torch.cuda.is_available():
|
| 23 |
+
for i in range(torch.cuda.device_count()):
|
| 24 |
+
memory_info = subprocess.check_output(
|
| 25 |
+
f"nvidia-smi -i {i} --query-gpu=memory.free --format=csv,nounits,noheader",
|
| 26 |
+
shell=True,
|
| 27 |
+
)
|
| 28 |
+
memory_info = str(memory_info).split("\\")[0][2:]
|
| 29 |
+
|
| 30 |
+
logger.info(f"GPU {i} memory available: {memory_info} MiB")
|
| 31 |
+
else:
|
| 32 |
+
logger.info("No GPUs available")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def log_available_ram():
|
| 36 |
+
"""
|
| 37 |
+
Logs the amount of available RAM in gigabytes.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
None
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
memory_info = psutil.virtual_memory()
|
| 44 |
+
|
| 45 |
+
# convert bytes to GB
|
| 46 |
+
logger.info(f"Available RAM: {memory_info.available / (1024.0 ** 3):.2f} GB")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def log_files_and_subdirs(directory_path: str):
|
| 50 |
+
"""
|
| 51 |
+
Logs all files and subdirectories in the specified directory.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
directory_path (str): The path to the directory to log.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
None
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
# Check if the directory exists
|
| 61 |
+
if os.path.exists(directory_path) and os.path.isdir(directory_path):
|
| 62 |
+
for dirpath, dirnames, filenames in os.walk(directory_path):
|
| 63 |
+
logger.info(f"Directory: {dirpath}")
|
| 64 |
+
for filename in filenames:
|
| 65 |
+
logger.info(f"File: {os.path.join(dirpath, filename)}")
|
| 66 |
+
for dirname in dirnames:
|
| 67 |
+
logger.info(f"Sub-directory: {os.path.join(dirpath, dirname)}")
|
| 68 |
+
else:
|
| 69 |
+
logger.info(f"The directory '{directory_path}' does not exist")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MockedPipeline:
|
| 73 |
+
"""
|
| 74 |
+
A mocked pipeline class that is used as a replacement to the HF pipeline class.
|
| 75 |
+
|
| 76 |
+
Attributes:
|
| 77 |
+
-----------
|
| 78 |
+
task : str
|
| 79 |
+
The task of the pipeline, which is text-generation.
|
| 80 |
+
f : Callable[[str], str]
|
| 81 |
+
A function that takes a prompt string as input and returns a generated text string.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
task: str = "text-generation"
|
| 85 |
+
|
| 86 |
+
def __init__(self, f: Callable[[str], str]):
|
| 87 |
+
self.f = f
|
| 88 |
+
|
| 89 |
+
def __call__(self, prompt: str) -> List[Dict[str, str]]:
|
| 90 |
+
"""
|
| 91 |
+
Calls the pipeline with a given prompt and returns a list of generated text.
|
| 92 |
+
|
| 93 |
+
Parameters:
|
| 94 |
+
-----------
|
| 95 |
+
prompt : str
|
| 96 |
+
The prompt string to generate text from.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
--------
|
| 100 |
+
List[Dict[str, str]]
|
| 101 |
+
A list of dictionaries, where each dictionary contains a generated_text key with the generated text string.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
result = self.f(prompt)
|
| 105 |
+
|
| 106 |
+
return [{"generated_text": f"{prompt}{result}"}]
|