KeenWoo commited on
Commit
e15286f
·
verified ·
1 Parent(s): b8d5bdd

Delete alz_companion/agent.py

Browse files
Files changed (1) hide show
  1. alz_companion/agent.py +0 -396
alz_companion/agent.py DELETED
@@ -1,396 +0,0 @@
1
-
2
- from __future__ import annotations
3
- import os
4
- import json
5
- import base64
6
- import time
7
- import tempfile
8
- import re
9
-
10
- from typing import List, Dict, Any, Optional
11
-
12
- try:
13
- from openai import OpenAI
14
- except Exception:
15
- OpenAI = None
16
-
17
- from langchain.schema import Document
18
- from langchain_community.vectorstores import FAISS
19
- from langchain_community.embeddings import HuggingFaceEmbeddings
20
-
21
- try:
22
- from gtts import gTTS
23
- except Exception:
24
- gTTS = None
25
-
26
- from .prompts import (
27
- SYSTEM_TEMPLATE, ANSWER_TEMPLATE_CALM, ANSWER_TEMPLATE_ADQ,
28
- SAFETY_GUARDRAILS, RISK_FOOTER, render_emotion_guidelines,
29
- NLU_ROUTER_PROMPT, SPECIALIST_CLASSIFIER_PROMPT,
30
- ROUTER_PROMPT,
31
- ANSWER_TEMPLATE_FACTUAL,
32
- ANSWER_TEMPLATE_GENERAL_KNOWLEDGE,
33
- ANSWER_TEMPLATE_GENERAL,
34
- QUERY_EXPANSION_PROMPT
35
- )
36
-
37
- # -----------------------------
38
- # Multimodal Processing Functions
39
- # -----------------------------
40
-
41
- def _openai_client() -> Optional[OpenAI]:
42
- api_key = os.getenv("OPENAI_API_KEY", "").strip()
43
- return OpenAI(api_key=api_key) if api_key and OpenAI else None
44
-
45
- def describe_image(image_path: str) -> str:
46
- client = _openai_client()
47
- if not client:
48
- return "(Image description failed: OpenAI API key not configured.)"
49
- try:
50
- extension = os.path.splitext(image_path)[1].lower()
51
- mime_type = f"image/{'jpeg' if extension in ['.jpg', '.jpeg'] else extension.strip('.')}"
52
- with open(image_path, "rb") as image_file:
53
- base64_image = base64.b64encode(image_file.read()).decode('utf-8')
54
- response = client.chat.completions.create(
55
- model="gpt-4o",
56
- messages=[
57
- {
58
- "role": "user",
59
- "content": [
60
- {"type": "text", "text": "Describe this image concisely for a memory journal. Focus on people, places, and key objects. Example: 'A photo of John and Mary smiling on a bench at the park.'"},
61
- {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{base64_image}"}}
62
- ],
63
- }
64
- ], max_tokens=100)
65
- return response.choices[0].message.content or "No description available."
66
- except Exception as e:
67
- return f"[Image description error: {e}]"
68
-
69
- # -----------------------------
70
- # NLU Classification Function (Dynamic Version)
71
- # -----------------------------
72
-
73
- def detect_tags_from_query(
74
- query: str,
75
- nlu_vectorstore: FAISS,
76
- behavior_options: list,
77
- emotion_options: list,
78
- topic_options: list,
79
- context_options: list,
80
- settings: dict = None
81
- ) -> Dict[str, Any]:
82
- """Uses a dynamic two-step NLU process: Route -> Retrieve Examples -> Classify."""
83
-
84
- # --- STEP 1: Route the query to determine the primary goal ---
85
- router_prompt = NLU_ROUTER_PROMPT.format(query=query)
86
- primary_goal_raw = call_llm([{"role": "user", "content": router_prompt}], temperature=0.0).strip().lower()
87
-
88
- # --- FIX START: Use separate variables for the filter (lowercase) and the prompt (Title Case) ---
89
- goal_for_filter = "practical_planning" if "practical" in primary_goal_raw else "emotional_support"
90
- goal_for_prompt = "Practical Planning" if "practical" in primary_goal_raw else "Emotional Support"
91
- # --- FIX END ---
92
-
93
- if settings and settings.get("debug_mode"):
94
- print(f"\n--- NLU Router ---\nGoal: {goal_for_prompt} (Filter: '{goal_for_filter}')\n------------------\n")
95
-
96
- # --- STEP 2: Retrieve relevant examples from the NLU vector store ---
97
- retriever = nlu_vectorstore.as_retriever(
98
- search_kwargs={"k": 2, "filter": {"primary_goal": goal_for_filter}} # <-- Use the correct lowercase filter
99
- )
100
- retrieved_docs = retriever.invoke(query)
101
-
102
- # Format the retrieved examples for the prompt
103
- selected_examples = "\n".join(
104
- f"User Query: \"{doc.page_content}\"\n{json.dumps(doc.metadata['classification'], indent=4)}"
105
- for doc in retrieved_docs
106
- )
107
- if not selected_examples:
108
- selected_examples = "(No relevant examples found)"
109
- if settings and settings.get("debug_mode"):
110
- print("WARNING: NLU retriever found no examples for this query.")
111
-
112
-
113
- # --- STEP 3: Use the Specialist Classifier with retrieved examples ---
114
- behavior_str = ", ".join(f'"{opt}"' for opt in behavior_options if opt != "None")
115
- emotion_str = ", ".join(f'"{opt}"' for opt in emotion_options if opt != "None")
116
- topic_str = ", ".join(f'"{opt}"' for opt in topic_options if opt != "None")
117
- context_str = ", ".join(f'"{opt}"' for opt in context_options if opt != "None")
118
-
119
- prompt = SPECIALIST_CLASSIFIER_PROMPT.format(
120
- primary_goal=goal_for_prompt, # Use Title Case for the prompt text
121
- examples=selected_examples,
122
- behavior_options=behavior_str,
123
- emotion_options=emotion_str,
124
- topic_options=topic_str,
125
- context_options=context_str,
126
- query=query
127
- )
128
-
129
- messages = [{"role": "system", "content": "You are a helpful NLU classification assistant."}, {"role": "user", "content": prompt}]
130
- response_str = call_llm(messages, temperature=0.1)
131
-
132
- if settings and settings.get("debug_mode"):
133
- print(f"\n--- NLU Specialist Full Response ---\n{response_str}\n----------------------------------\n")
134
-
135
- # --- STEP 4: Parse the final result ---
136
- result_dict = {"detected_behaviors": [], "detected_emotion": "None", "detected_topic": "None", "detected_contexts": []}
137
- try:
138
- start_brace = response_str.find('{')
139
- end_brace = response_str.rfind('}')
140
- if start_brace != -1 and end_brace > start_brace:
141
- json_str = response_str[start_brace : end_brace + 1]
142
- result = json.loads(json_str)
143
-
144
- behaviors = result.get("detected_behaviors")
145
- if behaviors: # This checks for both None and empty list
146
- result_dict["detected_behaviors"] = [b for b in behaviors if b in behavior_options]
147
-
148
- # Fix bug to properly handle null values from the LLM and will no longer raise the TypeError.
149
- # Use `or` to safely handle None, empty strings, etc.
150
- result_dict["detected_emotion"] = result.get("detected_emotion") or "None"
151
- result_dict["detected_topic"] = result.get("detected_topic") or "None"
152
-
153
- contexts = result.get("detected_contexts")
154
- if contexts: # This checks for both None and empty list
155
- result_dict["detected_contexts"] = [c for c in contexts if c in context_options]
156
-
157
- # Buggy code that can't handle a NULL case from LLM.
158
- # result_dict["detected_behaviors"] = [b for b in result.get("detected_behaviors", []) if b in behavior_options]
159
- # result_dict["detected_emotion"] = result.get("detected_emotion", "None")
160
- # result_dict["detected_topic"] = result.get("detected_topic", "None")
161
- # result_dict["detected_contexts"] = [c for c in result.get("detected_contexts", []) if c in context_options]
162
-
163
- return result_dict
164
- except (json.JSONDecodeError, AttributeError) as e:
165
- print(f"ERROR parsing NLU Specialist JSON: {e}")
166
- return result_dict
167
-
168
- # -----------------------------
169
- # Embeddings & VectorStore
170
- # -----------------------------
171
-
172
- def _default_embeddings():
173
- model_name = os.getenv("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
174
- return HuggingFaceEmbeddings(model_name=model_name)
175
-
176
- def build_or_load_vectorstore(docs: List[Document], index_path: str, is_personal: bool = False) -> FAISS:
177
- os.makedirs(os.path.dirname(index_path), exist_ok=True)
178
- if os.path.isdir(index_path) and os.path.exists(os.path.join(index_path, "index.faiss")):
179
- try:
180
- return FAISS.load_local(index_path, _default_embeddings(), allow_dangerous_deserialization=True)
181
- except Exception: pass
182
- if is_personal and not docs:
183
- docs = [Document(page_content="(This is the start of the personal memory journal.)", metadata={"source": "placeholder"})]
184
- vs = FAISS.from_documents(docs, _default_embeddings())
185
- vs.save_local(index_path)
186
- return vs
187
-
188
- def texts_from_jsonl(path: str) -> List[Document]:
189
- out: List[Document] = []
190
- try:
191
- with open(path, "r", encoding="utf-8") as f:
192
- for i, line in enumerate(f):
193
- obj = json.loads(line.strip())
194
- txt = obj.get("text") or ""
195
- if not txt.strip(): continue
196
- md = {"source": os.path.basename(path), "chunk": i}
197
- for k in ("behaviors", "emotion", "topic_tags", "context_tags"):
198
- if k in obj and obj[k]: md[k] = obj[k]
199
- out.append(Document(page_content=txt, metadata=md))
200
- except Exception: return []
201
- return out
202
-
203
- def bootstrap_vectorstore(sample_paths: List[str] | None = None, index_path: str = "data/faiss_index") -> FAISS:
204
- docs: List[Document] = []
205
- for p in (sample_paths or []):
206
- try:
207
- if p.lower().endswith(".jsonl"):
208
- docs.extend(texts_from_jsonl(p))
209
- else:
210
- with open(p, "r", encoding="utf-8", errors="ignore") as fh:
211
- docs.append(Document(page_content=fh.read(), metadata={"source": os.path.basename(p)}))
212
- except Exception: continue
213
- if not docs:
214
- docs = [Document(page_content="(empty index)", metadata={"source": "placeholder"})]
215
- return build_or_load_vectorstore(docs, index_path=index_path)
216
-
217
- # -----------------------------
218
- # LLM Call
219
- # -----------------------------
220
- def call_llm(messages: List[Dict[str, str]], temperature: float = 0.6, stop: Optional[List[str]] = None) -> str:
221
- client = _openai_client()
222
- model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
223
- if not client:
224
- return "(Offline Mode: OpenAI API key not configured.)"
225
- try:
226
- api_args = {"model": model, "messages": messages, "temperature": float(temperature if temperature is not None else 0.6)}
227
- if stop: api_args["stop"] = stop
228
- resp = client.chat.completions.create(**api_args)
229
- return (resp.choices[0].message.content or "").strip()
230
- except Exception as e:
231
- return f"[LLM API Error: {e}]"
232
-
233
- # -----------------------------
234
- # Prompting & RAG Chain
235
- # -----------------------------
236
- def make_rag_chain(
237
- vs_general: FAISS,
238
- vs_personal: FAISS,
239
- *,
240
- role: str = "patient",
241
- temperature: float = 0.6,
242
- language: str = "English",
243
- patient_name: str = "the patient",
244
- caregiver_name: str = "the caregiver",
245
- tone: str = "warm",
246
- ):
247
- """Returns a callable that performs the complete, intelligent RAG process."""
248
-
249
- def _format_docs(docs: List[Document], default_msg: str) -> str:
250
- if not docs: return default_msg
251
- unique_docs = {doc.page_content: doc for doc in docs}.values()
252
- return "\n".join([f"- {d.page_content.strip()}" for d in unique_docs])
253
-
254
- def _answer_fn(query: str, chat_history: List[Dict[str, str]], scenario_tag: Optional[str] = None, emotion_tag: Optional[str] = None, topic_tag: Optional[str] = None, context_tags: Optional[List[str]] = None) -> Dict[str, Any]:
255
-
256
- router_messages = [{"role": "user", "content": ROUTER_PROMPT.format(query=query)}]
257
- query_type = call_llm(router_messages, temperature=0.0).strip().lower()
258
- print(f"Query classified as: {query_type}")
259
-
260
- system_message = SYSTEM_TEMPLATE.format(tone=tone, language=language, patient_name=patient_name or "the patient", caregiver_name=caregiver_name or "the caregiver", guardrails=SAFETY_GUARDRAILS)
261
- messages = [{"role": "system", "content": system_message}]
262
- messages.extend(chat_history)
263
-
264
- if "general_knowledge_question" in query_type:
265
- user_prompt = ANSWER_TEMPLATE_GENERAL_KNOWLEDGE.format(question=query, language=language)
266
- messages.append({"role": "user", "content": user_prompt})
267
- answer = call_llm(messages, temperature=temperature)
268
- return {"answer": answer, "sources": ["General Knowledge"]}
269
-
270
- elif "factual_question" in query_type:
271
- print(f"Performing query expansion for: '{query}'")
272
- expansion_prompt = QUERY_EXPANSION_PROMPT.format(question=query)
273
- expansion_response = call_llm([{"role": "user", "content": expansion_prompt}], temperature=0.1)
274
- try:
275
- clean_response = expansion_response.strip().replace("```json", "").replace("```", "")
276
- expanded_queries = json.loads(clean_response)
277
- search_queries = [query] + expanded_queries
278
- except json.JSONDecodeError:
279
- search_queries = [query]
280
- print(f"Searching with queries: {search_queries}")
281
- all_docs = []
282
- for q in search_queries:
283
- all_docs.extend(vs_personal.similarity_search(q, k=2))
284
- all_docs.extend(vs_general.similarity_search(q, k=2))
285
- context = _format_docs(all_docs, "(No relevant information found in the memory journal.)")
286
- user_prompt = ANSWER_TEMPLATE_FACTUAL.format(context=context, question=query, language=language)
287
- messages.append({"role": "user", "content": user_prompt})
288
- answer = call_llm(messages, temperature=temperature)
289
- sources = list(set(d.metadata.get("source", "unknown") for d in all_docs))
290
- return {"answer": answer, "sources": sources}
291
-
292
- elif "general_conversation" in query_type:
293
- user_prompt = ANSWER_TEMPLATE_GENERAL.format(question=query, language=language)
294
- messages.append({"role": "user", "content": user_prompt})
295
- answer = call_llm(messages, temperature=temperature)
296
- return {"answer": answer, "sources": []}
297
-
298
- else: # Default to the original caregiving logic
299
- # --- Reworked search strategy to handle filters correctly ---
300
-
301
- # 1. Start with a general, unfiltered search to always get text-based matches.
302
- personal_docs = vs_personal.similarity_search(query, k=3)
303
- general_docs = vs_general.similarity_search(query, k=3)
304
-
305
- # 2. Build a filter for simple equality checks (FAISS supported).
306
- simple_search_filter = {}
307
- if scenario_tag and scenario_tag != "None":
308
- simple_search_filter["behaviors"] = scenario_tag.lower()
309
- if emotion_tag and emotion_tag != "None":
310
- simple_search_filter["emotion"] = emotion_tag.lower()
311
- if topic_tag and topic_tag != "None":
312
- simple_search_filter["topic_tags"] = topic_tag.lower()
313
-
314
- # 3. If simple filters exist, perform a second, more specific search.
315
- if simple_search_filter:
316
- print(f"Performing additional search with filter: {simple_search_filter}")
317
- personal_docs.extend(vs_personal.similarity_search(query, k=2, filter=simple_search_filter))
318
- general_docs.extend(vs_general.similarity_search(query, k=2, filter=simple_search_filter))
319
-
320
- # 4. If context_tags exist (unsupported by 'in'), loop through them and perform separate searches.
321
- if context_tags:
322
- print(f"Performing looped context tag search for: {context_tags}")
323
- for tag in context_tags:
324
- context_filter = {"context_tags": tag.lower()}
325
- personal_docs.extend(vs_personal.similarity_search(query, k=1, filter=context_filter))
326
- general_docs.extend(vs_general.similarity_search(query, k=1, filter=context_filter))
327
-
328
- # 5. Combine and de-duplicate all results.
329
- all_docs_care = list({doc.page_content: doc for doc in personal_docs + general_docs}.values())
330
-
331
- # --- End of reworked search strategy ---
332
-
333
- personal_context = _format_docs([d for d in all_docs_care if d in personal_docs], "(No relevant personal memories found.)")
334
- general_context = _format_docs([d for d in all_docs_care if d in general_docs], "(No general guidance found.)")
335
-
336
- first_emotion = None
337
- for doc in all_docs_care:
338
- if "emotion" in doc.metadata and doc.metadata["emotion"]:
339
- emotion_data = doc.metadata["emotion"]
340
- if isinstance(emotion_data, list): first_emotion = emotion_data[0]
341
- else: first_emotion = emotion_data
342
- if first_emotion: break
343
-
344
- emotions_context = render_emotion_guidelines(first_emotion or emotion_tag)
345
- is_tagged_scenario = (scenario_tag and scenario_tag != "None") or (emotion_tag and emotion_tag != "None") or (first_emotion is not None)
346
- template = ANSWER_TEMPLATE_ADQ if is_tagged_scenario else ANSWER_TEMPLATE_CALM
347
-
348
- if template == ANSWER_TEMPLATE_ADQ:
349
- user_prompt = template.format(general_context=general_context, personal_context=personal_context, question=query, scenario_tag=scenario_tag, emotions_context=emotions_context, role=role, language=language)
350
- else:
351
- combined_context = f"General Guidance:\n{general_context}\n\nPersonal Memories:\n{personal_context}"
352
- user_prompt = template.format(context=combined_context, question=query, language=language)
353
-
354
- messages.append({"role": "user", "content": user_prompt})
355
- answer = call_llm(messages, temperature=temperature)
356
-
357
- high_risk_scenarios = ["exit_seeking", "wandering", "elopement"]
358
- if scenario_tag and scenario_tag.lower() in high_risk_scenarios:
359
- answer += f"\n\n---\n{RISK_FOOTER}"
360
-
361
- sources = list(set(d.metadata.get("source", "unknown") for d in all_docs_care))
362
- return {"answer": answer, "sources": sources}
363
-
364
- return _answer_fn
365
-
366
-
367
- def answer_query(chain, question: str, **kwargs) -> Dict[str, Any]:
368
- if not callable(chain): return {"answer": "[Error: RAG chain is not callable]", "sources": []}
369
- try:
370
- return chain(question, **kwargs)
371
- except Exception as e:
372
- print(f"ERROR in answer_query: {e}")
373
- return {"answer": f"[Error executing chain: {e}]", "sources": []}
374
-
375
- # -----------------------------
376
- # TTS & Transcription
377
- # -----------------------------
378
- def synthesize_tts(text: str, lang: str = "en"):
379
- if not text or gTTS is None: return None
380
- try:
381
- with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as fp:
382
- tts = gTTS(text=text, lang=(lang or "en"))
383
- tts.save(fp.name)
384
- return fp.name
385
- except Exception:
386
- return None
387
-
388
- def transcribe_audio(filepath: str, lang: str = "en"):
389
- client = _openai_client()
390
- if not client: return "[Transcription failed: API key not configured]"
391
- api_args = {"model": "whisper-1"}
392
- if lang and lang != "auto": api_args["language"] = lang
393
- with open(filepath, "rb") as audio_file:
394
- transcription = client.audio.transcriptions.create(file=audio_file, **api_args)
395
- return transcription.text
396
-