Spaces:
Sleeping
Sleeping
File size: 2,675 Bytes
b27eb78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import Literal, Union
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.retrievers import (
MongoDBAtlasHybridSearchRetriever,
MongoDBAtlasParentDocumentRetriever,
)
from loguru import logger
from second_brain_online.config import settings
from .embeddings import EmbeddingModelType, EmbeddingsModel, get_embedding_model
from .splitters import get_splitter
# Add these type definitions at the top of the file
RetrieverType = Literal["contextual", "parent"]
RetrieverModel = Union[
MongoDBAtlasHybridSearchRetriever, MongoDBAtlasParentDocumentRetriever
]
def get_retriever(
embedding_model_id: str,
embedding_model_type: EmbeddingModelType = "huggingface",
retriever_type: RetrieverType = "contextual",
k: int = 3,
device: str = "cpu",
) -> RetrieverModel:
logger.info(
f"Getting '{retriever_type}' retriever for '{embedding_model_type}' - '{embedding_model_id}' on '{device}' "
f"with {k} top results"
)
embedding_model = get_embedding_model(
embedding_model_id, embedding_model_type, device
)
if retriever_type == "contextual":
return get_hybrid_search_retriever(embedding_model, k)
elif retriever_type == "parent":
return get_parent_document_retriever(embedding_model, k)
else:
raise ValueError(f"Invalid retriever type: {retriever_type}")
def get_hybrid_search_retriever(
embedding_model: EmbeddingsModel, k: int
) -> MongoDBAtlasHybridSearchRetriever:
vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=settings.MONGODB_URI,
embedding=embedding_model,
namespace=f"{settings.MONGODB_DATABASE_NAME}.{settings.MONGODB_COLLECTION_NAME}",
text_key="chunk",
embedding_key="embedding",
relevance_score_fn="dotProduct",
)
retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=vectorstore,
search_index_name="chunk_text_search",
top_k=k,
vector_penalty=50,
fulltext_penalty=50,
)
return retriever
def get_parent_document_retriever(
embedding_model: EmbeddingsModel, k: int = 3
) -> MongoDBAtlasParentDocumentRetriever:
retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
connection_string=settings.MONGODB_URI,
embedding_model=embedding_model,
child_splitter=get_splitter(200),
parent_splitter=get_splitter(800),
database_name=settings.MONGODB_DATABASE_NAME,
collection_name=settings.MONGODB_COLLECTION_NAME,
text_key="chunk",
search_kwargs={"k": k},
)
return retriever
|