Spaces:
Paused
Paused
| from typing import Dict, TypedDict, Annotated, Sequence | |
| from langgraph.graph import Graph, StateGraph, END | |
| from langgraph.prebuilt import ToolExecutor | |
| from langchain.schema import StrOutputParser | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| import models | |
| import prompts | |
| from helper_functions import format_docs | |
| from operator import itemgetter | |
| # Define the state structure | |
| class State(TypedDict): | |
| messages: Sequence[str] | |
| research_data: Dict[str, str] | |
| draft_post: str | |
| final_post: str | |
| # Research Agent Pieces | |
| qdrant_research_chain = ( | |
| {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")} | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")} | |
| ) | |
| # Web Search Agent Pieces | |
| tavily_tool = TavilySearchResults(max_results=5) | |
| web_search_chain = ( | |
| { | |
| "topic": itemgetter("topic"), | |
| "qdrant_results": itemgetter("qdrant_results"), | |
| } | |
| | prompts.search_query_prompt | |
| | models.gpt4o_mini | |
| | StrOutputParser() | |
| | tavily_tool | |
| | { | |
| "topic": itemgetter("topic"), | |
| "qdrant_results": itemgetter("qdrant_results"), | |
| "search_results": RunnablePassthrough() | |
| } | |
| | prompts.summarize_prompt | |
| | models.gpt4o_mini | |
| | StrOutputParser() | |
| ) | |
| def query_qdrant(state: State) -> State: | |
| # Extract the last message as the input | |
| input_text = state["messages"][-1] | |
| # Run the chain | |
| result = qdrant_research_chain.invoke({"topic": input_text}) | |
| # Update the state with the research results | |
| state["research_data"]["qdrant_results"] = result | |
| return state | |
| def web_search(state: State) -> State: | |
| # Extract the last message as the topic | |
| topic = state["messages"][-1] | |
| # Get the Qdrant results from the state | |
| qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.") | |
| # Run the web search chain | |
| result = web_search_chain.invoke({ | |
| "topic": topic, | |
| "qdrant_results": qdrant_results | |
| }) | |
| # Update the state with the web search results | |
| state["research_data"]["web_search_results"] = result | |
| return state | |
| def research_supervisor(state): | |
| # Implement research supervision logic | |
| return state | |
| def post_creation(state): | |
| # Implement post creation logic | |
| return state | |
| def copy_editing(state): | |
| # Implement copy editing logic | |
| return state | |
| def voice_editing(state): | |
| # Implement voice editing logic | |
| return state | |
| def post_review(state): | |
| # Implement post review logic | |
| return state | |
| def writing_supervisor(state): | |
| # Implement writing supervision logic | |
| return state | |
| def overall_supervisor(state): | |
| # Implement overall supervision logic | |
| return state | |
| # Create the research team graph | |
| research_graph = StateGraph(State) | |
| research_graph.add_node("query_qdrant", query_qdrant) | |
| research_graph.add_node("web_search", web_search) | |
| research_graph.add_node("research_supervisor", research_supervisor) | |
| research_graph.add_edge("query_qdrant", "research_supervisor") | |
| research_graph.add_edge("web_search", "research_supervisor") | |
| research_graph.add_edge("research_supervisor", "query_qdrant") | |
| research_graph.add_edge("research_supervisor", "web_search") | |
| research_graph.add_edge("research_supervisor", END) | |
| research_graph.set_entry_point("research_supervisor") | |
| # Create the writing team graph | |
| writing_graph = StateGraph(State) | |
| writing_graph.add_node("post_creation", post_creation) | |
| writing_graph.add_node("copy_editing", copy_editing) | |
| writing_graph.add_node("voice_editing", voice_editing) | |
| writing_graph.add_node("post_review", post_review) | |
| writing_graph.add_node("writing_supervisor", writing_supervisor) | |
| writing_graph.add_edge("writing_supervisor", "post_creation") | |
| writing_graph.add_edge("post_creation", "copy_editing") | |
| writing_graph.add_edge("copy_editing", "voice_editing") | |
| writing_graph.add_edge("voice_editing", "post_review") | |
| writing_graph.add_edge("post_review", "writing_supervisor") | |
| writing_graph.add_edge("writing_supervisor", END) | |
| writing_graph.set_entry_point("writing_supervisor") | |
| # Create the overall graph | |
| overall_graph = StateGraph(State) | |
| # Add the research and writing team graphs as nodes | |
| overall_graph.add_node("research_team", research_graph) | |
| overall_graph.add_node("writing_team", writing_graph) | |
| # Add the overall supervisor node | |
| overall_graph.add_node("overall_supervisor", overall_supervisor) | |
| overall_graph.set_entry_point("overall_supervisor") | |
| # Connect the nodes | |
| overall_graph.add_edge("overall_supervisor", "research_team") | |
| overall_graph.add_edge("research_team", "overall_supervisor") | |
| overall_graph.add_edge("overall_supervisor", "writing_team") | |
| overall_graph.add_edge("writing_team", "overall_supervisor") | |
| overall_graph.add_edge("overall_supervisor", END) | |
| # Compile the graph | |
| app = overall_graph.compile() |