from dotenv import load_dotenv load_dotenv() import os import asyncio import logging from langchain_community.document_loaders import TextLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_openai import ChatOpenAI from langchain_community.llms import CTransformers from langchain_core.prompts import PromptTemplate from transformers import pipeline from langchain_huggingface import HuggingFacePipeline from rich import print as rprint from worker import scrape_website logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN") DEFAULT_MODEL = "TheBloke/Llama-2-7B-Chat-GGML" EMBEDDING_MODEL = "BAAI/bge-small-en" # -------------------- Document Preparation -------------------- async def prepare_document(url: str | list[str]): if isinstance(url, str): folder = f"{url[8:].replace('.', '-').split('/')[0]}" cache_path = os.path.join("cache", folder, "pages") else: folder = f"{url[0][8:].replace('.', '-').split('/')[0]}" cache_path = os.path.join("cache", f"list_{folder}", "pages") os.makedirs(cache_path, exist_ok=True) if not os.path.exists(f"{cache_path}/page_1.txt"): logging.info("Document not found. Scraping website...") await scrape_website(url, cache_path) logging.info("Scraping completed.") return cache_path # -------------------- Embedding -------------------- def get_embedding_model(embedding_model_name="", api_key=""): # Use OpenAI if api_key provided or model name indicates OpenAI if api_key or "openai" in embedding_model_name.lower(): if not api_key: raise ValueError("OpenAI API key required for OpenAI embeddings") from langchain_openai import OpenAIEmbeddings return OpenAIEmbeddings(model="text-embedding-3-small", api_key=api_key) # Use HuggingFace otherwise else: # Ensure HF token is set in env for this thread HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN return HuggingFaceEmbeddings(model_name=embedding_model_name or EMBEDDING_MODEL) # -------------------- Process & Build Vector Store -------------------- def process_documents(file_path: str, embedding_model, chunk_size=500, chunk_overlap=100): try: cache_path = os.path.dirname(file_path) faiss_path = f"{cache_path}/faiss_index_store" if os.path.exists(faiss_path): logging.info("FAISS index exists. Skipping rebuild.") return documents = [] for file in os.listdir(f"{cache_path}/pages"): doc_loader = TextLoader(os.path.join(cache_path, "pages", file), encoding="utf-8") documents.extend(doc_loader.load()) logging.info(f"Loaded {len(documents)} pages") text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) chunks = text_splitter.split_documents(documents) vector_db = FAISS.from_documents(chunks, embedding_model) vector_db.save_local(faiss_path) logging.info("FAISS store saved successfully") except Exception as e: logging.error(f"Error in document processing: {e}") # -------------------- Load Retriever -------------------- async def load_retriever(file_path: str, embedding_model_name="", api_key=""): cache_path = os.path.dirname(file_path) embedding_model = get_embedding_model(embedding_model_name, api_key) faiss_path = f"{cache_path}/faiss_index_store" if not os.path.exists(faiss_path): logging.warning("FAISS index missing. Rebuilding...") process_documents(file_path, embedding_model) vector_db = FAISS.load_local(faiss_path, embedding_model, allow_dangerous_deserialization=True) return vector_db.as_retriever(search_kwargs={"k": 3}) # -------------------- Build Custom QA Pipeline -------------------- async def build_pipeline(url: str | list, llm_model="", embedding_model="", api_key=""): # Force default model if llm_model is empty or 'default' if not llm_model or llm_model.lower() == "default": llm_model = DEFAULT_MODEL logging.info(f"[LLM] Using model: {llm_model}") file_path = await prepare_document(url) retriever = await load_retriever(file_path, embedding_model, api_key) llm_model_lower = llm_model.lower() # OpenAI LLM if "openai" in llm_model_lower: llm = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=api_key) # GGML model elif llm_model_lower.endswith("-ggml"): llm = CTransformers(model=llm_model, model_type="llama", config={"context_length": 4096}) # Hugging Face PyTorch model else: try: hf_pipeline = pipeline( "text-generation", model=llm_model, use_auth_token=HUGGINGFACEHUB_API_TOKEN ) llm = HuggingFacePipeline(pipeline=hf_pipeline) except Exception as e: logging.error(f"Failed to load Hugging Face model '{llm_model}'. Error: {e}") raise RuntimeError(f"Cannot load Hugging Face model: {e}") prompt = PromptTemplate( input_variables=["context", "question"], template="You are a helpful assistant. Use the following context to answer.\n\nContext:\n{context}\n\nQuestion: {question}\n\nAnswer:" ) return llm, retriever, prompt class Chatbot: def __init__(self, url: str | list, llm_model="", embedding_model="", api_key=""): self.url = url self.llm_model = llm_model self.embedding_model = embedding_model self.api_key = api_key async def initialize(self): self.llm, self.retriever, self.prompt = await build_pipeline( self.url, self.llm_model, self.embedding_model, self.api_key ) async def query(self, question: str): # Use async method if available if hasattr(self.retriever, "aretrieve"): docs = await self.retriever.aretrieve(question) else: # fallback: call the private method with run_manager=None docs = await asyncio.to_thread(self.retriever._get_relevant_documents, question, run_manager=None) context = "\n\n".join([d.page_content for d in docs]) prompt_text = self.prompt.format(context=context, question=question) response = await asyncio.to_thread(self.llm.invoke, prompt_text) return response # -------------------- Example Runner -------------------- async def main(): url = input("Enter URL: ").strip() query = input("Enter your question: ").strip() bot = Chatbot([url]) await bot.initialize() answer = await bot.query(query) rprint(f"\n[bold cyan]=== Answer ===[/bold cyan]\n{answer}") if __name__ == "__main__": asyncio.run(main())