Spaces:
Sleeping
Sleeping
Commit
·
975d8af
1
Parent(s):
f925355
Updated Query Classifier
Browse files- langgraph_agent.py +30 -20
- main.py +0 -1
langgraph_agent.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
-
# Import LangGraph components
|
| 2 |
from langgraph.graph import StateGraph, END
|
| 3 |
-
from typing import TypedDict, List, Dict, Any
|
| 4 |
-
import operator
|
| 5 |
-
from pydantic import BaseModel, Field
|
| 6 |
from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
|
| 7 |
import os
|
| 8 |
-
|
|
|
|
| 9 |
class AgentState(TypedDict):
|
| 10 |
query: str
|
| 11 |
previous_conversation: str
|
|
@@ -14,18 +12,36 @@ class AgentState(TypedDict):
|
|
| 14 |
context: List[str]
|
| 15 |
response: str
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def query_classifier(state: AgentState) -> AgentState:
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
|
|
|
| 29 |
return state
|
| 30 |
|
| 31 |
def retrieve_documents(state: AgentState) -> AgentState:
|
|
@@ -183,18 +199,12 @@ def agent_with_db():
|
|
| 183 |
"requires_rag": False,
|
| 184 |
"context": [],
|
| 185 |
"response": "",
|
| 186 |
-
# "style": style
|
| 187 |
}
|
| 188 |
-
print("Initial state:", initial_state)
|
| 189 |
|
| 190 |
-
# Run the workflow
|
| 191 |
final_state = self.workflow.invoke(initial_state)
|
| 192 |
-
print("Final state:", final_state)
|
| 193 |
|
| 194 |
-
# Update conversation history with response
|
| 195 |
self.conversation_history += f"Assistant: {final_state['response']}\n"
|
| 196 |
|
| 197 |
-
# Return in the expected format
|
| 198 |
return {"result": final_state["response"]}
|
| 199 |
|
| 200 |
return HealthAgent(agent_workflow)
|
|
|
|
|
|
|
| 1 |
from langgraph.graph import StateGraph, END
|
| 2 |
+
from typing import TypedDict, List, Dict, Any
|
|
|
|
|
|
|
| 3 |
from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
|
| 4 |
import os
|
| 5 |
+
|
| 6 |
+
# state schema
|
| 7 |
class AgentState(TypedDict):
|
| 8 |
query: str
|
| 9 |
previous_conversation: str
|
|
|
|
| 12 |
context: List[str]
|
| 13 |
response: str
|
| 14 |
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# def query_classifier(state: AgentState) -> AgentState:
|
| 18 |
+
# """Determine if the query requires RAG retrieval based on keywords.
|
| 19 |
+
# Is not continued anymore, will be removed in future."""
|
| 20 |
+
# query_lower = state["query"].lower()
|
| 21 |
+
# rag_keywords = [
|
| 22 |
+
# "scheme", "schemes", "program", "programs", "policy", "policies",
|
| 23 |
+
# "public health engineering", "phe", "public health", "government",
|
| 24 |
+
# "benefit", "financial", "assistance", "aid", "initiative", "yojana",
|
| 25 |
+
# ]
|
| 26 |
+
|
| 27 |
+
# state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
|
| 28 |
+
# return state
|
| 29 |
|
| 30 |
def query_classifier(state: AgentState) -> AgentState:
|
| 31 |
+
"""Updated classifier to use LLM for intent classification."""
|
| 32 |
+
|
| 33 |
+
query = state["query"]
|
| 34 |
+
|
| 35 |
+
# Then classify intent
|
| 36 |
+
classification_prompt = f"""
|
| 37 |
+
Answer with only 'Yes' or 'No'.
|
| 38 |
+
Classify if this query is asking about government schemes, policies, or benefits.
|
| 39 |
+
The language may not be English, So first detect the language. and understand the query.:
|
| 40 |
+
Query: {query}
|
| 41 |
+
Remember Answer with only 'Yes' or 'No'."""
|
| 42 |
|
| 43 |
+
result = llm.predict(classification_prompt)
|
| 44 |
+
state["requires_rag"] = "yes" in result.lower()
|
| 45 |
return state
|
| 46 |
|
| 47 |
def retrieve_documents(state: AgentState) -> AgentState:
|
|
|
|
| 199 |
"requires_rag": False,
|
| 200 |
"context": [],
|
| 201 |
"response": "",
|
|
|
|
| 202 |
}
|
|
|
|
| 203 |
|
|
|
|
| 204 |
final_state = self.workflow.invoke(initial_state)
|
|
|
|
| 205 |
|
|
|
|
| 206 |
self.conversation_history += f"Assistant: {final_state['response']}\n"
|
| 207 |
|
|
|
|
| 208 |
return {"result": final_state["response"]}
|
| 209 |
|
| 210 |
return HealthAgent(agent_workflow)
|
main.py
CHANGED
|
@@ -61,7 +61,6 @@ async def retrieve(request:request, url:Request):
|
|
| 61 |
if origin is None:
|
| 62 |
origin = url.headers.get('referer')
|
| 63 |
print("origin: ", origin)
|
| 64 |
-
print("response: ", response)
|
| 65 |
return {"response": response["result"]}
|
| 66 |
|
| 67 |
except Exception as e:
|
|
|
|
| 61 |
if origin is None:
|
| 62 |
origin = url.headers.get('referer')
|
| 63 |
print("origin: ", origin)
|
|
|
|
| 64 |
return {"response": response["result"]}
|
| 65 |
|
| 66 |
except Exception as e:
|