Spaces:
Runtime error
Runtime error
| import argparse | |
| import logging | |
| from pathlib import Path | |
| from threading import Thread | |
| from typing import List | |
| import os | |
| import gradio as gr | |
| logger = logging.getLogger(__name__) | |
| COMET_API_KEY = os.getenv("COMET_API_KEY") | |
| COMET_WORKSPACE = os.getenv("COMET_WORKSPACE") | |
| COMET_PROJECT_NAME = os.getenv("COMET_PROJECT_NAME") | |
| QDRANT_URL = os.getenv("QDRANT_URL") | |
| QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
| def parseargs() -> argparse.Namespace: | |
| """ | |
| Parses command line arguments for the Financial Assistant Bot. | |
| Returns: | |
| argparse.Namespace: An object containing the parsed arguments. | |
| """ | |
| parser = argparse.ArgumentParser(description="Financial Assistant Bot") | |
| parser.add_argument( | |
| "--env-file-path", | |
| type=str, | |
| default=".env", | |
| help="Path to the environment file", | |
| ) | |
| parser.add_argument( | |
| "--logging-config-path", | |
| type=str, | |
| default="logging.yaml", | |
| help="Path to the logging configuration file", | |
| ) | |
| parser.add_argument( | |
| "--model-cache-dir", | |
| type=str, | |
| default="./model_cache", | |
| help="Path to the directory where the model cache will be stored", | |
| ) | |
| parser.add_argument( | |
| "--embedding-model-device", | |
| type=str, | |
| default="cuda:0", | |
| help="Device to use for the embedding model (e.g. 'cpu', 'cuda:0', etc.)", | |
| ) | |
| parser.add_argument( | |
| "--debug", | |
| action="store_true", | |
| default=False, | |
| help="Enable debug mode", | |
| ) | |
| return parser.parse_args() | |
| args = parseargs() | |
| # === Load Bot === | |
| def load_bot( | |
| # env_file_path: str = ".env", | |
| logging_config_path: str = "logging.yaml", | |
| model_cache_dir: str = "./model_cache", | |
| embedding_model_device: str = "cuda:0", | |
| debug: bool = False, | |
| ): | |
| """ | |
| Load the financial assistant bot in production or development mode based on the `debug` flag | |
| In DEV mode the embedding model runs on CPU and the fine-tuned LLM is mocked. | |
| Otherwise, the embedding model runs on GPU and the fine-tuned LLM is used. | |
| Args: | |
| env_file_path (str): Path to the environment file. | |
| logging_config_path (str): Path to the logging configuration file. | |
| model_cache_dir (str): Path to the directory where the model cache is stored. | |
| embedding_model_device (str): Device to use for the embedding model. | |
| debug (bool): Flag to indicate whether to run the bot in debug mode or not. | |
| Returns: | |
| FinancialBot: An instance of the FinancialBot class. | |
| """ | |
| from financial_bot import initialize | |
| # Be sure to initialize the environment variables before importing any other modules. | |
| # initialize(logging_config_path=logging_config_path, env_file_path=env_file_path) | |
| initialize(logging_config_path=logging_config_path) | |
| from financial_bot import utils | |
| from financial_bot.langchain_bot import FinancialBot | |
| logger.info("#" * 100) | |
| utils.log_available_gpu_memory() | |
| utils.log_available_ram() | |
| logger.info("#" * 100) | |
| bot = FinancialBot( | |
| model_cache_dir=Path(model_cache_dir) if model_cache_dir else None, | |
| embedding_model_device=embedding_model_device, | |
| streaming=True, | |
| debug=debug, | |
| ) | |
| return bot | |
| bot = load_bot( | |
| # env_file_path=args.env_file_path, | |
| logging_config_path=args.logging_config_path, | |
| model_cache_dir=args.model_cache_dir, | |
| embedding_model_device=args.embedding_model_device, | |
| debug=args.debug, | |
| ) | |
| # === Gradio Interface === | |
| def predict(message: str, history: List[List[str]], about_me: str) -> str: | |
| """ | |
| Predicts a response to a given message using the financial_bot Gradio UI. | |
| Args: | |
| message (str): The message to generate a response for. | |
| history (List[List[str]]): A list of previous conversations. | |
| about_me (str): A string describing the user. | |
| Returns: | |
| str: The generated response. | |
| """ | |
| generate_kwargs = { | |
| "about_me": about_me, | |
| "question": message, | |
| "to_load_history": history, | |
| } | |
| if bot.is_streaming: | |
| t = Thread(target=bot.answer, kwargs=generate_kwargs) | |
| t.start() | |
| for partial_answer in bot.stream_answer(): | |
| yield partial_answer | |
| else: | |
| yield bot.answer(**generate_kwargs) | |
| demo = gr.ChatInterface( | |
| predict, | |
| textbox=gr.Textbox( | |
| placeholder="Ask me a financial question", | |
| label="Financial Question", | |
| container=False, | |
| scale=7, | |
| ), | |
| additional_inputs=[ | |
| gr.Textbox( | |
| "I am a 30 year old graphic designer. I want to invest in something with potential for high returns.", | |
| label="About Me", | |
| ) | |
| ], | |
| title="Friendly Financial Bot 🤑", | |
| description="Ask me any financial or crypto market questions, and I will do my best to provide useful insight. My advice is based on current \ | |
| finance news, stored as embeddings in a **Qdrant** vector db. I run on a 4bit quantized **Mistral-7B-Instruct-v0.2** model with a QLoRa \ | |
| adapter fine-tuned for providing financial guidance. Some sample questions and additional background scenarios are below. \ | |
| **Advice is strictly for demonstration purposes**", | |
| theme="soft", | |
| examples=[ | |
| [ | |
| "How are gene therapy stocks performing?", | |
| "I am a risk-averse 40 year old and would like to avoid risky investments.", | |
| ], | |
| [ | |
| "How is NVDA performing, and is it a wise investment?", | |
| "I'm a 45 year old interested in cryptocurrency and AI.", | |
| ], | |
| [ | |
| "Do you think investing in Boeing is a good idea right now?", | |
| "I'm a 31 year old pilot. I'm curious about the potential of investing in certain airlines.", | |
| ], | |
| [ | |
| "What's your opinion on investing in the Chinese stock market?", | |
| "I am a risk-averse 40 year old and would like to avoid risky investments.", | |
| ], | |
| ], | |
| cache_examples=False, | |
| retry_btn=None, | |
| undo_btn=None, | |
| clear_btn="Clear", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |