Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from langgraph.graph import START, StateGraph, MessagesState | |
| from langgraph.prebuilt import tools_condition | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from langchain_core.messages import HumanMessage | |
| from langchain.tools.retriever import create_retriever_tool | |
| from supabase.client import Client, create_client | |
| from utils import load_prompt | |
| from tools import calculator, duck_web_search, wiki_search, arxiv_search | |
| load_dotenv() | |
| # Create retriever | |
| embeddings = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-modernbert-base") # dim=768 | |
| supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY")) | |
| vector_store = SupabaseVectorStore( | |
| client=supabase, | |
| embedding= embeddings, | |
| table_name="gaia_documents", | |
| query_name="match_documents_langchain", | |
| ) | |
| retriever = create_retriever_tool( | |
| retriever=vector_store.as_retriever(), | |
| name="ModernBERT Retriever", | |
| description="A retriever of similar questions from a vector store.", | |
| ) | |
| tools = [calculator, duck_web_search, wiki_search, arxiv_search] | |
| model_id = "Qwen/Qwen3-32B" | |
| llm = HuggingFaceEndpoint( | |
| repo_id=model_id, | |
| temperature=0, | |
| repetition_penalty=1.03, | |
| provider="auto", | |
| huggingfacehub_api_token=os.getenv("HF_INFERENCE_KEY") | |
| ) | |
| agent = ChatHuggingFace(llm=llm) | |
| agent_with_tools = agent.bind_tools(tools) | |
| def retriever_node(state: MessagesState): | |
| """RAG node""" | |
| similar_question = vector_store.similarity_search(state["messages"][0].content) | |
| response = [HumanMessage(f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}")] | |
| return {"messages": response} | |
| def processor_node(state: MessagesState): | |
| system_prompt = load_prompt("prompt.yaml") | |
| messages = state.get("messages", []) | |
| response = [agent_with_tools.invoke([system_prompt] + messages)] | |
| """Agent node that answers questions""" | |
| return {"messages": response} | |
| def agent_graph(): | |
| workflow = StateGraph(MessagesState) | |
| ## Add nodes | |
| workflow.add_node("retriever_node", retriever_node) | |
| workflow.add_node("processor_node", processor_node) | |
| workflow.add_node("tools", ToolNode(tools)) | |
| ## Add edges | |
| workflow.add_edge(START, "retriever_node") | |
| workflow.add_edge("retriever_node", "processor_node") | |
| workflow.add_conditional_edges("processor_node", tools_condition) | |
| workflow.add_edge("tools", "processor_node") | |
| # Compile graph | |
| graph = workflow.compile() | |
| return graph |