Spaces:
Sleeping
Sleeping
| from typing import Dict, Any | |
| from langchain_openai import ChatOpenAI | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema import StrOutputParser | |
| from scripts.rag_chat import build_general_qa_chain | |
| def build_router_chain(model_name=None): | |
| general_qa = build_general_qa_chain(model_name=model_name) | |
| llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0) | |
| # This prompt asks the LLM to choose which "mode" to use | |
| router_prompt = ChatPromptTemplate.from_template(""" | |
| You are a routing assistant for a chatbot. | |
| Classify the following user request into one of these categories: | |
| - "code" for programming or debugging | |
| - "summarize" for summary requests | |
| - "calculate" for math or numeric calculations | |
| - "general" for general Q&A using course files | |
| Return ONLY the category word. | |
| User request: {input} | |
| """) | |
| router_chain = router_prompt | llm | StrOutputParser() | |
| class Router: | |
| def invoke(self, input_dict: Dict[str, Any]): | |
| category = router_chain.invoke({"input": input_dict["input"]}).strip().lower() | |
| print(f"[ROUTER] User query routed to category: {category}") | |
| if category == "code": | |
| prompt = ChatPromptTemplate.from_template( | |
| "As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:" | |
| ) | |
| chain = prompt | llm | StrOutputParser() | |
| return {"result": chain.invoke({"input": input_dict["input"]})} | |
| # elif category == "summarize": | |
| # prompt = ChatPromptTemplate.from_template( | |
| # "Provide a concise summary about: {input}\nSummary:" | |
| # ) | |
| # chain = prompt | llm | StrOutputParser() | |
| # return {"result": chain.invoke({"input": input_dict["input"]})} | |
| elif category == "summarize": | |
| # 1. Use RAG to retrieve relevant docs | |
| rag_result = general_qa({"query": input_dict["input"]}) | |
| # 2. Extract docs and prepare text | |
| source_docs = rag_result.get("source_documents", []) | |
| combined_text = "\n\n".join([doc.page_content for doc in source_docs]) | |
| # 3. Run the summarizer chain on the retrieved text | |
| from scripts.summarizer import get_summarizer | |
| summarizer_chain = get_summarizer() | |
| summary = summarizer_chain.run(combined_text) | |
| # 4. Add sources if any | |
| sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs}) | |
| if sources: | |
| summary += f"\n\n📚 Sources: {', '.join(sources)}" | |
| return {"result": summary} | |
| elif category == "calculate": | |
| prompt = ChatPromptTemplate.from_template( | |
| "Solve the following calculation step-by-step:\n{input}" | |
| ) | |
| chain = prompt | llm | StrOutputParser() | |
| return {"result": chain.invoke({"input": input_dict["input"]})} | |
| else: # "general" | |
| return general_qa({"query": input_dict["input"]}) | |
| return Router() | |