KeenWoo commited on
Commit
fe8f736
·
verified ·
1 Parent(s): 6168c98

Update alz_companion/agent.py

Browse files
Files changed (1) hide show
  1. alz_companion/agent.py +65 -23
alz_companion/agent.py CHANGED
@@ -233,7 +233,6 @@ def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6, stop: Opt
233
  # -----------------------------
234
  # Prompting & RAG Chain
235
  # -----------------------------
236
-
237
  def make_rag_chain(
238
  vs_general: FAISS,
239
  vs_personal: FAISS,
@@ -245,68 +244,107 @@ def make_rag_chain(
245
  caregiver_name: str = "the caregiver",
246
  tone: str = "warm",
247
  ):
 
 
248
  def _format_docs(docs: List[Document], default_msg: str) -> str:
249
  if not docs: return default_msg
250
  unique_docs = {doc.page_content: doc for doc in docs}.values()
251
  return "\n".join([f"- {d.page_content.strip()}" for d in unique_docs])
252
 
253
  def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None, topic_tag: Optional[str] = None, context_tags: Optional[List[str]] = None) -> Dict[str, Any]:
 
254
  router_messages = [{"role": "user", "content": ROUTER_PROMPT.format(query=query)}]
255
  query_type = call_llm(router_messages, temperature=0.0).strip().lower()
256
  print(f"Query classified as: {query_type}")
257
 
258
  system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
259
- messages = [{"role": "system", "content": system_message}, *chat_history]
 
260
 
261
  if "general_knowledge_question" in query_type:
262
  user_prompt = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE.format(question=query, language=language)
263
  messages.append({"role": "user", "content": user_prompt})
264
- return {"answer": call_llm(messages, temperature=temperature), "sources": ["General Knowledge"]}
 
265
 
266
  elif "factual_question" in query_type:
 
267
  expansion_prompt = QUERY_EXPANSION_PROMPT.format(question=query)
268
  expansion_response = call_llm([{"role": "user", "content": expansion_prompt}], temperature=0.1)
269
  try:
270
- expanded_queries = json.loads(expansion_response.strip().replace("```json", "").replace("```", ""))
 
271
  search_queries = [query] + expanded_queries
272
  except json.JSONDecodeError:
273
  search_queries = [query]
274
-
275
  all_docs = []
276
  for q in search_queries:
277
  all_docs.extend(vs_personal.similarity_search(q, k=2))
278
  all_docs.extend(vs_general.similarity_search(q, k=2))
279
- context = _format_docs(all_docs, "(No relevant information found.)")
280
  user_prompt = ANSWER_TEMPLATE_FACTUAL.format(context=context, question=query, language=language)
281
  messages.append({"role": "user", "content": user_prompt})
282
- return {"answer": call_llm(messages, temperature=temperature), "sources": list(set(d.metadata.get("source", "unknown") for d in all_docs))}
 
 
283
 
284
  elif "general_conversation" in query_type:
285
  user_prompt = ANSWER_TEMPLATE_GENERAL.format(question=query, language=language)
286
  messages.append({"role": "user", "content": user_prompt})
287
- return {"answer": call_llm(messages, temperature=temperature), "sources": []}
288
-
289
- else: # Default to caregiving logic
290
- search_filter = {}
291
- if scenario_tag: search_filter["behaviors"] = scenario_tag.lower()
292
- if emotion_tag: search_filter["emotion"] = emotion_tag.lower()
293
- if topic_tag: search_filter["topic_tags"] = topic_tag.lower()
294
- if context_tags: search_filter["context_tags"] = {"in": [tag.lower() for tag in context_tags]}
295
 
 
 
 
 
296
  personal_docs = vs_personal.similarity_search(query, k=3)
297
  general_docs = vs_general.similarity_search(query, k=3)
298
- if search_filter:
299
- personal_docs.extend(vs_personal.similarity_search(query, k=3, filter=search_filter))
300
- general_docs.extend(vs_general.similarity_search(query, k=3, filter=search_filter))
301
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  all_docs_care = list({doc.page_content: doc for doc in personal_docs + general_docs}.values())
 
 
 
