File size: 2,675 Bytes
b27eb78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Literal, Union

from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.retrievers import (
    MongoDBAtlasHybridSearchRetriever,
    MongoDBAtlasParentDocumentRetriever,
)
from loguru import logger

from second_brain_online.config import settings

from .embeddings import EmbeddingModelType, EmbeddingsModel, get_embedding_model
from .splitters import get_splitter

# Add these type definitions at the top of the file
RetrieverType = Literal["contextual", "parent"]
RetrieverModel = Union[
    MongoDBAtlasHybridSearchRetriever, MongoDBAtlasParentDocumentRetriever
]


def get_retriever(
    embedding_model_id: str,
    embedding_model_type: EmbeddingModelType = "huggingface",
    retriever_type: RetrieverType = "contextual",
    k: int = 3,
    device: str = "cpu",
) -> RetrieverModel:
    logger.info(
        f"Getting '{retriever_type}' retriever for '{embedding_model_type}' - '{embedding_model_id}' on '{device}' "
        f"with {k} top results"
    )

    embedding_model = get_embedding_model(
        embedding_model_id, embedding_model_type, device
    )

    if retriever_type == "contextual":
        return get_hybrid_search_retriever(embedding_model, k)
    elif retriever_type == "parent":
        return get_parent_document_retriever(embedding_model, k)
    else:
        raise ValueError(f"Invalid retriever type: {retriever_type}")


def get_hybrid_search_retriever(
    embedding_model: EmbeddingsModel, k: int
) -> MongoDBAtlasHybridSearchRetriever:
    vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
        connection_string=settings.MONGODB_URI,
        embedding=embedding_model,
        namespace=f"{settings.MONGODB_DATABASE_NAME}.{settings.MONGODB_COLLECTION_NAME}",
        text_key="chunk",
        embedding_key="embedding",
        relevance_score_fn="dotProduct",
    )

    retriever = MongoDBAtlasHybridSearchRetriever(
        vectorstore=vectorstore,
        search_index_name="chunk_text_search",
        top_k=k,
        vector_penalty=50,
        fulltext_penalty=50,
    )

    return retriever


def get_parent_document_retriever(
    embedding_model: EmbeddingsModel, k: int = 3
) -> MongoDBAtlasParentDocumentRetriever:
    retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
        connection_string=settings.MONGODB_URI,
        embedding_model=embedding_model,
        child_splitter=get_splitter(200),
        parent_splitter=get_splitter(800),
        database_name=settings.MONGODB_DATABASE_NAME,
        collection_name=settings.MONGODB_COLLECTION_NAME,
        text_key="chunk",
        search_kwargs={"k": k},
    )

    return retriever