KeenWoo commited on
Commit
adda93b
Β·
verified Β·
1 Parent(s): 5408fc0

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +81 -124
evaluate.py CHANGED
@@ -131,52 +131,63 @@ def _classify_error(gt: str, gen: str) -> str:
131
  return "omission"
132
  return "contradiction"
133
 
 
 
134
 
135
  def run_comprehensive_evaluation(
136
- vs_general: FAISS,
137
- vs_personal: FAISS,
138
- nlu_vectorstore: FAISS,
139
- config: Dict[str, Any]
140
  ):
141
  global test_fixtures
142
  if not test_fixtures:
143
- return "No test fixtures loaded. Please ensure conversation_test_fixtures_v10.jsonl exists.", [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
 
 
 
 
 
 
 
145
  def _norm(label: str) -> str:
146
  label = (label or "").strip().lower()
147
  return "factual_question" if "factual" in label else label
148
 
149
  print("Starting comprehensive evaluation...")
150
  results: List[Dict[str, Any]] = []
151
-
152
- # ADD THESE LINES:
153
  total_fixtures = len(test_fixtures)
154
  print(f"\nπŸš€ STARTING EVALUATION on {total_fixtures} test cases...")
155
 
156
- # In evaluate.py, before the evaluation loop
157
- print("--- DEBUG: Checking personal vector store before evaluation ---")
158
- if vs_personal and hasattr(vs_personal.docstore, '_dict'):
159
- print(f"Personal vector store contains {len(vs_personal.docstore._dict)} documents.")
160
- else:
161
- print("Personal vector store appears to be empty or invalid.")
162
-
163
- # REPLACE the original for loop with this one to get the counter 'i'
164
  for i, fx in enumerate(test_fixtures):
165
- # for fx in test_fixtures:
166
  test_id = fx.get("test_id", "N/A")
167
- # This print statement now works because we have 'i'
168
  print(f"--- Processing Test Case {i+1}/{total_fixtures}: ID = {test_id} ---")
169
 
170
-
171
  turns = fx.get("turns") or []
172
  api_chat_history = [{"role": t.get("role"), "content": t.get("text")} for t in turns]
173
  query = next((t["content"] for t in reversed(api_chat_history) if (t.get("role") or "user").lower() == "user"), "")
174
  if not query: continue
175
 
 
 
176
  ground_truth = fx.get("ground_truth", {})
177
  expected_route = _norm(ground_truth.get("expected_route", "caregiving_scenario"))
178
  expected_tags = ground_truth.get("expected_tags", {})
179
-
180
  actual_route = _norm(route_query_type(query))
181
  route_correct = (actual_route == expected_route)
182
 
@@ -203,25 +214,31 @@ def run_comprehensive_evaluation(
203
  }
204
 
205
  current_test_role = fx.get("test_role", "patient")
206
- rag_chain = make_rag_chain(vs_general, vs_personal, role=current_test_role)
207
-
 
 
 
208
  t0 = time.time()
209
  response = answer_query(rag_chain, query, query_type=actual_route, chat_history=api_chat_history, **final_tags)
210
  latency_ms = round((time.time() - t0) * 1000.0, 1)
211
  answer_text = response.get("answer", "ERROR")
 
 
 
 
212
 
213
  expected_sources_set = set(map(str, ground_truth.get("expected_sources", [])))
214
  raw_sources = response.get("sources", [])
215
  actual_sources_set = set(map(str, raw_sources if isinstance(raw_sources, (list, tuple)) else [raw_sources]))
216
 
217
- # --- START: ADD THIS STRATEGIC PRINT BLOCK ---
218
  print("\n" + "-"*20 + " SOURCE EVALUATION " + "-"*20)
219
  print(f" - Expected: {sorted(list(expected_sources_set))}")
220
  print(f" - Actual: {sorted(list(actual_sources_set))}")
221
 
222
  true_positives = expected_sources_set.intersection(actual_sources_set)
223
  false_positives = actual_sources_set - expected_sources_set
224
- false_negatives = expected_sources_set - actual_sources_set
225
 
226
  if not false_positives and not false_negatives:
227
  print(" - Result: βœ… Perfect Match!")
@@ -231,35 +248,34 @@ def run_comprehensive_evaluation(
231
  if false_negatives:
232
  print(f" - πŸ”» False Negatives (hurts recall): {sorted(list(false_negatives))}")
233
  print("-"*59 + "\n")
234
- # --- END: ADD THIS STRATEGIC PRINT BLOCK ---
235
-
236
  context_precision, context_recall = 0.0, 0.0
237
  if expected_sources_set or actual_sources_set:
238
- true_positives = len(expected_sources_set.intersection(actual_sources_set))
239
- if len(actual_sources_set) > 0: context_precision = true_positives / len(actual_sources_set)
240
- if len(expected_sources_set) > 0: context_recall = true_positives / len(expected_sources_set)
241
  elif not expected_sources_set and not actual_sources_set:
242
  context_precision, context_recall = 1.0, 1.0
243
 
244
- answer_correctness_score = None
245
- ground_truth_answer = ground_truth.get("ground_truth_answer")
246
- error_class = None # initialise #NEW
 
247
 
 
248
  if ground_truth_answer and "ERROR" not in answer_text:
249
  try:
250
  judge_msg = ANSWER_CORRECTNESS_JUDGE_PROMPT.format(ground_truth_answer=ground_truth_answer, generated_answer=answer_text)
 
251
  raw_correctness = call_llm([{"role": "user", "content": judge_msg}], temperature=0.0)
 
252
  correctness_data = _parse_judge_json(raw_correctness)
253
-
254
  if correctness_data and "correctness_score" in correctness_data:
255
  answer_correctness_score = float(correctness_data["correctness_score"])
256
-
257
  except Exception as e:
258
  print(f"ERROR during answer correctness judging: {e}")
259
 
260
- # --- NEW: derive error class for diagnostics ---
261
- error_class = _classify_error(ground_truth_answer, answer_text)
262
-
263
  faithfulness = None
264
  source_docs = response.get("source_documents", [])
265
  if source_docs and "ERROR" not in answer_text:
@@ -279,9 +295,6 @@ def run_comprehensive_evaluation(
279
  sources_pretty = ", ".join(sorted(s)) if (s:=actual_sources_set) else ""
280
  results.append({
281
  "test_id": fx.get("test_id", "N/A"), "title": fx.get("title", "N/A"),
282
- # NEW for debugging
283
- "category": _categorize_test(test_id), "error_class": error_class,
284
- # END
285
  "route_correct": "βœ…" if route_correct else "❌", "expected_route": expected_route, "actual_route": actual_route,
286
  "behavior_f1": f"{behavior_metrics['f1_score']:.2f}", "emotion_f1": f"{emotion_metrics['f1_score']:.2f}",
287
  "topic_f1": f"{topic_metrics['f1_score']:.2f}", "context_f1": f"{context_metrics['f1_score']:.2f}",
@@ -289,110 +302,54 @@ def run_comprehensive_evaluation(
289
  "latency_ms": latency_ms, "faithfulness": faithfulness,
290
  "context_precision": context_precision, "context_recall": context_recall,
291
  "answer_correctness": answer_correctness_score,
 
 
292
  })
293
 
294
  df = pd.DataFrame(results)
295
- output_path = "evaluation_results.csv"
 
296
  if not df.empty:
297
- cols = [
298
- "test_id", "title", "route_correct", "expected_route", "actual_route",
299
- "context_precision", "context_recall", "faithfulness", "answer_correctness",
300
- "behavior_f1", "emotion_f1", "topic_f1", "context_f1",
301
- "source_count", "latency_ms", "sources", "generated_answer"
302
- ]
303
  df = df[[c for c in cols if c in df.columns]]
 
304
  df.to_csv(output_path, index=False, encoding="utf-8")
305
  print(f"Evaluation results saved to {output_path}")
306
 
307
-
308
- # --- NEW: write detailed results to a log file instead of CSV ---
309
- log_path = Path(__file__).parent / "evaluation_log.txt"
310
- with open(log_path, "a", encoding="utf-8") as logf:
311
- logf.write("\n===== Detailed Evaluation Run =====\n")
312
- logf.write(df.to_string(index=False))
313
  logf.write("\n\n")
314
 
315
- # --- NEW: per-category averages ---
316
- try:
317
- cat_means = df.groupby("category")["answer_correctness"].mean().reset_index()
318
- print("\nπŸ“Š Correctness by Category:")
319
- print(cat_means.to_string(index=False))
320
- with open("evaluation_log.txt", "a", encoding="utf-8") as logf:
321
  logf.write("\nπŸ“Š Correctness by Category:\n")
322
  logf.write(cat_means.to_string(index=False))
323
  logf.write("\n")
324
- except Exception as e:
325
- print(f"WARNING: Could not compute category breakdown: {e}")
326
-
327
- # --- NEW: confusion-style matrix ---
328
- try:
329
- confusion = pd.crosstab(df.get("category", []), df.get("error_class", []),
330
- rownames=["Category"], colnames=["Error Class"], dropna=False)
331
- print("\nπŸ“Š Error Class Distribution by Category:")
332
- print(confusion.to_string())
333
- with open("evaluation_log.txt", "a", encoding="utf-8") as logf:
334
  logf.write("\nπŸ“Š Error Class Distribution by Category:\n")
335
  logf.write(confusion.to_string())
336
  logf.write("\n")
337
- except Exception as e:
338
- print(f"WARNING: Could not build confusion matrix: {e}")
339
-
340
-
341
- # NEW: save detailed results
342
- df.to_csv("evaluation_results_detailed.csv", index=False, encoding="utf-8")
343
-
344
- # NEW: per-category averages
345
- try:
346
- cat_means = df.groupby("category")["answer_correctness"].mean().reset_index()
347
- print("\nπŸ“Š Correctness by Category:")
348
- print(cat_means.to_string(index=False))
349
- cat_means.to_csv("evaluation_correctness_by_category.csv", index=False)
350
- except Exception as e:
351
- print(f"WARNING: Could not compute category breakdown: {e}")
352
-
353
- # NEW: confusion-style matrix
354
- try:
355
- confusion = pd.crosstab(df.get("category", []), df.get("error_class", []),
356
- rownames=["Category"], colnames=["Error Class"], dropna=False)
357
- print("\nπŸ“Š Error Class Distribution by Category:")
358
- print(confusion.to_string())
359
- confusion.to_csv("evaluation_confusion_matrix.csv")
360
- except Exception as e:
361
- print(f"WARNING: Could not build confusion matrix: {e}")
362
-
363
 
364
  pct = df["route_correct"].value_counts(normalize=True).get("βœ…", 0) * 100
365
  to_f = lambda s: pd.to_numeric(s, errors="coerce")
366
-
367
- cp_mean = to_f(df["context_precision"]).mean()
368
- cr_mean = to_f(df["context_recall"]).mean()
369
- faith_mean = to_f(df["faithfulness"]).mean()
370
- correct_mean = to_f(df["answer_correctness"]).mean()
371
- rag_with_sources_pct = (df["source_count"] > 0).mean() * 100 if "source_count" in df else 0
372
-
373
- summary_text = f"""
374
- ## Evaluation Summary
375
- - **Routing Accuracy**: {pct:.2f}%
376
- - **Behaviour F1 (avg)**: {(to_f(df["behavior_f1"]).mean() * 100):.2f}%
377
- - **Emotion F1 (avg)**: {(to_f(df["emotion_f1"]).mean() * 100):.2f}%
378
- - **Topic F1 (avg)**: {(to_f(df["topic_f1"]).mean() * 100):.2f}%
379
- - **Context F1 (avg)**: {(to_f(df["context_f1"]).mean() * 100):.2f}%
380
- - **RAG: Context Precision**: {"N/A" if pd.isna(cp_mean) else f'{(cp_mean * 100):.1f}%'}
381
- - **RAG: Context Recall**: {"N/A" if pd.isna(cr_mean) else f'{(cr_mean * 100):.1f}%'}
382
- - **RAG: Faithfulness (LLM-judge)**: {"N/A" if pd.isna(faith_mean) else f'{(faith_mean * 100):.1f}%'}
383
- - **RAG: Answer Correctness (LLM-judge)**: {"N/A" if pd.isna(correct_mean) else f'{(correct_mean * 100):.1f}%'}
384
- - **RAG Answers w/ Sources**: {rag_with_sources_pct:.1f}%
385
- - **RAG: Avg Latency (ms)**: {to_f(df["latency_ms"]).mean():.1f}
386
- """
387
- df_display = df.rename(columns={
388
- "context_precision": "Ctx. Precision", "context_recall": "Ctx. Recall",
389
- "answer_correctness": "Answer Correct.", "faithfulness": "Faithfulness",
390
- "behavior_f1": "Behav. F1", "emotion_f1": "Emo. F1", "topic_f1": "Topic F1", "context_f1": "Ctx. F1"
391
- })
392
  table_rows = df_display.values.tolist()
393
  headers = df_display.columns.tolist()
394
- else:
395
- summary_text = "No valid test fixtures found to evaluate."
396
- table_rows, headers = [], []
397
-
398
  return summary_text, table_rows, headers
 
 
 
131
  return "omission"
132
  return "contradiction"
133
 
134
+ ## NEW
135
+ # In evaluate.py
136
 
137
  def run_comprehensive_evaluation(
138
+ vs_general: "Chroma",
139
+ nlu_vectorstore: "Chroma",
140
+ config: Dict[str, Any],
141
+ storage_path: Path
142
  ):
143
  global test_fixtures
144
  if not test_fixtures:
145
+ # The return signature is now back to 3 items.
146
+ return "No test fixtures loaded.", [], []
147
+
148
+ vs_personal_test = None
149
+ personal_context_docs = []
150
+ personal_context_file = "sample_data/1 Complaints of a Dutiful Daughter.txt"
151
+
152
+ if os.path.exists(personal_context_file):
153
+ print(f"Found personal context file for evaluation: '{personal_context_file}'")
154
+ with open(personal_context_file, "r", encoding="utf-8") as f:
155
+ content = f.read()
156
+ doc = Document(page_content=content, metadata={"source": os.path.basename(personal_context_file)})
157
+ personal_context_docs.append(doc)
158
+ else:
159
+ print(f"WARNING: Personal context file not found at '{personal_context_file}'. Factual tests will likely fail.")
160
 
161
+ vs_personal_test = build_or_load_vectorstore(
162
+ personal_context_docs,
163
+ index_path="tmp/eval_personal_index",
164
+ is_personal=True
165
+ )
166
+ print(f"Successfully created temporary personal vectorstore with {len(personal_context_docs)} document(s) for this evaluation run.")
167
+
168
  def _norm(label: str) -> str:
169
  label = (label or "").strip().lower()
170
  return "factual_question" if "factual" in label else label
171
 
172
  print("Starting comprehensive evaluation...")
173
  results: List[Dict[str, Any]] = []
 
 
174
  total_fixtures = len(test_fixtures)
175
  print(f"\nπŸš€ STARTING EVALUATION on {total_fixtures} test cases...")
176
 
 
 
 
 
 
 
 
 
177
  for i, fx in enumerate(test_fixtures):
 
178
  test_id = fx.get("test_id", "N/A")
 
179
  print(f"--- Processing Test Case {i+1}/{total_fixtures}: ID = {test_id} ---")
180
 
 
181
  turns = fx.get("turns") or []
182
  api_chat_history = [{"role": t.get("role"), "content": t.get("text")} for t in turns]
183
  query = next((t["content"] for t in reversed(api_chat_history) if (t.get("role") or "user").lower() == "user"), "")
184
  if not query: continue
185
 
186
+ print(f'Query: "{query}"')
187
+
188
  ground_truth = fx.get("ground_truth", {})
189
  expected_route = _norm(ground_truth.get("expected_route", "caregiving_scenario"))
190
  expected_tags = ground_truth.get("expected_tags", {})
 
191
  actual_route = _norm(route_query_type(query))
192
  route_correct = (actual_route == expected_route)
193
 
 
214
  }
215
 
216
  current_test_role = fx.get("test_role", "patient")
217
+ rag_chain = make_rag_chain(
218
+ vs_general, vs_personal_test, nlu_vectorstore=nlu_vectorstore,
219
+ config=config, role=current_test_role, for_evaluation=True
220
+ )
221
+
222
  t0 = time.time()
223
  response = answer_query(rag_chain, query, query_type=actual_route, chat_history=api_chat_history, **final_tags)
224
  latency_ms = round((time.time() - t0) * 1000.0, 1)
225
  answer_text = response.get("answer", "ERROR")
226
+ ground_truth_answer = ground_truth.get("ground_truth_answer")
227
+
228
+ category = _categorize_test(test_id)
229
+ error_class = _classify_error(ground_truth_answer, answer_text)
230
 
231
  expected_sources_set = set(map(str, ground_truth.get("expected_sources", [])))
232
  raw_sources = response.get("sources", [])
233
  actual_sources_set = set(map(str, raw_sources if isinstance(raw_sources, (list, tuple)) else [raw_sources]))
234
 
 
235
  print("\n" + "-"*20 + " SOURCE EVALUATION " + "-"*20)
236
  print(f" - Expected: {sorted(list(expected_sources_set))}")
237
  print(f" - Actual: {sorted(list(actual_sources_set))}")
238
 
239
  true_positives = expected_sources_set.intersection(actual_sources_set)
240
  false_positives = actual_sources_set - expected_sources_set
241
+ false_negatives = expected_sources_set - actual_sources_set
242
 
243
  if not false_positives and not false_negatives:
244
  print(" - Result: βœ… Perfect Match!")
 
248
  if false_negatives:
249
  print(f" - πŸ”» False Negatives (hurts recall): {sorted(list(false_negatives))}")
250
  print("-"*59 + "\n")
251
+
 
252
  context_precision, context_recall = 0.0, 0.0
253
  if expected_sources_set or actual_sources_set:
254
+ tp = len(expected_sources_set.intersection(actual_sources_set))
255
+ if len(actual_sources_set) > 0: context_precision = tp / len(actual_sources_set)
256
+ if len(expected_sources_set) > 0: context_recall = tp / len(expected_sources_set)
257
  elif not expected_sources_set and not actual_sources_set:
258
  context_precision, context_recall = 1.0, 1.0
259
 
260
+ print("\n" + "-"*20 + " ANSWER & CORRECTNESS EVALUATION " + "-"*20)
261
+ print(f" - Ground Truth Answer: {ground_truth_answer}")
262
+ print(f" - Generated Answer: {answer_text}")
263
+ print("-" * 59)
264
 
265
+ answer_correctness_score = None
266
  if ground_truth_answer and "ERROR" not in answer_text:
267
  try:
268
  judge_msg = ANSWER_CORRECTNESS_JUDGE_PROMPT.format(ground_truth_answer=ground_truth_answer, generated_answer=answer_text)
269
+ print(f" - Judge Prompt Sent:\n{judge_msg}")
270
  raw_correctness = call_llm([{"role": "user", "content": judge_msg}], temperature=0.0)
271
+ print(f" - Judge Raw Response: {raw_correctness}")
272
  correctness_data = _parse_judge_json(raw_correctness)
 
273
  if correctness_data and "correctness_score" in correctness_data:
274
  answer_correctness_score = float(correctness_data["correctness_score"])
275
+ print(f" - Final Score: {answer_correctness_score}")
276
  except Exception as e:
277
  print(f"ERROR during answer correctness judging: {e}")
278
 
 
 
 
279
  faithfulness = None
280
  source_docs = response.get("source_documents", [])
281
  if source_docs and "ERROR" not in answer_text:
 
295
  sources_pretty = ", ".join(sorted(s)) if (s:=actual_sources_set) else ""
296
  results.append({
297
  "test_id": fx.get("test_id", "N/A"), "title": fx.get("title", "N/A"),
 
 
 
298
  "route_correct": "βœ…" if route_correct else "❌", "expected_route": expected_route, "actual_route": actual_route,
299
  "behavior_f1": f"{behavior_metrics['f1_score']:.2f}", "emotion_f1": f"{emotion_metrics['f1_score']:.2f}",
300
  "topic_f1": f"{topic_metrics['f1_score']:.2f}", "context_f1": f"{context_metrics['f1_score']:.2f}",
 
302
  "latency_ms": latency_ms, "faithfulness": faithfulness,
303
  "context_precision": context_precision, "context_recall": context_recall,
304
  "answer_correctness": answer_correctness_score,
305
+ "category": category,
306
+ "error_class": error_class
307
  })
308
 
309
  df = pd.DataFrame(results)
310
+ summary_text, table_rows, headers = "No valid test fixtures found to evaluate.", [], []
311
+
312
  if not df.empty:
313
+ cols = ["test_id", "title", "route_correct", "expected_route", "actual_route", "context_precision", "context_recall", "faithfulness", "answer_correctness", "behavior_f1", "emotion_f1", "topic_f1", "context_f1", "source_count", "latency_ms", "sources", "generated_answer", "category", "error_class"]
 
 
 
 
 
314
  df = df[[c for c in cols if c in df.columns]]
315
+ output_path = "evaluation_results.csv"
316
  df.to_csv(output_path, index=False, encoding="utf-8")
317
  print(f"Evaluation results saved to {output_path}")
318
 
319
+ log_path = storage_path / "evaluation_log.txt"
320
+ with open(log_path, "w", encoding="utf-8") as logf:
321
+ logf.write("===== Detailed Evaluation Run =====\n")
322
+ df_string = df.to_string(index=False)
323
+ logf.write(df_string)
 
324
  logf.write("\n\n")
325
 
326
+ try:
327
+ cat_means = df.groupby("category")["answer_correctness"].mean().reset_index()
328
+ print("\nπŸ“Š Correctness by Category:")
329
+ print(cat_means.to_string(index=False))
 
 
330
  logf.write("\nπŸ“Š Correctness by Category:\n")
331
  logf.write(cat_means.to_string(index=False))
332
  logf.write("\n")
333
+ except Exception as e:
334
+ print(f"WARNING: Could not compute category breakdown: {e}")
335
+
336
+ try:
337
+ confusion = pd.crosstab(df["category"], df["error_class"], rownames=["Category"], colnames=["Error Class"], dropna=False)
338
+ print("\nπŸ“Š Error Class Distribution by Category:")
339
+ print(confusion.to_string())
 
 
 
340
  logf.write("\nπŸ“Š Error Class Distribution by Category:\n")
341
  logf.write(confusion.to_string())
342
  logf.write("\n")
343
+ except Exception as e:
344
+ print(f"WARNING: Could not build confusion matrix: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  pct = df["route_correct"].value_counts(normalize=True).get("βœ…", 0) * 100
347
  to_f = lambda s: pd.to_numeric(s, errors="coerce")
348
+ summary_text = f"""## Evaluation Summary\n- **Routing Accuracy**: {pct:.2f}%\n- **RAG: Context Precision**: {(to_f(df["context_precision"]).mean() * 100):.1f}%\n- **RAG: Context Recall**: {(to_f(df["context_recall"]).mean() * 100):.1f}%\n- **RAG: Answer Correctness (LLM-judge)**: {(to_f(df["answer_correctness"]).mean() * 100):.1f}%"""
349
+ df_display = df.rename(columns={"context_precision": "Ctx. Precision", "context_recall": "Ctx. Recall"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  table_rows = df_display.values.tolist()
351
  headers = df_display.columns.tolist()
352
+
 
 
 
353
  return summary_text, table_rows, headers
354
+
355
+ ## END