|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from operator import itemgetter |
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
|
|
|
|
def create_rag_chain(base_retriever, get_session_history_func, embedding_model, store): |
|
|
""" |
|
|
Creates a dictionary of RAG chain components for inspection and a final runnable chain. |
|
|
""" |
|
|
load_dotenv() |
|
|
api_key = os.getenv("GROQ_API_KEY") |
|
|
if not api_key or api_key == "your_groq_api_key_here": |
|
|
raise ValueError("GROQ_API_KEY not found or not configured properly.") |
|
|
|
|
|
llm = ChatGroq(model_name="llama-3.1-8b-instant", api_key=api_key, temperature=0.1) |
|
|
|
|
|
|
|
|
hyde_template = """As a document expert, write a concise, fact-based paragraph that directly answers the user's question. This will be used for a database search. |
|
|
Question: {question} |
|
|
Hypothetical Answer:""" |
|
|
hyde_prompt = ChatPromptTemplate.from_template(hyde_template) |
|
|
hyde_chain = hyde_prompt | llm | StrOutputParser() |
|
|
|
|
|
|
|
|
rewrite_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question that is optimized for a vector database. |
|
|
|
|
|
**Chat History:** |
|
|
{chat_history} |
|
|
|
|
|
**Follow-up Question:** |
|
|
{question} |
|
|
|
|
|
**Standalone Question:**""" |
|
|
rewrite_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", rewrite_template), |
|
|
MessagesPlaceholder(variable_name="chat_history"), |
|
|
("human", "Reformulate this question as a standalone query: {question}") |
|
|
]) |
|
|
query_rewriter_chain = rewrite_prompt | llm | StrOutputParser() |
|
|
|
|
|
|
|
|
def get_parents(docs): |
|
|
parent_ids = {d.metadata.get("doc_id") for d in docs} |
|
|
return store.mget(list(parent_ids)) |
|
|
|
|
|
parent_fetcher_chain = RunnableLambda(get_parents) |
|
|
|
|
|
|
|
|
rag_template = """You are CogniChat, an expert document analysis assistant. Your task is to answer the user's question based *only* on the provided context. |
|
|
|
|
|
**Instructions:** |
|
|
1. Read the context carefully. |
|
|
2. If the answer is in the context, provide a clear and concise answer. |
|
|
3. If the answer is not in the context, you *must* state that you cannot find the information in the provided documents. Do not use any external knowledge. |
|
|
4. Where appropriate, use formatting like lists or bold text to improve readability. |
|
|
|
|
|
**Context:** |
|
|
{context} |
|
|
""" |
|
|
rag_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", rag_template), |
|
|
MessagesPlaceholder(variable_name="chat_history"), |
|
|
("human", "{question}"), |
|
|
]) |
|
|
|
|
|
conversational_rag_chain = ( |
|
|
RunnablePassthrough.assign( |
|
|
context=query_rewriter_chain | hyde_chain | base_retriever | parent_fetcher_chain |
|
|
) |
|
|
| rag_prompt |
|
|
| llm |
|
|
| StrOutputParser() |
|
|
) |
|
|
|
|
|
|
|
|
final_chain = RunnableWithMessageHistory( |
|
|
conversational_rag_chain, |
|
|
get_session_history_func, |
|
|
input_messages_key="question", |
|
|
history_messages_key="chat_history", |
|
|
) |
|
|
|
|
|
print("\n✅ RAG chain and components successfully built.") |
|
|
|
|
|
return { |
|
|
"rewriter": query_rewriter_chain, |
|
|
"hyde": hyde_chain, |
|
|
"base_retriever": base_retriever, |
|
|
"parent_fetcher": parent_fetcher_chain, |
|
|
"final_chain": final_chain |
|
|
} |