SCR_Course_ChatBot / scripts /router_chain.py
MaryamKarimi080's picture
Update scripts/router_chain.py
dfda80f verified
raw
history blame
1.56 kB
from typing import Dict, Any
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from scripts.rag_chat import build_general_qa_chain
def build_router_chain(model_name=None):
general_qa = build_general_qa_chain(model_name=model_name)
llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0)
class Router:
def invoke(self, input_dict: Dict[str, Any]):
text = input_dict.get("input", "").lower()
if "code" in text or "program" in text or "debug" in text:
prompt = ChatPromptTemplate.from_template(
"As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:"
)
chain = prompt | llm
return {"result": chain.invoke({"input": input_dict["input"]}).content}
elif "summarize" in text or "summary" in text:
prompt = ChatPromptTemplate.from_template(
"Provide a concise summary about: {input}\nSummary:"
)
chain = prompt | llm
return {"result": chain.invoke({"input": input_dict["input"]}).content}
elif "calculate" in text or any(char.isdigit() for char in text):
return {"result": "For calculations, please ask a specific calculation or provide more context."}
else:
# Use RAG chain
result = general_qa({"query": input_dict["input"]})
return result
return Router()