|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
from dotenv import load_dotenv |
|
|
from unstructured.partition.pdf import partition_pdf |
|
|
|
|
|
|
|
|
from bloatectomy import bloatectomy |
|
|
|
|
|
|
|
|
from langchain_groq import ChatGroq |
|
|
from langgraph.prebuilt import create_react_agent |
|
|
|
|
|
|
|
|
from langgraph.graph import StateGraph, START, END |
|
|
from typing_extensions import TypedDict, NotRequired |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
|
|
logger = logging.getLogger("health-agent") |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
from pathlib import Path |
|
|
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", "reports")).resolve() |
|
|
SSRI_FILE = Path(os.getenv("SSRI_FILE", "app/medicationCategories/SSRI_list.txt")).resolve() |
|
|
MISC_FILE = Path(os.getenv("MISC_FILE", "app/medicationCategories/MISC_list.txt")).resolve() |
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY", None) |
|
|
|
|
|
|
|
|
llm = ChatGroq( |
|
|
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"), |
|
|
temperature=0.0, |
|
|
max_tokens=None, |
|
|
) |
|
|
|
|
|
|
|
|
NODE_BASE_INSTRUCTIONS = """ |
|
|
You are HealthAI — a clinical assistant producing JSON for downstream processing. |
|
|
Produce only valid JSON (no extra text). Follow field types exactly. If missing data, return empty strings or empty arrays. |
|
|
Be conservative: do not assert diagnoses; provide suggestions and ask physician confirmation where needed. |
|
|
""" |
|
|
|
|
|
|
|
|
agent = create_react_agent(model=llm, tools=[], prompt=NODE_BASE_INSTRUCTIONS) |
|
|
agent_json_resolver = create_react_agent(model=llm, tools=[], prompt=""" |
|
|
You are a JSON fixer. Input: a possibly-malformed JSON-like text. Output: valid JSON only (enclosed in triple backticks). |
|
|
Fix missing quotes, trailing commas, unescaped newlines, stray assistant labels, and ensure schema compliance. |
|
|
""") |
|
|
|
|
|
|
|
|
def extract_json_from_llm_response(raw_response: str) -> dict: |
|
|
try: |
|
|
|
|
|
md = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response) |
|
|
json_string = md.group(1).strip() if md else raw_response |
|
|
|
|
|
|
|
|
first, last = json_string.find('{'), json_string.rfind('}') |
|
|
if 0 <= first < last: |
|
|
json_string = json_string[first:last+1] |
|
|
|
|
|
|
|
|
json_string = re.sub(r'\b\w+\s*{', '{', json_string) |
|
|
json_string = re.sub(r'"assistant"\s*:', '', json_string) |
|
|
json_string = re.sub(r'\b(false|true)"', r'\1', json_string) |
|
|
|
|
|
|
|
|
def _esc(m): |
|
|
prefix, body = m.group(1), m.group(2) |
|
|
return prefix + body.replace('"', r'\"') |
|
|
json_string = re.sub( |
|
|
r'("logic"\s*:\s*")([\s\S]+?)(?=",\s*"[A-Za-z_]\w*"\s*:\s*)', |
|
|
_esc, |
|
|
json_string |
|
|
) |
|
|
|
|
|
|
|
|
json_string = re.sub(r',\s*(?=[}\],])', '', json_string) |
|
|
json_string = re.sub(r',\s*,', ',', json_string) |
|
|
|
|
|
|
|
|
ob, cb = json_string.count('{'), json_string.count('}') |
|
|
if cb > ob: |
|
|
excess = cb - ob |
|
|
json_string = json_string.rstrip()[:-excess] |
|
|
|
|
|
|
|
|
def _escape_newlines_in_strings(s: str) -> str: |
|
|
return re.sub( |
|
|
r'"((?:[^"\\]|\\.)*?)"', |
|
|
lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\r', '\\r') + '"', |
|
|
s, |
|
|
flags=re.DOTALL |
|
|
) |
|
|
json_string = _escape_newlines_in_strings(json_string) |
|
|
|
|
|
|
|
|
return json.loads(json_string) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to extract JSON from LLM response: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str: |
|
|
try: |
|
|
b = bloatectomy(text, style=style, output="html") |
|
|
tokens = getattr(b, "tokens", None) |
|
|
if not tokens: |
|
|
return text |
|
|
return "\n".join(tokens) |
|
|
except Exception: |
|
|
logger.exception("Bloatectomy cleaning failed; returning original text") |
|
|
return text |
|
|
|
|
|
|
|
|
def readDrugs_from_file(path: Path): |
|
|
try: |
|
|
if not path.exists(): |
|
|
return {}, [] |
|
|
txt = path.read_text(encoding="utf-8", errors="ignore") |
|
|
generics = re.findall(r"^(.*?)\|", txt, re.MULTILINE) |
|
|
generics = [g.lower() for g in generics if g] |
|
|
lines = [ln.strip().lower() for ln in txt.splitlines() if ln.strip()] |
|
|
return dict(zip(generics, lines)), generics |
|
|
except Exception: |
|
|
logger.exception(f"Failed to read drugs from file: {path}") |
|
|
return {}, [] |
|
|
|
|
|
def addToDrugs_line(line: str, drugs_flags: List[int], listing: Dict[str,str], genList: List[str]) -> List[int]: |
|
|
try: |
|
|
gen_index = {g:i for i,g in enumerate(genList)} |
|
|
for generic, pattern_line in listing.items(): |
|
|
try: |
|
|
if re.search(pattern_line, line, re.I): |
|
|
idx = gen_index.get(generic) |
|
|
if idx is not None: |
|
|
drugs_flags[idx] = 1 |
|
|
except re.error: |
|
|
continue |
|
|
return drugs_flags |
|
|
except Exception: |
|
|
logger.exception("Error in addToDrugs_line") |
|
|
return drugs_flags |
|
|
|
|
|
def extract_medications_from_text(text: str) -> List[str]: |
|
|
try: |
|
|
ssri_map, ssri_generics = readDrugs_from_file(SSRI_FILE) |
|
|
misc_map, misc_generics = readDrugs_from_file(MISC_FILE) |
|
|
combined_map = {**ssri_map, **misc_map} |
|
|
combined_generics = [] |
|
|
if ssri_generics: |
|
|
combined_generics.extend(ssri_generics) |
|
|
if misc_generics: |
|
|
combined_generics.extend(misc_generics) |
|
|
|
|
|
flags = [0]* len(combined_generics) |
|
|
meds_found = set() |
|
|
for ln in text.splitlines(): |
|
|
ln = ln.strip() |
|
|
if not ln: |
|
|
continue |
|
|
if combined_map: |
|
|
flags = addToDrugs_line(ln, flags, combined_map, combined_generics) |
|
|
m = re.search(r"\b(Rx|Drug|Medication|Prescribed|Tablet)\s*[:\-]?\s*([A-Za-z0-9\-\s/\.]+)", ln, re.I) |
|
|
if m: |
|
|
meds_found.add(m.group(2).strip()) |
|
|
m2 = re.findall(r"\b([A-Z][a-z0-9\-]{2,}\s*(?:[0-9]{1,4}\s*(?:mg|mcg|g|IU))?)", ln) |
|
|
for s in m2: |
|
|
if re.search(r"\b(mg|mcg|g|IU)\b", s, re.I): |
|
|
meds_found.add(s.strip()) |
|
|
for i, f in enumerate(flags): |
|
|
if f == 1: |
|
|
meds_found.add(combined_generics[i]) |
|
|
return list(meds_found) |
|
|
except Exception: |
|
|
logger.exception("Failed to extract medications from text") |
|
|
return [] |
|
|
|
|
|
|
|
|
PATIENT_NODE_PROMPT = """ |
|
|
You will extract patientDetails from the provided document texts. |
|
|
Return ONLY JSON with this exact shape: |
|
|
{ "patientDetails": {"name": "", "age": "", "sex": "", "pid": ""} } |
|
|
Fill fields using text evidence or leave empty strings. |
|
|
""" |
|
|
|
|
|
DOCTOR_NODE_PROMPT = """ |
|
|
You will extract doctorDetails found in the documents. |
|
|
Return ONLY JSON with this exact shape: |
|
|
{ "doctorDetails": {"referredBy": ""} } |
|
|
""" |
|
|
|
|
|
TEST_REPORT_NODE_PROMPT = """ |
|
|
You will extract per-test structured results from the documents. |
|
|
Return ONLY JSON with this exact shape: |
|
|
{ |
|
|
"reports": [ |
|
|
{ |
|
|
"testName": "", |
|
|
"dateReported": "", |
|
|
"timeReported": "", |
|
|
"abnormalFindings": [ |
|
|
{"investigation": "", "result": 0, "unit": "", "status": "", "referenceValue": ""} |
|
|
], |
|
|
"interpretation": "", |
|
|
"trends": [] |
|
|
} |
|
|
] |
|
|
} |
|
|
- Include only findings that are outside reference ranges OR explicitly called 'abnormal' in the report. |
|
|
- For result numeric parsing, prefer numeric values; if not numeric, keep original string. |
|
|
- Use statuses: Low, High, Borderline, Positive, Negative, Normal. |
|
|
""" |
|
|
|
|
|
ANALYSIS_NODE_PROMPT = """ |
|
|
You will create an overallAnalysis based on the extracted reports (the agent will give you the 'reports' JSON). |
|
|
Return ONLY JSON: |
|
|
{ "overallAnalysis": { "summary": "", "recommendations": "", "longTermTrends": "",""risk_prediction": "","drug_interaction": "" } } |
|
|
Be conservative, evidence-based, and suggest follow-up steps for physicians. |
|
|
""" |
|
|
|
|
|
CONDITION_LOOP_NODE_PROMPT = """ |
|
|
Validation and condition node: |
|
|
Input: partial JSON (patientDetails, doctorDetails, reports, overallAnalysis). |
|
|
Task: Check required keys exist and that each report has at least testName and abnormalFindings list. |
|
|
Return ONLY JSON: |
|
|
{ "valid": true, "missing": [] } |
|
|
If missing fields, list keys in 'missing'. Do NOT modify content. |
|
|
""" |
|
|
|
|
|
|
|
|
def call_node_agent(node_prompt: str, payload: dict) -> dict: |
|
|
""" |
|
|
Call the generic agent with a targeted node prompt and the payload. |
|
|
Tries to parse JSON. If parsing fails, uses the JSON resolver agent once. |
|
|
""" |
|
|
try: |
|
|
content = { |
|
|
"prompt": node_prompt, |
|
|
"payload": payload |
|
|
} |
|
|
resp = agent.invoke({"messages": [{"role": "user", "content": json.dumps(content)}]}) |
|
|
|
|
|
|
|
|
raw = None |
|
|
if isinstance(resp, str): |
|
|
raw = resp |
|
|
elif hasattr(resp, "content"): |
|
|
raw = resp.content |
|
|
elif isinstance(resp, dict): |
|
|
msgs = resp.get("messages") |
|
|
if msgs: |
|
|
last_msg = msgs[-1] |
|
|
if isinstance(last_msg, str): |
|
|
raw = last_msg |
|
|
elif hasattr(last_msg, "content"): |
|
|
raw = last_msg.content |
|
|
elif isinstance(last_msg, dict): |
|
|
raw = last_msg.get("content", "") |
|
|
else: |
|
|
raw = str(last_msg) |
|
|
else: |
|
|
raw = json.dumps(resp) |
|
|
else: |
|
|
raw = str(resp) |
|
|
|
|
|
parsed = extract_json_from_llm_response(raw) |
|
|
return parsed |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning("Node agent JSON parse failed: %s. Attempting JSON resolver.", e) |
|
|
try: |
|
|
resolver_prompt = f"Fix this JSON. Input:\n```json\n{raw}\n```\nReturn valid JSON only." |
|
|
r = agent_json_resolver.invoke({"messages": [{"role": "user", "content": resolver_prompt}]}) |
|
|
|
|
|
rtxt = None |
|
|
if isinstance(r, str): |
|
|
rtxt = r |
|
|
elif hasattr(r, "content"): |
|
|
rtxt = r.content |
|
|
elif isinstance(r, dict): |
|
|
msgs = r.get("messages") |
|
|
if msgs: |
|
|
last_msg = msgs[-1] |
|
|
if isinstance(last_msg, str): |
|
|
rtxt = last_msg |
|
|
elif hasattr(last_msg, "content"): |
|
|
rtxt = last_msg.content |
|
|
elif isinstance(last_msg, dict): |
|
|
rtxt = last_msg.get("content", "") |
|
|
else: |
|
|
rtxt = str(last_msg) |
|
|
else: |
|
|
rtxt = json.dumps(r) |
|
|
else: |
|
|
rtxt = str(r) |
|
|
|
|
|
corrected = extract_json_from_llm_response(rtxt) |
|
|
return corrected |
|
|
except Exception as e2: |
|
|
logger.exception("JSON resolver also failed: %s", e2) |
|
|
return {} |
|
|
|
|
|
|
|
|
class State(TypedDict): |
|
|
patient_meta: NotRequired[Dict[str, Any]] |
|
|
patient_id: str |
|
|
documents: List[Dict[str, Any]] |
|
|
medications: List[str] |
|
|
patientDetails: NotRequired[Dict[str, Any]] |
|
|
doctorDetails: NotRequired[Dict[str, Any]] |
|
|
reports: NotRequired[List[Dict[str, Any]]] |
|
|
overallAnalysis: NotRequired[Dict[str, Any]] |
|
|
valid: NotRequired[bool] |
|
|
missing: NotRequired[List[str]] |
|
|
|
|
|
|
|
|
def patient_details_node(state: State) -> dict: |
|
|
payload = { |
|
|
"patient_meta": state.get("patient_meta", {}), |
|
|
"documents": state.get("documents", []), |
|
|
"medications": state.get("medications", []) |
|
|
} |
|
|
logger.info("Running patient_details_node") |
|
|
out = call_node_agent(PATIENT_NODE_PROMPT, payload) |
|
|
return {"patientDetails": out.get("patientDetails", {}) if isinstance(out, dict) else {}} |
|
|
|
|
|
def doctor_details_node(state: State) -> dict: |
|
|
payload = { |
|
|
"documents": state.get("documents", []), |
|
|
"medications": state.get("medications", []) |
|
|
} |
|
|
logger.info("Running doctor_details_node") |
|
|
out = call_node_agent(DOCTOR_NODE_PROMPT, payload) |
|
|
return {"doctorDetails": out.get("doctorDetails", {}) if isinstance(out, dict) else {}} |
|
|
|
|
|
def test_report_node(state: State) -> dict: |
|
|
payload = { |
|
|
"documents": state.get("documents", []), |
|
|
"medications": state.get("medications", []) |
|
|
} |
|
|
logger.info("Running test_report_node") |
|
|
out = call_node_agent(TEST_REPORT_NODE_PROMPT, payload) |
|
|
return {"reports": out.get("reports", []) if isinstance(out, dict) else []} |
|
|
|
|
|
def analysis_node(state: State) -> dict: |
|
|
payload = { |
|
|
"patientDetails": state.get("patientDetails", {}), |
|
|
"doctorDetails": state.get("doctorDetails", {}), |
|
|
"reports": state.get("reports", []), |
|
|
"medications": state.get("medications", []) |
|
|
} |
|
|
logger.info("Running analysis_node") |
|
|
out = call_node_agent(ANALYSIS_NODE_PROMPT, payload) |
|
|
return {"overallAnalysis": out.get("overallAnalysis", {}) if isinstance(out, dict) else {}} |
|
|
|
|
|
def condition_loop_node(state: State) -> dict: |
|
|
payload = { |
|
|
"patientDetails": state.get("patientDetails", {}), |
|
|
"doctorDetails": state.get("doctorDetails", {}), |
|
|
"reports": state.get("reports", []), |
|
|
"overallAnalysis": state.get("overallAnalysis", {}) |
|
|
} |
|
|
logger.info("Running condition_loop_node (validation)") |
|
|
out = call_node_agent(CONDITION_LOOP_NODE_PROMPT, payload) |
|
|
if isinstance(out, dict) and "valid" in out: |
|
|
return {"valid": bool(out.get("valid")), "missing": out.get("missing", [])} |
|
|
missing = [] |
|
|
if not state.get("patientDetails"): |
|
|
missing.append("patientDetails") |
|
|
if not state.get("reports"): |
|
|
missing.append("reports") |
|
|
return {"valid": len(missing) == 0, "missing": missing} |
|
|
|
|
|
|
|
|
graph_builder = StateGraph(State) |
|
|
|
|
|
graph_builder.add_node("patient_details", patient_details_node) |
|
|
graph_builder.add_node("doctor_details", doctor_details_node) |
|
|
graph_builder.add_node("test_report", test_report_node) |
|
|
graph_builder.add_node("analysis", analysis_node) |
|
|
graph_builder.add_node("condition_loop", condition_loop_node) |
|
|
|
|
|
graph_builder.add_edge(START, "patient_details") |
|
|
graph_builder.add_edge("patient_details", "doctor_details") |
|
|
graph_builder.add_edge("doctor_details", "test_report") |
|
|
graph_builder.add_edge("test_report", "analysis") |
|
|
graph_builder.add_edge("analysis", "condition_loop") |
|
|
graph_builder.add_edge("condition_loop", END) |
|
|
|
|
|
graph = graph_builder.compile() |
|
|
|
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent |
|
|
static_folder = BASE_DIR / "static" |
|
|
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static") |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
|
def serve_frontend(): |
|
|
try: |
|
|
return app.send_static_file("frontend.html") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to serve frontend.html: {e}") |
|
|
return "<h3>frontend.html not found in static/ — drop your frontend.html there.</h3>", 404 |
|
|
|
|
|
@app.route("/process_reports", methods=["POST"]) |
|
|
def process_reports(): |
|
|
try: |
|
|
data = request.get_json(force=True) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to parse JSON request: {e}") |
|
|
return jsonify({"error": "Invalid JSON request"}), 400 |
|
|
|
|
|
patient_id = data.get("patient_id") |
|
|
filenames = data.get("filenames", []) |
|
|
extra_patient_meta = data.get("patientDetails", {}) |
|
|
|
|
|
if not patient_id or not filenames: |
|
|
return jsonify({"error": "missing patient_id or filenames"}), 400 |
|
|
|
|
|
patient_folder = REPORTS_ROOT / str(patient_id) |
|
|
if not patient_folder.exists() or not patient_folder.is_dir(): |
|
|
return jsonify({"error": f"patient folder not found: {patient_folder}"}), 404 |
|
|
|
|
|
documents = [] |
|
|
combined_text_parts = [] |
|
|
|
|
|
for fname in filenames: |
|
|
file_path = patient_folder / fname |
|
|
if not file_path.exists(): |
|
|
logger.warning("file not found: %s", file_path) |
|
|
continue |
|
|
try: |
|
|
elements = partition_pdf(filename=str(file_path)) |
|
|
page_text = "\n".join([el.text for el in elements if hasattr(el, "text") and el.text]) |
|
|
except Exception: |
|
|
logger.exception(f"Failed to parse PDF {file_path}") |
|
|
page_text = "" |
|
|
try: |
|
|
cleaned = clean_notes_with_bloatectomy(page_text, style="remov") |
|
|
except Exception: |
|
|
logger.exception("Failed to clean notes with bloatectomy") |
|
|
cleaned = page_text |
|
|
documents.append({ |
|
|
"filename": fname, |
|
|
"raw_text": page_text, |
|
|
"cleaned_text": cleaned |
|
|
}) |
|
|
combined_text_parts.append(cleaned) |
|
|
|
|
|
if not documents: |
|
|
return jsonify({"error": "no valid documents found"}), 400 |
|
|
|
|
|
combined_text = "\n\n".join(combined_text_parts) |
|
|
try: |
|
|
meds = extract_medications_from_text(combined_text) |
|
|
except Exception: |
|
|
logger.exception("Failed to extract medications") |
|
|
meds = [] |
|
|
|
|
|
initial_state = { |
|
|
"patient_meta": extra_patient_meta, |
|
|
"patient_id": patient_id, |
|
|
"documents": documents, |
|
|
"medications": meds |
|
|
} |
|
|
|
|
|
try: |
|
|
result_state = graph.invoke(initial_state) |
|
|
|
|
|
|
|
|
if not result_state.get("valid", True): |
|
|
missing = result_state.get("missing", []) |
|
|
logger.info(f"Validation failed; missing keys: {missing}") |
|
|
if "patientDetails" in missing: |
|
|
result_state["patientDetails"] = extra_patient_meta or {"name": "", "age": "", "sex": "", "pid": patient_id} |
|
|
if "reports" in missing: |
|
|
result_state["reports"] = [] |
|
|
|
|
|
result_state.update(analysis_node(result_state)) |
|
|
|
|
|
cond = condition_loop_node(result_state) |
|
|
result_state.update(cond) |
|
|
|
|
|
safe_response = { |
|
|
"patientDetails": result_state.get("patientDetails", {"name": "", "age": "", "sex": "", "pid": patient_id}), |
|
|
"doctorDetails": result_state.get("doctorDetails", {"referredBy": ""}), |
|
|
"reports": result_state.get("reports", []), |
|
|
"overallAnalysis": result_state.get("overallAnalysis", {"summary": "", "recommendations": "", "longTermTrends": ""}), |
|
|
"_pre_extracted_medications": result_state.get("medications", []), |
|
|
"_validation": { |
|
|
"valid": result_state.get("valid", True), |
|
|
"missing": result_state.get("missing", []) |
|
|
} |
|
|
} |
|
|
return jsonify(safe_response), 200 |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Node pipeline failed") |
|
|
return jsonify({"error": "Node pipeline failed", "detail": str(e)}), 500 |
|
|
|
|
|
@app.route("/ping", methods=["GET"]) |
|
|
def ping(): |
|
|
return jsonify({"status": "ok"}) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", 7860)) |
|
|
app.run(host="0.0.0.0", port=port, debug=True) |
|
|
|
|
|
|