|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import re |
|
|
from typing import Dict, Any |
|
|
from pathlib import Path |
|
|
from unstructured.partition.pdf import partition_pdf |
|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
from dotenv import load_dotenv |
|
|
from bloatectomy import bloatectomy |
|
|
from werkzeug.utils import secure_filename |
|
|
from langchain_groq import ChatGroq |
|
|
from typing_extensions import TypedDict, NotRequired |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
|
|
logger = logging.getLogger("patient-assistant") |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
if not GROQ_API_KEY: |
|
|
logger.error("GROQ_API_KEY not set in environment") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent |
|
|
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", str(BASE_DIR / "reports"))) |
|
|
static_folder = BASE_DIR / "static" |
|
|
|
|
|
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static") |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
os.makedirs(REPORTS_ROOT, exist_ok=True) |
|
|
|
|
|
|
|
|
llm = ChatGroq( |
|
|
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"), |
|
|
temperature=0.0, |
|
|
max_tokens=1024, |
|
|
api_key=GROQ_API_KEY, |
|
|
) |
|
|
|
|
|
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str: |
|
|
"""Helper function to clean up text using the bloatectomy library.""" |
|
|
try: |
|
|
b = bloatectomy(text, style=style, output="html") |
|
|
tokens = getattr(b, "tokens", None) |
|
|
if not tokens: |
|
|
return text |
|
|
return "\n".join(tokens) |
|
|
except Exception: |
|
|
logger.exception("Bloatectomy cleaning failed; returning original text") |
|
|
return text |
|
|
|
|
|
|
|
|
PATIENT_ASSISTANT_PROMPT = """ |
|
|
You are a patient assistant helping to analyze medical records and reports. Your primary task is to get the patient ID (PID) from the user at the start of the conversation. |
|
|
|
|
|
Once you have the PID, you will be provided with a summary of the patient's medical reports. Use this information, along with the conversation history, to provide a comprehensive response. |
|
|
|
|
|
Your tasks include: |
|
|
- **First, ask for the patient ID.** Do not proceed with any other task until you have the PID. |
|
|
- Analyzing medical records and reports to detect anomalies, redundant tests, or misleading treatments. |
|
|
- Suggesting preventive care based on the overall patient health history. |
|
|
- Optimizing healthcare costs by comparing past visits and treatments. |
|
|
- Offering personalized lifestyle recommendations. |
|
|
- Generating a natural, helpful reply to the user. |
|
|
|
|
|
STRICT OUTPUT FORMAT (JSON ONLY): |
|
|
Return a single JSON object with the following keys: |
|
|
- assistant_reply: string // a natural language reply to the user (short, helpful, always present) |
|
|
- patientDetails: object // keys may include name, problem, pid (patient ID), city, contact (update if user shared info) |
|
|
- conversationSummary: string (optional) // short summary of conversation + relevant patient docs |
|
|
|
|
|
Rules: |
|
|
- ALWAYS include `assistant_reply` as a non-empty string. |
|
|
- Do NOT produce any text outside the JSON object. |
|
|
- Be concise in `assistant_reply`. If you need more details, ask a targeted follow-up question. |
|
|
- Do not make up information that is not present in the provided medical reports or conversation history. |
|
|
""" |
|
|
|
|
|
|
|
|
def extract_json_from_llm_response(raw_response: str) -> dict: |
|
|
"""Safely extracts a JSON object from a string that might contain extra text or markdown.""" |
|
|
default = { |
|
|
"assistant_reply": "I'm sorry — I couldn't understand that. Could you please rephrase?", |
|
|
"patientDetails": {}, |
|
|
"conversationSummary": "", |
|
|
} |
|
|
|
|
|
if not raw_response or not isinstance(raw_response, str): |
|
|
return default |
|
|
|
|
|
|
|
|
m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response) |
|
|
json_string = m.group(1).strip() if m else raw_response |
|
|
|
|
|
|
|
|
first = json_string.find('{') |
|
|
last = json_string.rfind('}') |
|
|
if first == -1 or last == -1 or first >= last: |
|
|
try: |
|
|
return json.loads(json_string) |
|
|
except Exception: |
|
|
logger.warning("Could not locate JSON braces in LLM output. Falling back to default.") |
|
|
return default |
|
|
|
|
|
candidate = json_string[first:last+1] |
|
|
|
|
|
candidate = re.sub(r',\s*(?=[}\]])', '', candidate) |
|
|
|
|
|
try: |
|
|
parsed = json.loads(candidate) |
|
|
except Exception as e: |
|
|
logger.warning("Failed to parse JSON from LLM output: %s", e) |
|
|
return default |
|
|
|
|
|
|
|
|
if isinstance(parsed, dict) and "assistant_reply" in parsed and isinstance(parsed["assistant_reply"], str) and parsed["assistant_reply"].strip(): |
|
|
parsed.setdefault("patientDetails", {}) |
|
|
parsed.setdefault("conversationSummary", "") |
|
|
return parsed |
|
|
else: |
|
|
logger.warning("Parsed JSON missing 'assistant_reply' or invalid format. Returning default.") |
|
|
return default |
|
|
|
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
|
def serve_frontend(): |
|
|
"""Serves the frontend HTML file.""" |
|
|
try: |
|
|
return app.send_static_file("frontend_p.html") |
|
|
except Exception: |
|
|
return "<h3>frontend2.html not found in static/ — please add your frontend2.html there.</h3>", 404 |
|
|
|
|
|
@app.route("/upload_report", methods=["POST"]) |
|
|
def upload_report(): |
|
|
"""Handles the upload of a new PDF report for a specific patient.""" |
|
|
if 'report' not in request.files: |
|
|
return jsonify({"error": "No file part in the request"}), 400 |
|
|
|
|
|
file = request.files['report'] |
|
|
patient_id = request.form.get("patient_id") |
|
|
|
|
|
if file.filename == '' or not patient_id: |
|
|
return jsonify({"error": "No selected file or patient ID"}), 400 |
|
|
|
|
|
if file: |
|
|
filename = secure_filename(file.filename) |
|
|
patient_folder = REPORTS_ROOT / f"p_{patient_id}" |
|
|
os.makedirs(patient_folder, exist_ok=True) |
|
|
file_path = patient_folder / filename |
|
|
file.save(file_path) |
|
|
return jsonify({"message": f"File '{filename}' uploaded successfully for patient ID '{patient_id}'."}), 200 |
|
|
|
|
|
@app.route("/chat", methods=["POST"]) |
|
|
def chat(): |
|
|
"""Handles the chat conversation with the assistant.""" |
|
|
data = request.get_json(force=True) |
|
|
if not isinstance(data, dict): |
|
|
return jsonify({"error": "invalid request body"}), 400 |
|
|
|
|
|
chat_history = data.get("chat_history") or [] |
|
|
patient_state = data.get("patient_state") or {} |
|
|
patient_id = patient_state.get("patientDetails", {}).get("pid") |
|
|
|
|
|
|
|
|
state = patient_state.copy() |
|
|
state["lastUserMessage"] = "" |
|
|
if chat_history: |
|
|
|
|
|
for msg in reversed(chat_history): |
|
|
if msg.get("role") == "user" and msg.get("content"): |
|
|
state["lastUserMessage"] = msg["content"] |
|
|
break |
|
|
|
|
|
combined_text_parts = [] |
|
|
|
|
|
if not patient_id: |
|
|
|
|
|
user_prompt = "Hello. I need to get the patient's ID to proceed." |
|
|
|
|
|
|
|
|
last_message = state.get("lastUserMessage", "") |
|
|
|
|
|
if re.search(r'\d+', last_message): |
|
|
inferred_pid = re.search(r'(\d+)', last_message).group(1) |
|
|
state["patientDetails"] = {"pid": inferred_pid} |
|
|
patient_id = inferred_pid |
|
|
|
|
|
user_prompt = f"The user provided a patient ID: {inferred_pid}. Please access their reports and respond." |
|
|
else: |
|
|
|
|
|
user_prompt = "The patient has not provided a patient ID. Please ask them to provide it to proceed." |
|
|
|
|
|
|
|
|
if patient_id: |
|
|
patient_folder = REPORTS_ROOT / f"p_{patient_id}" |
|
|
if patient_folder.exists() and patient_folder.is_dir(): |
|
|
for fname in sorted(os.listdir(patient_folder)): |
|
|
file_path = patient_folder / fname |
|
|
page_text = "" |
|
|
if partition_pdf is not None and str(file_path).lower().endswith('.pdf'): |
|
|
try: |
|
|
elements = partition_pdf(filename=str(file_path)) |
|
|
page_text = "\n".join([el.text for el in elements if hasattr(el, 'text') and el.text]) |
|
|
except Exception: |
|
|
logger.exception("Failed to parse PDF %s", file_path) |
|
|
else: |
|
|
try: |
|
|
page_text = file_path.read_text(encoding='utf-8', errors='ignore') |
|
|
except Exception: |
|
|
page_text = "" |
|
|
|
|
|
if page_text: |
|
|
cleaned = clean_notes_with_bloatectomy(page_text, style="remov") |
|
|
if cleaned: |
|
|
combined_text_parts.append(cleaned) |
|
|
|
|
|
|
|
|
base_summary = state.get("conversationSummary", "") or "" |
|
|
docs_summary = "\n\n".join(combined_text_parts) |
|
|
if docs_summary: |
|
|
state["conversationSummary"] = (base_summary + "\n\n" + docs_summary).strip() |
|
|
else: |
|
|
state["conversationSummary"] = base_summary |
|
|
|
|
|
|
|
|
user_prompt = f""" |
|
|
Current patientDetails: {json.dumps(state.get("patientDetails", {}))} |
|
|
Current conversationSummary: {state.get("conversationSummary", "")} |
|
|
Last user message: {state.get("lastUserMessage", "")} |
|
|
|
|
|
Return ONLY valid JSON with keys: assistant_reply, patientDetails, conversationSummary. |
|
|
""" |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": PATIENT_ASSISTANT_PROMPT}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
] |
|
|
|
|
|
try: |
|
|
logger.info("Invoking LLM with prepared state and prompt...") |
|
|
llm_response = llm.invoke(messages) |
|
|
raw_response = "" |
|
|
if hasattr(llm_response, "content"): |
|
|
raw_response = llm_response.content |
|
|
else: |
|
|
raw_response = str(llm_response) |
|
|
|
|
|
logger.info(f"Raw LLM response: {raw_response}") |
|
|
parsed_result = extract_json_from_llm_response(raw_response) |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("LLM invocation failed") |
|
|
return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500 |
|
|
|
|
|
updated_state = parsed_result or {} |
|
|
|
|
|
assistant_reply = updated_state.get("assistant_reply") |
|
|
if not assistant_reply or not isinstance(assistant_reply, str) or not assistant_reply.strip(): |
|
|
|
|
|
assistant_reply = "I'm here to help — could you tell me more about your symptoms?" |
|
|
|
|
|
response_payload = { |
|
|
"assistant_reply": assistant_reply, |
|
|
"updated_state": updated_state, |
|
|
} |
|
|
|
|
|
return jsonify(response_payload) |
|
|
|
|
|
@app.route("/ping", methods=["GET"]) |
|
|
def ping(): |
|
|
return jsonify({"status": "ok"}) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", 5000)) |
|
|
app.run(host="0.0.0.0", port=port, debug=True) |