Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import uuid | |
| import numpy as np | |
| from datetime import datetime | |
| from flask import Flask, request, jsonify, send_from_directory | |
| from flask_cors import CORS | |
| from werkzeug.utils import secure_filename | |
| import google.generativeai as genai | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| import faiss | |
| import markdown | |
| # Configuration | |
| GEMINI_API_KEY = ( | |
| "AIzaSyBbb8rH6ksakMg_v2W6hvUNzgHDI3lxWk0" # Replace with your actual API key | |
| ) | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| # Initialize Flask app | |
| app = Flask(__name__, static_folder="../frontend", static_url_path="") | |
| CORS(app) | |
| # RAG Model Initialization | |
| print("π Initializing RAG System...") | |
| # Load medical guidelines dataset | |
| print("π Loading dataset...") | |
| dataset = load_dataset("epfl-llm/guidelines", split="train") | |
| TITLE_COL = "title" | |
| CONTENT_COL = "clean_text" | |
| # Initialize models | |
| print("π€ Loading AI models...") | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| qa_pipeline = pipeline( | |
| "question-answering", model="distilbert-base-cased-distilled-squad" | |
| ) | |
| # Build FAISS index | |
| print("π Building FAISS index...") | |
| def embed_text(batch): | |
| combined_texts = [ | |
| f"{title} {content[:200]}" | |
| for title, content in zip(batch[TITLE_COL], batch[CONTENT_COL]) | |
| ] | |
| return {"embeddings": embedder.encode(combined_texts, show_progress_bar=False)} | |
| dataset = dataset.map(embed_text, batched=True, batch_size=32) | |
| dataset.add_faiss_index(column="embeddings") | |
| # Processing Functions | |
| def format_response(text): | |
| """Convert Markdown text to HTML for proper frontend display.""" | |
| return markdown.markdown(text) | |
| def summarize_report(report): | |
| """Generate a clinical summary using QA and Gemini model.""" | |
| questions = [ | |
| "Patient's age?", | |
| "Patient's gender?", | |
| "Current symptoms?", | |
| "Medical history?", | |
| ] | |
| answers = [] | |
| for q in questions: | |
| result = qa_pipeline(question=q, context=report) | |
| answers.append(result["answer"] if result["score"] > 0.1 else "Not specified") | |
| model = genai.GenerativeModel("gemini-1.5-flash") | |
| prompt = f"""Create clinical summary from: | |
| - Age: {answers[0]} | |
| - Gender: {answers[1]} | |
| - Symptoms: {answers[2]} | |
| - History: {answers[3]} | |
| Format: "[Age] [Gender] with [History], presenting with [Symptoms]" | |
| Add relevant medical context.""" | |
| summary = model.generate_content(prompt).text.strip() | |
| print(f"Generated Summary: {summary}") # Debugging log | |
| return format_response(summary) | |
| def rag_retrieval(query, k=3): | |
| """Retrieve relevant guidelines using FAISS.""" | |
| query_embedding = embedder.encode([query]) | |
| scores, examples = dataset.get_nearest_examples("embeddings", query_embedding, k=k) | |
| return [ | |
| { | |
| "title": title, | |
| "content": content[:1000], | |
| "source": examples.get("source", ["N/A"] * len(examples[TITLE_COL]))[i], | |
| "score": float(score), | |
| } | |
| for i, (title, content, score) in enumerate( | |
| zip(examples[TITLE_COL], examples[CONTENT_COL], scores) | |
| ) | |
| ] | |
| def generate_recommendations(report): | |
| """Generate treatment recommendations with RAG context.""" | |
| guidelines = rag_retrieval(report) | |
| context = "Relevant Clinical Guidelines:\n" + "\n".join( | |
| [f"β’ {g['title']}: {g['content']} [Source: {g['source']}]" for g in guidelines] | |
| ) | |
| model = genai.GenerativeModel("gemini-1.5-flash") | |
| prompt = f"""Generate treatment recommendations using these guidelines: | |
| {context} | |
| Patient Presentation: | |
| {report} | |
| Format with: | |
| - Bold section headers | |
| - Clear bullet points | |
| - Evidence markers [Guideline #] | |
| - Risk-benefit analysis | |
| - Include references to the sources provided where applicable | |
| """ | |
| recommendations = model.generate_content(prompt).text.strip() | |
| references = [g["source"] for g in guidelines if g["source"] != "N/A"] | |
| return format_response(recommendations), references | |
| def generate_risk_assessment(summary): | |
| """Generate risk assessment using the summary.""" | |
| model = genai.GenerativeModel("gemini-1.5-flash") | |
| prompt = f"""Analyze clinical risk: | |
| {summary} | |
| Output format: | |
| Risk Score: 0-100 | |
| Alert Level: π΄ High/π‘ Medium/π’ Low | |
| Key Risk Factors: bullet points | |
| Recommended Actions: bullet points""" | |
| return format_response(model.generate_content(prompt).text.strip()) | |
| # Flask Endpoints | |
| def handle_upload(): | |
| """Handle text file upload and return processed data.""" | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file provided"}), 400 | |
| file = request.files["file"] | |
| if not file or not file.filename.endswith(".txt"): | |
| return jsonify({"error": "Invalid file, must be a .txt file"}), 400 | |
| try: | |
| content = file.read().decode("utf-8") | |
| if not content.strip(): | |
| return jsonify({"error": "File is empty"}), 400 | |
| summary = summarize_report(content) | |
| recommendations, references = generate_recommendations(content) | |
| risk_assessment = generate_risk_assessment(summary) | |
| response = { | |
| "session_id": str(uuid.uuid4()), | |
| "timestamp": datetime.now().isoformat(), | |
| "summary": summary, | |
| "recommendations": recommendations, | |
| "risk_assessment": risk_assessment, | |
| "references": references, | |
| } | |
| print( | |
| f"Response Sent to Frontend: {json.dumps(response, indent=2)}" | |
| ) # Debugging log | |
| return jsonify(response) | |
| except Exception as e: | |
| return jsonify({"error": f"Processing failed: {str(e)}"}), 500 | |
| def serve_index(): | |
| """Serve the index.html file.""" | |
| return send_from_directory(app.static_folder, "index.html") | |
| def serve_static(path): | |
| """Serve other static files from the frontend directory.""" | |
| return send_from_directory(app.static_folder, path) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=5000, debug=True) | |