from typing import Dict, Any from langchain_openai import ChatOpenAI from langchain.prompts import ChatPromptTemplate from langchain.schema import StrOutputParser from scripts.rag_chat import build_general_qa_chain def build_router_chain(model_name=None): general_qa = build_general_qa_chain(model_name=model_name) llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0) # This prompt asks the LLM to choose which "mode" to use router_prompt = ChatPromptTemplate.from_template(""" You are a routing assistant for a chatbot. Classify the following user request into one of these categories: - "code" for programming or debugging - "summarize" for summary requests - "calculate" for math or numeric calculations - "general" for general Q&A using course files Return ONLY the category word. User request: {input} """) router_chain = router_prompt | llm | StrOutputParser() class Router: def invoke(self, input_dict: Dict[str, Any]): category = router_chain.invoke({"input": input_dict["input"]}).strip().lower() print(f"[ROUTER] User query routed to category: {category}") if category == "code": prompt = ChatPromptTemplate.from_template( "As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:" ) chain = prompt | llm | StrOutputParser() return {"result": chain.invoke({"input": input_dict["input"]})} # elif category == "summarize": # prompt = ChatPromptTemplate.from_template( # "Provide a concise summary about: {input}\nSummary:" # ) # chain = prompt | llm | StrOutputParser() # return {"result": chain.invoke({"input": input_dict["input"]})} elif category == "summarize": # 1. Use RAG to retrieve relevant docs rag_result = general_qa({"query": input_dict["input"]}) # 2. Extract docs and prepare text source_docs = rag_result.get("source_documents", []) combined_text = "\n\n".join([doc.page_content for doc in source_docs]) # 3. Run the summarizer chain on the retrieved text from scripts.summarizer import get_summarizer summarizer_chain = get_summarizer() summary = summarizer_chain.run(combined_text) # 4. Add sources if any sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs}) if sources: summary += f"\n\nšŸ“š Sources: {', '.join(sources)}" return {"result": summary} elif category == "calculate": prompt = ChatPromptTemplate.from_template( "Solve the following calculation step-by-step:\n{input}" ) chain = prompt | llm | StrOutputParser() return {"result": chain.invoke({"input": input_dict["input"]})} else: # "general" return general_qa({"query": input_dict["input"]}) return Router()