Spaces:
Sleeping
Sleeping
Update alz_companion/agent.py
Browse files- alz_companion/agent.py +65 -23
alz_companion/agent.py
CHANGED
|
@@ -233,7 +233,6 @@ def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6, stop: Opt
|
|
| 233 |
# -----------------------------
|
| 234 |
# Prompting & RAG Chain
|
| 235 |
# -----------------------------
|
| 236 |
-
|
| 237 |
def make_rag_chain(
|
| 238 |
vs_general: FAISS,
|
| 239 |
vs_personal: FAISS,
|
|
@@ -245,68 +244,107 @@ def make_rag_chain(
|
|
| 245 |
caregiver_name: str = "the caregiver",
|
| 246 |
tone: str = "warm",
|
| 247 |
):
|
|
|
|
|
|
|
| 248 |
def _format_docs(docs: List[Document], default_msg: str) -> str:
|
| 249 |
if not docs: return default_msg
|
| 250 |
unique_docs = {doc.page_content: doc for doc in docs}.values()
|
| 251 |
return "\n".join([f"- {d.page_content.strip()}" for d in unique_docs])
|
| 252 |
|
| 253 |
def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None, topic_tag: Optional[str] = None, context_tags: Optional[List[str]] = None) -> Dict[str, Any]:
|
|
|
|
| 254 |
router_messages = [{"role": "user", "content": ROUTER_PROMPT.format(query=query)}]
|
| 255 |
query_type = call_llm(router_messages, temperature=0.0).strip().lower()
|
| 256 |
print(f"Query classified as: {query_type}")
|
| 257 |
|
| 258 |
system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
|
| 259 |
-
messages = [{"role": "system", "content": system_message}
|
|
|
|
| 260 |
|
| 261 |
if "general_knowledge_question" in query_type:
|
| 262 |
user_prompt = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE.format(question=query, language=language)
|
| 263 |
messages.append({"role": "user", "content": user_prompt})
|
| 264 |
-
|
|
|
|
| 265 |
|
| 266 |
elif "factual_question" in query_type:
|
|
|
|
| 267 |
expansion_prompt = QUERY_EXPANSION_PROMPT.format(question=query)
|
| 268 |
expansion_response = call_llm([{"role": "user", "content": expansion_prompt}], temperature=0.1)
|
| 269 |
try:
|
| 270 |
-
|
|
|
|
| 271 |
search_queries = [query] + expanded_queries
|
| 272 |
except json.JSONDecodeError:
|
| 273 |
search_queries = [query]
|
| 274 |
-
|
| 275 |
all_docs = []
|
| 276 |
for q in search_queries:
|
| 277 |
all_docs.extend(vs_personal.similarity_search(q, k=2))
|
| 278 |
all_docs.extend(vs_general.similarity_search(q, k=2))
|
| 279 |
-
context = _format_docs(all_docs, "(No relevant information found.)")
|
| 280 |
user_prompt = ANSWER_TEMPLATE_FACTUAL.format(context=context, question=query, language=language)
|
| 281 |
messages.append({"role": "user", "content": user_prompt})
|
| 282 |
-
|
|
|
|
|
|
|
| 283 |
|
| 284 |
elif "general_conversation" in query_type:
|
| 285 |
user_prompt = ANSWER_TEMPLATE_GENERAL.format(question=query, language=language)
|
| 286 |
messages.append({"role": "user", "content": user_prompt})
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
else: # Default to caregiving logic
|
| 290 |
-
search_filter = {}
|
| 291 |
-
if scenario_tag: search_filter["behaviors"] = scenario_tag.lower()
|
| 292 |
-
if emotion_tag: search_filter["emotion"] = emotion_tag.lower()
|
| 293 |
-
if topic_tag: search_filter["topic_tags"] = topic_tag.lower()
|
| 294 |
-
if context_tags: search_filter["context_tags"] = {"in": [tag.lower() for tag in context_tags]}
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
personal_docs = vs_personal.similarity_search(query, k=3)
|
| 297 |
general_docs = vs_general.similarity_search(query, k=3)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
all_docs_care = list({doc.page_content: doc for doc in personal_docs + general_docs}.values())
|
|
|
|
|
|
|
|
|
|
| 303 |
personal_context = _format_docs([d for d in all_docs_care if d in personal_docs], "(No relevant personal memories found.)")
|
| 304 |
general_context = _format_docs([d for d in all_docs_care if d in general_docs], "(No general guidance found.)")
|
| 305 |
|
| 306 |
-
first_emotion =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
|
|
|
|
|
|
|
| 308 |
|
| 309 |
-
template = ANSWER_TEMPLATE_ADQ if any([scenario_tag, emotion_tag, first_emotion]) else ANSWER_TEMPLATE_CALM
|
| 310 |
if template == ANSWER_TEMPLATE_ADQ:
|
| 311 |
user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=scenario_tag, emotions_context=emotions_context, role=role, language=language)
|
| 312 |
else:
|
|
@@ -316,12 +354,16 @@ def make_rag_chain(
|
|
| 316 |
messages.append({"role": "user", "content": user_prompt})
|
| 317 |
answer = call_llm(messages, temperature=temperature)
|
| 318 |
|
| 319 |
-
|
|
|
|
| 320 |
answer += f"\n\n---\n{RISK_FOOTER}"
|
| 321 |
|
| 322 |
-
|
|
|
|
|
|
|
| 323 |
return _answer_fn
|
| 324 |
|
|
|
|
| 325 |
def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
|
| 326 |
if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
|
| 327 |
try:
|
|
|
|
| 233 |
# -----------------------------
|
| 234 |
# Prompting & RAG Chain
|
| 235 |
# -----------------------------
|
|
|
|
| 236 |
def make_rag_chain(
|
| 237 |
vs_general: FAISS,
|
| 238 |
vs_personal: FAISS,
|
|
|
|
| 244 |
caregiver_name: str = "the caregiver",
|
| 245 |
tone: str = "warm",
|
| 246 |
):
|
| 247 |
+
"""Returns a callable that performs the complete, intelligent RAG process."""
|
| 248 |
+
|
| 249 |
def _format_docs(docs: List[Document], default_msg: str) -> str:
|
| 250 |
if not docs: return default_msg
|
| 251 |
unique_docs = {doc.page_content: doc for doc in docs}.values()
|
| 252 |
return "\n".join([f"- {d.page_content.strip()}" for d in unique_docs])
|
| 253 |
|
| 254 |
def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None, topic_tag: Optional[str] = None, context_tags: Optional[List[str]] = None) -> Dict[str, Any]:
|
| 255 |
+
|
| 256 |
router_messages = [{"role": "user", "content": ROUTER_PROMPT.format(query=query)}]
|
| 257 |
query_type = call_llm(router_messages, temperature=0.0).strip().lower()
|
| 258 |
print(f"Query classified as: {query_type}")
|
| 259 |
|
| 260 |
system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
|
| 261 |
+
messages = [{"role": "system", "content": system_message}]
|
| 262 |
+
messages.extend(chat_history)
|
| 263 |
|
| 264 |
if "general_knowledge_question" in query_type:
|
| 265 |
user_prompt = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE.format(question=query, language=language)
|
| 266 |
messages.append({"role": "user", "content": user_prompt})
|
| 267 |
+
answer = call_llm(messages, temperature=temperature)
|
| 268 |
+
return {"answer": answer, "sources": ["General Knowledge"]}
|
| 269 |
|
| 270 |
elif "factual_question" in query_type:
|
| 271 |
+
print(f"Performing query expansion for: '{query}'")
|
| 272 |
expansion_prompt = QUERY_EXPANSION_PROMPT.format(question=query)
|
| 273 |
expansion_response = call_llm([{"role": "user", "content": expansion_prompt}], temperature=0.1)
|
| 274 |
try:
|
| 275 |
+
clean_response = expansion_response.strip().replace("```json", "").replace("```", "")
|
| 276 |
+
expanded_queries = json.loads(clean_response)
|
| 277 |
search_queries = [query] + expanded_queries
|
| 278 |
except json.JSONDecodeError:
|
| 279 |
search_queries = [query]
|
| 280 |
+
print(f"Searching with queries: {search_queries}")
|
| 281 |
all_docs = []
|
| 282 |
for q in search_queries:
|
| 283 |
all_docs.extend(vs_personal.similarity_search(q, k=2))
|
| 284 |
all_docs.extend(vs_general.similarity_search(q, k=2))
|
| 285 |
+
context = _format_docs(all_docs, "(No relevant information found in the memory journal.)")
|
| 286 |
user_prompt = ANSWER_TEMPLATE_FACTUAL.format(context=context, question=query, language=language)
|
| 287 |
messages.append({"role": "user", "content": user_prompt})
|
| 288 |
+
answer = call_llm(messages, temperature=temperature)
|
| 289 |
+
sources = list(set(d.metadata.get("source", "unknown") for d in all_docs))
|
| 290 |
+
return {"answer": answer, "sources": sources}
|
| 291 |
|
| 292 |
elif "general_conversation" in query_type:
|
| 293 |
user_prompt = ANSWER_TEMPLATE_GENERAL.format(question=query, language=language)
|
| 294 |
messages.append({"role": "user", "content": user_prompt})
|
| 295 |
+
answer = call_llm(messages, temperature=temperature)
|
| 296 |
+
return {"answer": answer, "sources": []}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
+
else: # Default to the original caregiving logic
|
| 299 |
+
# --- Reworked search strategy to handle filters correctly ---
|
| 300 |
+
|
| 301 |
+
# 1. Start with a general, unfiltered search to always get text-based matches.
|
| 302 |
personal_docs = vs_personal.similarity_search(query, k=3)
|
| 303 |
general_docs = vs_general.similarity_search(query, k=3)
|
| 304 |
+
|
| 305 |
+
# 2. Build a filter for simple equality checks (FAISS supported).
|
| 306 |
+
simple_search_filter = {}
|
| 307 |
+
if scenario_tag and scenario_tag != "None":
|
| 308 |
+
simple_search_filter["behaviors"] = scenario_tag.lower()
|
| 309 |
+
if emotion_tag and emotion_tag != "None":
|
| 310 |
+
simple_search_filter["emotion"] = emotion_tag.lower()
|
| 311 |
+
if topic_tag and topic_tag != "None":
|
| 312 |
+
simple_search_filter["topic_tags"] = topic_tag.lower()
|
| 313 |
+
|
| 314 |
+
# 3. If simple filters exist, perform a second, more specific search.
|
| 315 |
+
if simple_search_filter:
|
| 316 |
+
print(f"Performing additional search with filter: {simple_search_filter}")
|
| 317 |
+
personal_docs.extend(vs_personal.similarity_search(query, k=2, filter=simple_search_filter))
|
| 318 |
+
general_docs.extend(vs_general.similarity_search(query, k=2, filter=simple_search_filter))
|
| 319 |
+
|
| 320 |
+
# 4. If context_tags exist (unsupported by 'in'), loop through them and perform separate searches.
|
| 321 |
+
if context_tags:
|
| 322 |
+
print(f"Performing looped context tag search for: {context_tags}")
|
| 323 |
+
for tag in context_tags:
|
| 324 |
+
context_filter = {"context_tags": tag.lower()}
|
| 325 |
+
personal_docs.extend(vs_personal.similarity_search(query, k=1, filter=context_filter))
|
| 326 |
+
general_docs.extend(vs_general.similarity_search(query, k=1, filter=context_filter))
|
| 327 |
+
|
| 328 |
+
# 5. Combine and de-duplicate all results.
|
| 329 |
all_docs_care = list({doc.page_content: doc for doc in personal_docs + general_docs}.values())
|
| 330 |
+
|
| 331 |
+
# --- End of reworked search strategy ---
|
| 332 |
+
|
| 333 |
personal_context = _format_docs([d for d in all_docs_care if d in personal_docs], "(No relevant personal memories found.)")
|
| 334 |
general_context = _format_docs([d for d in all_docs_care if d in general_docs], "(No general guidance found.)")
|
| 335 |
|
| 336 |
+
first_emotion = None
|
| 337 |
+
for doc in all_docs_care:
|
| 338 |
+
if "emotion" in doc.metadata and doc.metadata["emotion"]:
|
| 339 |
+
emotion_data = doc.metadata["emotion"]
|
| 340 |
+
if isinstance(emotion_data, list): first_emotion = emotion_data[0]
|
| 341 |
+
else: first_emotion = emotion_data
|
| 342 |
+
if first_emotion: break
|
| 343 |
+
|
| 344 |
emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
|
| 345 |
+
is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None") or (first_emotion is not None)
|
| 346 |
+
template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
|
| 347 |
|
|
|
|
| 348 |
if template == ANSWER_TEMPLATE_ADQ:
|
| 349 |
user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=scenario_tag, emotions_context=emotions_context, role=role, language=language)
|
| 350 |
else:
|
|
|
|
| 354 |
messages.append({"role": "user", "content": user_prompt})
|
| 355 |
answer = call_llm(messages, temperature=temperature)
|
| 356 |
|
| 357 |
+
high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
|
| 358 |
+
if scenario_tag and scenario_tag.lower() in high_risk_scenarios:
|
| 359 |
answer += f"\n\n---\n{RISK_FOOTER}"
|
| 360 |
|
| 361 |
+
sources = list(set(d.metadata.get("source", "unknown") for d in all_docs_care))
|
| 362 |
+
return {"answer": answer, "sources": sources}
|
| 363 |
+
|
| 364 |
return _answer_fn
|
| 365 |
|
| 366 |
+
|
| 367 |
def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
|
| 368 |
if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
|
| 369 |
try:
|