File size: 3,745 Bytes
becc8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from dotenv import load_dotenv
from operator import itemgetter
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory

def create_rag_chain(base_retriever, get_session_history_func, embedding_model, store):
    """
    Creates a dictionary of RAG chain components for inspection and a final runnable chain.
    """
    load_dotenv()
    api_key = os.getenv("GROQ_API_KEY")
    if not api_key or api_key == "your_groq_api_key_here":
        raise ValueError("GROQ_API_KEY not found or not configured properly.")

    llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=0.1)

    # 1. HyDE-like Document Generation Chain
    hyde_template = """As a document expert, write a concise, fact-based paragraph that directly answers the user's question. This will be used for a database search.
    Question: {question}
    Hypothetical Answer:"""
    hyde_prompt = ChatPromptTemplate.from_template(hyde_template)
    hyde_chain = hyde_prompt | llm | StrOutputParser()

    # 2. Query Rewriting Chain
    rewrite_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question that is optimized for a vector database.

    **Chat History:**
    {chat_history}

    **Follow-up Question:**
    {question}

    **Standalone Question:**"""
    rewrite_prompt = ChatPromptTemplate.from_messages([
        ("system", rewrite_template),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "Reformulate this question as a standalone query: {question}")
    ])
    query_rewriter_chain = rewrite_prompt | llm | StrOutputParser()

    # 3. Parent Document Fetching Chain
    def get_parents(docs):
        parent_ids = {d.metadata.get("doc_id") for d in docs}
        return store.mget(list(parent_ids))

    parent_fetcher_chain = RunnableLambda(get_parents)

    # 4. Main Conversational RAG Chain
    rag_template = """You are CogniChat, an expert document analysis assistant. Your task is to answer the user's question based *only* on the provided context.

    **Instructions:**
    1. Read the context carefully.
    2. If the answer is in the context, provide a clear and concise answer.
    3. If the answer is not in the context, you *must* state that you cannot find the information in the provided documents. Do not use any external knowledge.
    4. Where appropriate, use formatting like lists or bold text to improve readability.

    **Context:**
    {context}
    """
    rag_prompt = ChatPromptTemplate.from_messages([
        ("system", rag_template),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{question}"),
    ])

    conversational_rag_chain = (
        RunnablePassthrough.assign(
            context=query_rewriter_chain | hyde_chain | base_retriever | parent_fetcher_chain
        )
        | rag_prompt
        | llm
        | StrOutputParser()
    )
    
    # 5. Final Chain with History (Simplified)
    final_chain = RunnableWithMessageHistory(
        conversational_rag_chain,
        get_session_history_func,
        input_messages_key="question",
        history_messages_key="chat_history",
    )
    
    print("\n✅ RAG chain and components successfully built.")
    
    return {
        "rewriter": query_rewriter_chain,
        "hyde": hyde_chain,
        "base_retriever": base_retriever,
        "parent_fetcher": parent_fetcher_chain,
        "final_chain": final_chain
    }