Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| from contextlib import contextmanager | |
| from langchain.schema import Document | |
| from langgraph.graph import END, StateGraph | |
| from langchain_core.runnables.graph import MermaidDrawMethod | |
| from tomlkit import document | |
| from typing_extensions import TypedDict | |
| from typing import List | |
| from IPython.display import display, HTML, Image | |
| from celsius_csrd_chatbot.chains.esrs_categorization import ( | |
| make_esrs_categorization_node, | |
| ) | |
| from celsius_csrd_chatbot.chains.esrs_intent import ( | |
| make_esrs_intent_node, | |
| ) | |
| from celsius_csrd_chatbot.chains.retriever import make_retriever_node | |
| from celsius_csrd_chatbot.chains.answer_rag import make_rag_node | |
| class GraphState(TypedDict): | |
| """ | |
| Represents the state of our graph. | |
| """ | |
| query: str | |
| esrs_type: str | |
| answer: str | |
| documents: List[Document] | |
| def route_intent(state): | |
| esrs = state["esrs_type"] | |
| if esrs == "none": | |
| return "intent_esrs" | |
| elif esrs == "wrong_esrs": | |
| return "answer_rag_wrong" | |
| else: | |
| return "retrieve_documents" | |
| def make_graph_agent(llm, vectorstore): | |
| workflow = StateGraph(GraphState) | |
| # Define the node functions | |
| categorize_esrs = make_esrs_categorization_node() | |
| intent_esrs = make_esrs_intent_node(llm) | |
| retrieve_documents = make_retriever_node(vectorstore) | |
| answer_rag = make_rag_node(llm, wrong_esrs=False) | |
| answer_rag_wrong = make_rag_node(llm, wrong_esrs=True) | |
| # Define the nodes | |
| workflow.add_node("categorize_esrs", categorize_esrs) | |
| workflow.add_node("intent_esrs", intent_esrs) | |
| workflow.add_node("retrieve_documents", retrieve_documents) | |
| workflow.add_node("answer_rag", answer_rag) | |
| workflow.add_node("answer_rag_wrong", answer_rag_wrong) | |
| # Entry point | |
| workflow.set_entry_point("categorize_esrs") | |
| # CONDITIONAL EDGES | |
| workflow.add_conditional_edges("categorize_esrs", route_intent) | |
| # Define the edges | |
| workflow.add_edge("intent_esrs", "retrieve_documents") | |
| workflow.add_edge("retrieve_documents", "answer_rag") | |
| workflow.add_edge("answer_rag", END) | |
| workflow.add_edge("answer_rag_wrong", END) | |
| # Compile | |
| app = workflow.compile() | |
| return app | |
| def display_graph(app): | |
| display( | |
| Image( | |
| app.get_graph(xray=True).draw_mermaid_png( | |
| draw_method=MermaidDrawMethod.API, | |
| ) | |
| ) | |
| ) | |