chinmayjha's picture
Deploy complete Second Brain AI Assistant with custom UI
b27eb78
raw
history blame
2.68 kB
from typing import Literal, Union
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.retrievers import (
MongoDBAtlasHybridSearchRetriever,
MongoDBAtlasParentDocumentRetriever,
)
from loguru import logger
from second_brain_online.config import settings
from .embeddings import EmbeddingModelType, EmbeddingsModel, get_embedding_model
from .splitters import get_splitter
# Add these type definitions at the top of the file
RetrieverType = Literal["contextual", "parent"]
RetrieverModel = Union[
MongoDBAtlasHybridSearchRetriever, MongoDBAtlasParentDocumentRetriever
]
def get_retriever(
embedding_model_id: str,
embedding_model_type: EmbeddingModelType = "huggingface",
retriever_type: RetrieverType = "contextual",
k: int = 3,
device: str = "cpu",
) -> RetrieverModel:
logger.info(
f"Getting '{retriever_type}' retriever for '{embedding_model_type}' - '{embedding_model_id}' on '{device}' "
f"with {k} top results"
)
embedding_model = get_embedding_model(
embedding_model_id, embedding_model_type, device
)
if retriever_type == "contextual":
return get_hybrid_search_retriever(embedding_model, k)
elif retriever_type == "parent":
return get_parent_document_retriever(embedding_model, k)
else:
raise ValueError(f"Invalid retriever type: {retriever_type}")
def get_hybrid_search_retriever(
embedding_model: EmbeddingsModel, k: int
) -> MongoDBAtlasHybridSearchRetriever:
vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=settings.MONGODB_URI,
embedding=embedding_model,
namespace=f"{settings.MONGODB_DATABASE_NAME}.{settings.MONGODB_COLLECTION_NAME}",
text_key="chunk",
embedding_key="embedding",
relevance_score_fn="dotProduct",
)
retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=vectorstore,
search_index_name="chunk_text_search",
top_k=k,
vector_penalty=50,
fulltext_penalty=50,
)
return retriever
def get_parent_document_retriever(
embedding_model: EmbeddingsModel, k: int = 3
) -> MongoDBAtlasParentDocumentRetriever:
retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
connection_string=settings.MONGODB_URI,
embedding_model=embedding_model,
child_splitter=get_splitter(200),
parent_splitter=get_splitter(800),
database_name=settings.MONGODB_DATABASE_NAME,
collection_name=settings.MONGODB_COLLECTION_NAME,
text_key="chunk",
search_kwargs={"k": k},
)
return retriever