from typing import Literal, Union from langchain_mongodb import MongoDBAtlasVectorSearch from langchain_mongodb.retrievers import ( MongoDBAtlasHybridSearchRetriever, MongoDBAtlasParentDocumentRetriever, ) from loguru import logger from second_brain_online.config import settings from .embeddings import EmbeddingModelType, EmbeddingsModel, get_embedding_model from .splitters import get_splitter # Add these type definitions at the top of the file RetrieverType = Literal["contextual", "parent"] RetrieverModel = Union[ MongoDBAtlasHybridSearchRetriever, MongoDBAtlasParentDocumentRetriever ] def get_retriever( embedding_model_id: str, embedding_model_type: EmbeddingModelType = "huggingface", retriever_type: RetrieverType = "contextual", k: int = 3, device: str = "cpu", ) -> RetrieverModel: logger.info( f"Getting '{retriever_type}' retriever for '{embedding_model_type}' - '{embedding_model_id}' on '{device}' " f"with {k} top results" ) embedding_model = get_embedding_model( embedding_model_id, embedding_model_type, device ) if retriever_type == "contextual": return get_hybrid_search_retriever(embedding_model, k) elif retriever_type == "parent": return get_parent_document_retriever(embedding_model, k) else: raise ValueError(f"Invalid retriever type: {retriever_type}") def get_hybrid_search_retriever( embedding_model: EmbeddingsModel, k: int ) -> MongoDBAtlasHybridSearchRetriever: vectorstore = MongoDBAtlasVectorSearch.from_connection_string( connection_string=settings.MONGODB_URI, embedding=embedding_model, namespace=f"{settings.MONGODB_DATABASE_NAME}.{settings.MONGODB_COLLECTION_NAME}", text_key="chunk", embedding_key="embedding", relevance_score_fn="dotProduct", ) retriever = MongoDBAtlasHybridSearchRetriever( vectorstore=vectorstore, search_index_name="chunk_text_search", top_k=k, vector_penalty=50, fulltext_penalty=50, ) return retriever def get_parent_document_retriever( embedding_model: EmbeddingsModel, k: int = 3 ) -> MongoDBAtlasParentDocumentRetriever: retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string( connection_string=settings.MONGODB_URI, embedding_model=embedding_model, child_splitter=get_splitter(200), parent_splitter=get_splitter(800), database_name=settings.MONGODB_DATABASE_NAME, collection_name=settings.MONGODB_COLLECTION_NAME, text_key="chunk", search_kwargs={"k": k}, ) return retriever