|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from operator import itemgetter |
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
|
from langchain_core.documents import Document |
|
|
from query_expansion import expand_query_simple |
|
|
from typing import List, Optional |
|
|
import time |
|
|
|
|
|
class GroqAPIKeyManager: |
|
|
"""Manages multiple Groq API keys with automatic rotation and fallback.""" |
|
|
|
|
|
def __init__(self, api_keys: List[str]): |
|
|
""" |
|
|
Initialize with a list of API keys. |
|
|
|
|
|
Args: |
|
|
api_keys: List of Groq API keys to use |
|
|
""" |
|
|
self.api_keys = [key for key in api_keys if key and key != "your_groq_api_key_here"] |
|
|
if not self.api_keys: |
|
|
raise ValueError("No valid API keys provided!") |
|
|
|
|
|
self.current_index = 0 |
|
|
self.failed_keys = set() |
|
|
self.success_count = {key: 0 for key in self.api_keys} |
|
|
self.failure_count = {key: 0 for key in self.api_keys} |
|
|
|
|
|
print(f"π API Key Manager: Loaded {len(self.api_keys)} API keys") |
|
|
|
|
|
def get_current_key(self) -> str: |
|
|
"""Get the current API key.""" |
|
|
return self.api_keys[self.current_index] |
|
|
|
|
|
def mark_success(self, api_key: str): |
|
|
"""Mark an API key as successful.""" |
|
|
if api_key in self.success_count: |
|
|
self.success_count[api_key] += 1 |
|
|
|
|
|
if api_key in self.failed_keys: |
|
|
self.failed_keys.remove(api_key) |
|
|
print(f" β
API Key #{self.api_keys.index(api_key) + 1} recovered!") |
|
|
|
|
|
def mark_failure(self, api_key: str): |
|
|
"""Mark an API key as failed.""" |
|
|
if api_key in self.failure_count: |
|
|
self.failure_count[api_key] += 1 |
|
|
self.failed_keys.add(api_key) |
|
|
|
|
|
def rotate_to_next_key(self) -> bool: |
|
|
""" |
|
|
Rotate to the next available API key. |
|
|
|
|
|
Returns: |
|
|
True if a new key is available, False if all keys failed |
|
|
""" |
|
|
initial_index = self.current_index |
|
|
attempts = 0 |
|
|
|
|
|
while attempts < len(self.api_keys): |
|
|
self.current_index = (self.current_index + 1) % len(self.api_keys) |
|
|
attempts += 1 |
|
|
|
|
|
current_key = self.api_keys[self.current_index] |
|
|
|
|
|
|
|
|
if attempts >= len(self.api_keys): |
|
|
print(f" β οΈ All keys attempted, retrying with key #{self.current_index + 1}") |
|
|
return True |
|
|
|
|
|
|
|
|
if current_key not in self.failed_keys: |
|
|
print(f" π Switching to API Key #{self.current_index + 1}") |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def get_statistics(self) -> str: |
|
|
"""Get statistics about API key usage.""" |
|
|
stats = [] |
|
|
for i, key in enumerate(self.api_keys): |
|
|
success = self.success_count[key] |
|
|
failure = self.failure_count[key] |
|
|
status = "β FAILED" if key in self.failed_keys else "β
ACTIVE" |
|
|
masked_key = key[:8] + "..." + key[-4:] if len(key) > 12 else "***" |
|
|
stats.append(f" Key #{i+1} ({masked_key}): {success} success, {failure} failures [{status}]") |
|
|
return "\n".join(stats) |
|
|
|
|
|
|
|
|
def load_api_keys_from_hf_secrets() -> List[str]: |
|
|
""" |
|
|
Load API keys from Hugging Face Spaces Secrets. |
|
|
|
|
|
In your Hugging Face Space settings, add these secrets: |
|
|
- GROQ_API_KEY_1 |
|
|
- GROQ_API_KEY_2 |
|
|
- GROQ_API_KEY_3 |
|
|
- GROQ_API_KEY_4 |
|
|
|
|
|
Returns: |
|
|
List of API keys retrieved from HF secrets |
|
|
""" |
|
|
api_keys = [] |
|
|
secret_names = ['GROQ_API_KEY_1', 'GROQ_API_KEY_2', 'GROQ_API_KEY_3', 'GROQ_API_KEY_4'] |
|
|
|
|
|
print("π Loading API keys from Hugging Face Secrets...") |
|
|
|
|
|
for secret_name in secret_names: |
|
|
try: |
|
|
|
|
|
api_key = os.getenv(secret_name) |
|
|
|
|
|
if api_key and api_key.strip() and api_key != "your_groq_api_key_here": |
|
|
api_keys.append(api_key.strip()) |
|
|
print(f" β
Loaded: {secret_name}") |
|
|
else: |
|
|
print(f" β οΈ Not found or empty: {secret_name}") |
|
|
except Exception as e: |
|
|
print(f" β Error loading {secret_name}: {str(e)}") |
|
|
|
|
|
|
|
|
return api_keys |
|
|
|
|
|
|
|
|
def create_llm_with_fallback( |
|
|
api_key_manager: GroqAPIKeyManager, |
|
|
model_name: str, |
|
|
temperature: float, |
|
|
max_retries: int = 3 |
|
|
) -> ChatGroq: |
|
|
""" |
|
|
Create a ChatGroq LLM with automatic API key fallback. |
|
|
|
|
|
Args: |
|
|
api_key_manager: Manager handling multiple API keys |
|
|
model_name: Name of the model to use |
|
|
temperature: Temperature setting |
|
|
max_retries: Maximum number of retry attempts |
|
|
|
|
|
Returns: |
|
|
ChatGroq instance |
|
|
""" |
|
|
for attempt in range(max_retries): |
|
|
current_key = api_key_manager.get_current_key() |
|
|
|
|
|
try: |
|
|
llm = ChatGroq( |
|
|
model_name=model_name, |
|
|
api_key=current_key, |
|
|
temperature=temperature |
|
|
) |
|
|
|
|
|
test_result = llm.invoke("test") |
|
|
api_key_manager.mark_success(current_key) |
|
|
return llm |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e).lower() |
|
|
api_key_manager.mark_failure(current_key) |
|
|
|
|
|
|
|
|
if "rate" in error_msg or "limit" in error_msg: |
|
|
print(f" β οΈ Rate limit hit on API Key #{api_key_manager.current_index + 1}") |
|
|
elif "auth" in error_msg or "api" in error_msg: |
|
|
print(f" β Authentication failed on API Key #{api_key_manager.current_index + 1}") |
|
|
else: |
|
|
print(f" β Error with API Key #{api_key_manager.current_index + 1}: {str(e)[:50]}") |
|
|
|
|
|
|
|
|
if attempt < max_retries - 1: |
|
|
if api_key_manager.rotate_to_next_key(): |
|
|
print(f" π Retrying with next API key (Attempt {attempt + 2}/{max_retries})...") |
|
|
time.sleep(1) |
|
|
else: |
|
|
raise ValueError("All API keys failed!") |
|
|
else: |
|
|
raise ValueError(f"Failed to initialize LLM after {max_retries} attempts") |
|
|
|
|
|
raise ValueError("Failed to create LLM with any available API key") |
|
|
|
|
|
|
|
|
def create_multi_query_retriever(base_retriever, llm, strategy: str = "balanced"): |
|
|
"""Wraps a base retriever with query expansion capabilities.""" |
|
|
def multi_query_retrieve(query: str) -> List[Document]: |
|
|
"""Retrieves documents using expanded query variations.""" |
|
|
query_variations = expand_query_simple(query, strategy=strategy, llm=llm) |
|
|
all_docs = [] |
|
|
seen_content = set() |
|
|
for i, query_var in enumerate(query_variations): |
|
|
try: |
|
|
docs = base_retriever.invoke(query_var) |
|
|
for doc in docs: |
|
|
content_hash = hash(doc.page_content) |
|
|
if content_hash not in seen_content: |
|
|
seen_content.add(content_hash) |
|
|
all_docs.append(doc) |
|
|
except Exception as e: |
|
|
print(f" β Query Expansion Error (Query {i+1}): {str(e)[:50]}") |
|
|
continue |
|
|
print(f" π Query Expansion: Retrieved {len(all_docs)} unique documents.") |
|
|
return all_docs |
|
|
return multi_query_retrieve |
|
|
|
|
|
|
|
|
def get_system_prompt(temperature: float) -> str: |
|
|
""" |
|
|
Returns a system prompt dynamically based on temperature setting. |
|
|
|
|
|
Temperature ranges: |
|
|
- 0.0-0.4: Highly factual, structured, conservative |
|
|
- 0.4-0.8: Balanced approach with moderate creativity |
|
|
- 0.8-1.0: Creative, engaging, storytelling mode |
|
|
""" |
|
|
|
|
|
if temperature <= 0.4: |
|
|
|
|
|
return """You are CogniChat, an expert document analysis assistant specializing in comprehensive and well-structured answers. |
|
|
|
|
|
RESPONSE GUIDELINES: |
|
|
|
|
|
**Structure & Formatting:** |
|
|
- Start with a direct answer to the question |
|
|
- Use **bold** for key terms, important concepts, and technical terminology |
|
|
- Use bullet points (β’) for lists, features, or multiple items |
|
|
- Use numbered lists (1., 2., 3.) for steps, procedures, or sequential information |
|
|
- Use ### Headers to organize different sections or topics |
|
|
- Add blank lines between sections for readability |
|
|
|
|
|
**Source Citation:** |
|
|
- Always cite information using: [Source: filename, Page: X] and cite it at the end of the entire answer only |
|
|
- Place citations at the end of your final answer only |
|
|
- Do not cite sources within the body of your answer |
|
|
- Multiple sources: [Source: doc1.pdf, Page: 3; doc2.pdf, Page: 7] |
|
|
|
|
|
**Completeness:** |
|
|
- Provide thorough, detailed answers using ALL relevant information from context |
|
|
- Summarize and properly elaborate each point for increased clarity |
|
|
- If the question has multiple parts, address each part clearly |
|
|
|
|
|
**Accuracy:** |
|
|
- ONLY use information from the provided context documents below |
|
|
- If information is incomplete, state what IS available and what ISN'T |
|
|
- If the answer isn't in the context, clearly state: "I cannot find this information in the uploaded documents" |
|
|
- Never make assumptions or add information not in the context |
|
|
|
|
|
--- |
|
|
|
|
|
{context} |
|
|
|
|
|
--- |
|
|
|
|
|
Now answer the following question comprehensively using the context above:""" |
|
|
|
|
|
elif temperature <= 0.8: |
|
|
|
|
|
return """You are CogniChat, an intelligent document analysis assistant that combines accuracy with engaging communication. |
|
|
|
|
|
RESPONSE GUIDELINES: |
|
|
|
|
|
**Communication Style:** |
|
|
- Present information in a clear, engaging manner |
|
|
- Use **bold** for emphasis on important concepts |
|
|
- Balance structure with natural flow |
|
|
- Make complex topics accessible and interesting |
|
|
|
|
|
**Content Approach:** |
|
|
- Ground your response firmly in the provided context |
|
|
- Add helpful explanations and connections between concepts |
|
|
- Use analogies or examples when they help clarify ideas (but keep them brief) |
|
|
- Organize information logically with headers (###) and lists where appropriate |
|
|
|
|
|
**Source Attribution:** |
|
|
- Cite sources at the end: [Source: filename, Page: X] |
|
|
- Be transparent about what the documents do and don't contain |
|
|
|
|
|
**Accuracy:** |
|
|
- Base your answer on the context documents provided |
|
|
- If information is partial, explain what's available |
|
|
- Acknowledge gaps: "The documents don't cover this aspect" |
|
|
|
|
|
--- |
|
|
|
|
|
{context} |
|
|
|
|
|
--- |
|
|
|
|
|
Now answer the following question in an engaging yet accurate way:""" |
|
|
|
|
|
else: |
|
|
|
|
|
return """You are CogniChat, a creative document analyst who makes complex information clear, memorable, and engaging. |
|
|
|
|
|
π― YOUR CORE MISSION: **CLARITY FIRST, CREATIVITY SECOND** |
|
|
|
|
|
Make information easier to understand, not harder. Your creativity should illuminate, not obscure. |
|
|
|
|
|
**CREATIVE CLARITY PRINCIPLES:** |
|
|
|
|
|
1. **Simplify, Don't Complicate** |
|
|
- Break down complex concepts into simple, digestible parts |
|
|
- Use everyday language alongside technical terms |
|
|
- Explain jargon immediately in plain English |
|
|
- Short sentences for complex ideas, varied length for rhythm |
|
|
|
|
|
2. **Smart Use of Examples & Analogies** (Use Sparingly!) |
|
|
- Only use analogies when they genuinely make something clearer |
|
|
- Keep analogies simple and relatable (everyday objects/experiences) |
|
|
- Never use metaphors that require explanation themselves |
|
|
- If you can explain it directly in simple terms, do that instead |
|
|
|
|
|
3. **Engaging Structure** |
|
|
- Start with the core answer in one clear sentence |
|
|
- Use **bold** to highlight key takeaways |
|
|
- Break information into logical chunks with ### headers |
|
|
- Use bullet points for clarity, not decoration |
|
|
- Add brief transition phrases to connect ideas smoothly |
|
|
|
|
|
4. **Conversational Yet Precise** |
|
|
- Write like you're explaining to a smart friend |
|
|
- Use "you" and active voice to engage readers |
|
|
- Ask occasional rhetorical questions only if they aid understanding |
|
|
- Vary sentence length to maintain interest |
|
|
- Use emojis sparingly (1-2 max) and only where they add clarity |
|
|
|
|
|
5. **Visual Clarity** |
|
|
- Strategic use of formatting: **bold** for key terms, *italics* for emphasis |
|
|
- White space between sections for easy scanning |
|
|
- Progressive disclosure: simple concepts first, details after |
|
|
- Numbered lists for sequences, bullets for related items |
|
|
|
|
|
**WHAT TO AVOID:** |
|
|
- β Flowery or overly descriptive language |
|
|
- β Complex metaphors that need their own explanation |
|
|
- β Long narrative storytelling that buries the facts |
|
|
- β Multiple rhetorical questions in a row |
|
|
- β Overuse of emojis or exclamation points |
|
|
- β Making simple things sound complicated |
|
|
|
|
|
**ACCURACY BOUNDARIES:** |
|
|
- β
Creative explanation and presentation of facts |
|
|
- β
Simple, helpful examples from common knowledge |
|
|
- β
Reorganizing information for better understanding |
|
|
- β Never invent facts not in the documents |
|
|
- β Don't contradict source material |
|
|
- β If info is missing, say so clearly and briefly |
|
|
|
|
|
**Source Attribution:** |
|
|
- End with: [Source: filename, Page: X] |
|
|
- Keep it simple and clear |
|
|
|
|
|
--- |
|
|
|
|
|
{context} |
|
|
|
|
|
--- |
|
|
|
|
|
Now, explain the answer clearly and engagingly. Remember: if your grandmother couldn't understand it, simplify more:""" |
|
|
|
|
|
|
|
|
|
|
|
def create_rag_chain( |
|
|
retriever, |
|
|
get_session_history_func, |
|
|
enable_query_expansion=True, |
|
|
expansion_strategy="balanced", |
|
|
model_name: str = "moonshotai/kimi-k2-instruct", |
|
|
temperature: float = 0.2, |
|
|
api_keys: Optional[List[str]] = None |
|
|
): |
|
|
""" |
|
|
Creates an advanced RAG chain with temperature-adaptive prompting and API key rotation. |
|
|
|
|
|
Args: |
|
|
retriever: Document retriever |
|
|
get_session_history_func: Function to get session history |
|
|
enable_query_expansion: Whether to enable query expansion |
|
|
expansion_strategy: Strategy for query expansion |
|
|
model_name: Name of the LLM model |
|
|
temperature: Temperature setting (0.0-1.0) |
|
|
api_keys: Optional list of API keys. If None, loads from environment |
|
|
""" |
|
|
|
|
|
|
|
|
if api_keys is None: |
|
|
api_keys = load_api_keys_from_hf_secrets() |
|
|
|
|
|
if not api_keys: |
|
|
raise ValueError( |
|
|
"No valid API keys found! Please set GROQ_API_KEY or GROQ_API_KEY_1, " |
|
|
"GROQ_API_KEY_2, GROQ_API_KEY_3, GROQ_API_KEY_4 in your .env file" |
|
|
) |
|
|
|
|
|
|
|
|
api_key_manager = GroqAPIKeyManager(api_keys) |
|
|
|
|
|
print(f"βοΈ RAG: Initializing LLM - Model: {model_name}, Temp: {temperature}") |
|
|
|
|
|
|
|
|
if temperature <= 0.4: |
|
|
creativity_mode = "FACTUAL & STRUCTURED" |
|
|
elif temperature <= 0.8: |
|
|
creativity_mode = "BALANCED & ENGAGING" |
|
|
else: |
|
|
creativity_mode = "CREATIVE & STORYTELLING" |
|
|
print(f"π Creativity Mode: {creativity_mode}") |
|
|
|
|
|
|
|
|
llm = create_llm_with_fallback(api_key_manager, model_name, temperature) |
|
|
print(f"β
LLM initialized with API Key #{api_key_manager.current_index + 1}") |
|
|
|
|
|
if enable_query_expansion: |
|
|
print(f"β¨ RAG: Query Expansion ENABLED (Strategy: {expansion_strategy})") |
|
|
enhanced_retriever = create_multi_query_retriever( |
|
|
base_retriever=retriever, |
|
|
llm=llm, |
|
|
strategy=expansion_strategy |
|
|
) |
|
|
else: |
|
|
enhanced_retriever = retriever |
|
|
|
|
|
rewrite_template = """You are an expert at optimizing search queries for document retrieval. |
|
|
|
|
|
Given the conversation history and a follow-up question, create a comprehensive standalone question that: |
|
|
1. Incorporates all relevant context from the chat history |
|
|
2. Expands abbreviations and resolves all pronouns (it, they, this, that, etc.) |
|
|
3. Includes key technical terms and concepts that would help find relevant documents |
|
|
4. Maintains the original intent, specificity, and detail level |
|
|
5. If the question asks for comparison or multiple items, ensure all items are in the query |
|
|
|
|
|
Chat History: |
|
|
{chat_history} |
|
|
|
|
|
Follow-up Question: {question} |
|
|
|
|
|
Optimized Standalone Question:""" |
|
|
rewrite_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", rewrite_template), |
|
|
MessagesPlaceholder(variable_name="chat_history"), |
|
|
("human", "{question}") |
|
|
]) |
|
|
query_rewriter = rewrite_prompt | llm | StrOutputParser() |
|
|
|
|
|
def format_docs(docs): |
|
|
"""Format retrieved documents with clear structure and metadata.""" |
|
|
if not docs: |
|
|
return "No relevant documents found in the knowledge base." |
|
|
|
|
|
formatted_parts = [] |
|
|
for i, doc in enumerate(docs, 1): |
|
|
source = doc.metadata.get('source', 'Unknown Document') |
|
|
page = doc.metadata.get('page', 'N/A') |
|
|
rerank_score = doc.metadata.get('rerank_score') |
|
|
content = doc.page_content.strip() |
|
|
|
|
|
doc_header = f"{'='*60}\nDOCUMENT {i}\n{'='*60}" |
|
|
metadata_line = f"Source: {source} | Page: {page}" |
|
|
if rerank_score: |
|
|
metadata_line += f" | Relevance: {rerank_score:.3f}" |
|
|
|
|
|
formatted_parts.append( |
|
|
f"{doc_header}\n" |
|
|
f"{metadata_line}\n" |
|
|
f"{'-'*60}\n" |
|
|
f"{content}\n" |
|
|
) |
|
|
return f"RETRIEVED CONTEXT ({len(docs)} documents):\n\n" + "\n".join(formatted_parts) |
|
|
|
|
|
|
|
|
rag_template = get_system_prompt(temperature) |
|
|
|
|
|
rag_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", rag_template), |
|
|
MessagesPlaceholder(variable_name="chat_history"), |
|
|
("human", "{question}"), |
|
|
]) |
|
|
|
|
|
|
|
|
rewriter_input = RunnableParallel({ |
|
|
"question": itemgetter("question"), |
|
|
"chat_history": itemgetter("chat_history"), |
|
|
}) |
|
|
|
|
|
|
|
|
retrieval_chain = rewriter_input | query_rewriter | enhanced_retriever | format_docs |
|
|
|
|
|
|
|
|
conversational_rag_chain = RunnableParallel({ |
|
|
"context": retrieval_chain, |
|
|
"question": itemgetter("question"), |
|
|
"chat_history": itemgetter("chat_history"), |
|
|
}) | rag_prompt | llm | StrOutputParser() |
|
|
|
|
|
chain_with_memory = RunnableWithMessageHistory( |
|
|
conversational_rag_chain, |
|
|
get_session_history_func, |
|
|
input_messages_key="question", |
|
|
history_messages_key="chat_history", |
|
|
) |
|
|
|
|
|
print("β
RAG: Chain created successfully.") |
|
|
print("\n" + api_key_manager.get_statistics()) |
|
|
|
|
|
return chain_with_memory, api_key_manager |