Cogni-Chat-document-reader-v2 / rag_processor.py
riteshraut
fix/new update
becc8f7
raw
history blame
3.75 kB
import os
from dotenv import load_dotenv
from operator import itemgetter
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
def create_rag_chain(base_retriever, get_session_history_func, embedding_model, store):
"""
Creates a dictionary of RAG chain components for inspection and a final runnable chain.
"""
load_dotenv()
api_key = os.getenv("GROQ_API_KEY")
if not api_key or api_key == "your_groq_api_key_here":
raise ValueError("GROQ_API_KEY not found or not configured properly.")
llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=0.1)
# 1. HyDE-like Document Generation Chain
hyde_template = """As a document expert, write a concise, fact-based paragraph that directly answers the user's question. This will be used for a database search.
Question: {question}
Hypothetical Answer:"""
hyde_prompt = ChatPromptTemplate.from_template(hyde_template)
hyde_chain = hyde_prompt | llm | StrOutputParser()
# 2. Query Rewriting Chain
rewrite_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question that is optimized for a vector database.
**Chat History:**
{chat_history}
**Follow-up Question:**
{question}
**Standalone Question:**"""
rewrite_prompt = ChatPromptTemplate.from_messages([
("system", rewrite_template),
MessagesPlaceholder(variable_name="chat_history"),
("human", "Reformulate this question as a standalone query: {question}")
])
query_rewriter_chain = rewrite_prompt | llm | StrOutputParser()
# 3. Parent Document Fetching Chain
def get_parents(docs):
parent_ids = {d.metadata.get("doc_id") for d in docs}
return store.mget(list(parent_ids))
parent_fetcher_chain = RunnableLambda(get_parents)
# 4. Main Conversational RAG Chain
rag_template = """You are CogniChat, an expert document analysis assistant. Your task is to answer the user's question based *only* on the provided context.
**Instructions:**
1. Read the context carefully.
2. If the answer is in the context, provide a clear and concise answer.
3. If the answer is not in the context, you *must* state that you cannot find the information in the provided documents. Do not use any external knowledge.
4. Where appropriate, use formatting like lists or bold text to improve readability.
**Context:**
{context}
"""
rag_prompt = ChatPromptTemplate.from_messages([
("system", rag_template),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
])
conversational_rag_chain = (
RunnablePassthrough.assign(
context=query_rewriter_chain | hyde_chain | base_retriever | parent_fetcher_chain
)
| rag_prompt
| llm
| StrOutputParser()
)
# 5. Final Chain with History (Simplified)
final_chain = RunnableWithMessageHistory(
conversational_rag_chain,
get_session_history_func,
input_messages_key="question",
history_messages_key="chat_history",
)
print("\n✅ RAG chain and components successfully built.")
return {
"rewriter": query_rewriter_chain,
"hyde": hyde_chain,
"base_retriever": base_retriever,
"parent_fetcher": parent_fetcher_chain,
"final_chain": final_chain
}