303
  personal_context = _format_docs([d for d in all_docs_care if d in personal_docs], "(No relevant personal memories found.)")
304
  general_context = _format_docs([d for d in all_docs_care if d in general_docs], "(No general guidance found.)")
305
 
306
- first_emotion = next((d.metadata.get("emotion") for d in all_docs_care if d.metadata.get("emotion")), None)
 
 
 
 
 
 
 
307
  emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
 
 
308
 
309
- template = ANSWER_TEMPLATE_ADQ if any([scenario_tag, emotion_tag, first_emotion]) else ANSWER_TEMPLATE_CALM
310
  if template == ANSWER_TEMPLATE_ADQ:
311
  user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=scenario_tag, emotions_context=emotions_context, role=role, language=language)
312
  else:
@@ -316,12 +354,16 @@ def make_rag_chain(
316
  messages.append({"role": "user", "content": user_prompt})
317
  answer = call_llm(messages, temperature=temperature)
318
 
319
- if scenario_tag and scenario_tag.lower() in ["exit_seeking", "wandering"]:
 
320
  answer += f"\n\n---\n{RISK_FOOTER}"
321
 
322
- return {"answer": answer, "sources": list(set(d.metadata.get("source", "unknown") for d in all_docs_care))}
 
 
323
  return _answer_fn
324
 
 
325
  def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
326
  if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
327
  try:
 
233
  # -----------------------------
234
  # Prompting & RAG Chain
235
  # -----------------------------
 
236
  def make_rag_chain(
237
  vs_general: FAISS,
238
  vs_personal: FAISS,
 
244
  caregiver_name: str = "the caregiver",
245
  tone: str = "warm",
246
  ):
247
+ """Returns a callable that performs the complete, intelligent RAG process."""
248
+
249
  def _format_docs(docs: List[Document], default_msg: str) -> str:
250
  if not docs: return default_msg
251
  unique_docs = {doc.page_content: doc for doc in docs}.values()
252
  return "\n".join([f"- {d.page_content.strip()}" for d in unique_docs])
253
 
254
  def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None, topic_tag: Optional[str] = None, context_tags: Optional[List[str]] = None) -> Dict[str, Any]:
255
+
256
  router_messages = [{"role": "user", "content": ROUTER_PROMPT.format(query=query)}]
257
  query_type = call_llm(router_messages, temperature=0.0).strip().lower()
258
  print(f"Query classified as: {query_type}")
259
 
260
  system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
261
+ messages = [{"role": "system", "content": system_message}]
262
+ messages.extend(chat_history)
263
 
264
  if "general_knowledge_question" in query_type:
265
  user_prompt = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE.format(question=query, language=language)
266
  messages.append({"role": "user", "content": user_prompt})
267
+ answer = call_llm(messages, temperature=temperature)
268
+ return {"answer": answer, "sources": ["General Knowledge"]}
269
 
270
  elif "factual_question" in query_type:
271
+ print(f"Performing query expansion for: '{query}'")
272
  expansion_prompt = QUERY_EXPANSION_PROMPT.format(question=query)
273
  expansion_response = call_llm([{"role": "user", "content": expansion_prompt}], temperature=0.1)
274
  try:
275
+ clean_response = expansion_response.strip().replace("```json", "").replace("```", "")
276
+ expanded_queries = json.loads(clean_response)
277
  search_queries = [query] + expanded_queries
278
  except json.JSONDecodeError:
279
  search_queries = [query]
280
+ print(f"Searching with queries: {search_queries}")
281
  all_docs = []
282
  for q in search_queries:
283
  all_docs.extend(vs_personal.similarity_search(q, k=2))
284
  all_docs.extend(vs_general.similarity_search(q, k=2))
285
+ context = _format_docs(all_docs, "(No relevant information found in the memory journal.)")
286
  user_prompt = ANSWER_TEMPLATE_FACTUAL.format(context=context, question=query, language=language)
287
  messages.append({"role": "user", "content": user_prompt})
288
+ answer = call_llm(messages, temperature=temperature)
289
+ sources = list(set(d.metadata.get("source", "unknown") for d in all_docs))
290
+ return {"answer": answer, "sources": sources}
291
 
292
  elif "general_conversation" in query_type:
293
  user_prompt = ANSWER_TEMPLATE_GENERAL.format(question=query, language=language)
294
  messages.append({"role": "user", "content": user_prompt})
295
+ answer = call_llm(messages, temperature=temperature)
296
+ return {"answer": answer, "sources": []}
 
 
 
 
 
 
297
 
298
+ else: # Default to the original caregiving logic
299
+ # --- Reworked search strategy to handle filters correctly ---
300
+
301
+ # 1. Start with a general, unfiltered search to always get text-based matches.
302
  personal_docs = vs_personal.similarity_search(query, k=3)
303
  general_docs = vs_general.similarity_search(query, k=3)
304
+
305
+ # 2. Build a filter for simple equality checks (FAISS supported).
306
+ simple_search_filter = {}
307
+ if scenario_tag and scenario_tag != "None":
308
+ simple_search_filter["behaviors"] = scenario_tag.lower()
309
+ if emotion_tag and emotion_tag != "None":
310
+ simple_search_filter["emotion"] = emotion_tag.lower()
311
+ if topic_tag and topic_tag != "None":
312
+ simple_search_filter["topic_tags"] = topic_tag.lower()
313
+
314
+ # 3. If simple filters exist, perform a second, more specific search.
315
+ if simple_search_filter:
316
+ print(f"Performing additional search with filter: {simple_search_filter}")
317
+ personal_docs.extend(vs_personal.similarity_search(query, k=2, filter=simple_search_filter))
318
+ general_docs.extend(vs_general.similarity_search(query, k=2, filter=simple_search_filter))
319
+
320
+ # 4. If context_tags exist (unsupported by 'in'), loop through them and perform separate searches.
321
+ if context_tags:
322
+ print(f"Performing looped context tag search for: {context_tags}")
323
+ for tag in context_tags:
324
+ context_filter = {"context_tags": tag.lower()}
325
+ personal_docs.extend(vs_personal.similarity_search(query, k=1, filter=context_filter))
326
+ general_docs.extend(vs_general.similarity_search(query, k=1, filter=context_filter))
327
+
328
+ # 5. Combine and de-duplicate all results.
329
  all_docs_care = list({doc.page_content: doc for doc in personal_docs + general_docs}.values())
330
+
331
+ # --- End of reworked search strategy ---
332
+
333
  personal_context = _format_docs([d for d in all_docs_care if d in personal_docs], "(No relevant personal memories found.)")
334
  general_context = _format_docs([d for d in all_docs_care if d in general_docs], "(No general guidance found.)")
335
 
336
+ first_emotion = None
337
+ for doc in all_docs_care:
338
+ if "emotion" in doc.metadata and doc.metadata["emotion"]:
339
+ emotion_data = doc.metadata["emotion"]
340
+ if isinstance(emotion_data, list): first_emotion = emotion_data[0]
341
+ else: first_emotion = emotion_data
342
+ if first_emotion: break
343
+
344
  emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
345
+ is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None") or (first_emotion is not None)
346
+ template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
347
 
 
348
  if template == ANSWER_TEMPLATE_ADQ:
349
  user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=scenario_tag, emotions_context=emotions_context, role=role, language=language)
350
  else:
 
354
  messages.append({"role": "user", "content": user_prompt})
355
  answer = call_llm(messages, temperature=temperature)
356
 
357
+ high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
358
+ if scenario_tag and scenario_tag.lower() in high_risk_scenarios:
359
  answer += f"\n\n---\n{RISK_FOOTER}"
360
 
361
+ sources = list(set(d.metadata.get("source", "unknown") for d in all_docs_care))
362
+ return {"answer": answer, "sources": sources}
363
+
364
  return _answer_fn
365
 
366
+
367
  def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
368
  if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
369
  try: