Spaces:
Sleeping
Sleeping
create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,646 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import sqlite3
|
| 3 |
+
import bcrypt
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import re
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
|
| 10 |
+
import os
|
| 11 |
+
import logging
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
from openai import OpenAI
|
| 14 |
+
load_dotenv() # Loads .env file
|
| 15 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 16 |
+
import json
|
| 17 |
+
from fpdf import FPDF
|
| 18 |
+
|
| 19 |
+
# --------------------------
|
| 20 |
+
# Environment Setup
|
| 21 |
+
# --------------------------
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
print("Using device:", device)
|
| 25 |
+
|
| 26 |
+
# --------------------------
|
| 27 |
+
# Global Tokenizer and Hybrid Model for Treatment Prediction
|
| 28 |
+
# --------------------------
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class HybridMentalHealthModel(nn.Module):
|
| 33 |
+
def __init__(self, bert_model, num_genders, num_medications, num_therapies, hidden_size=128):
|
| 34 |
+
super(HybridMentalHealthModel, self).__init__()
|
| 35 |
+
self.bert = AutoModel.from_pretrained(bert_model)
|
| 36 |
+
bert_output_size = self.bert.config.hidden_size
|
| 37 |
+
self.age_fc = nn.Linear(1, 16)
|
| 38 |
+
self.gender_fc = nn.Embedding(num_genders, 16)
|
| 39 |
+
self.fc = nn.Linear(bert_output_size + 32, hidden_size)
|
| 40 |
+
self.medication_head = nn.Linear(hidden_size, num_medications)
|
| 41 |
+
self.therapy_head = nn.Linear(hidden_size, num_therapies)
|
| 42 |
+
|
| 43 |
+
def forward(self, input_ids, attention_mask, age, gender):
|
| 44 |
+
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
|
| 45 |
+
age_out = self.age_fc(age)
|
| 46 |
+
gender_out = self.gender_fc(gender)
|
| 47 |
+
combined = torch.cat((bert_output, age_out, gender_out), dim=1)
|
| 48 |
+
hidden = torch.relu(self.fc(combined))
|
| 49 |
+
return self.medication_head(hidden), self.therapy_head(hidden)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# --------------------------
|
| 53 |
+
# Global Label Mappings and Age Scaler
|
| 54 |
+
# --------------------------
|
| 55 |
+
medication_classes = ["Anxiolytics", "Benzodiazepines", "Antidepressants", "Mood Stabilizers", "Antipsychotics", "Stimulants"]
|
| 56 |
+
therapy_classes = ["Cognitive Behavioral Therapy", "Dialectical Behavioral Therapy", "Interpersonal Therapy", "Mindfulness-Based Therapy"] # Update with your types
|
| 57 |
+
gender_classes = ["Male", "Female", "Other"]
|
| 58 |
+
|
| 59 |
+
medication_encoder = {name: idx for idx, name in enumerate(medication_classes)}
|
| 60 |
+
inv_medication_encoder = {idx: name for name, idx in medication_encoder.items()}
|
| 61 |
+
therapy_encoder = {name: idx for idx, name in enumerate(therapy_classes)}
|
| 62 |
+
inv_therapy_encoder = {idx: name for name, idx in therapy_encoder.items()}
|
| 63 |
+
gender_encoder = {name: idx for idx, name in enumerate(gender_classes)}
|
| 64 |
+
|
| 65 |
+
mean_age = 50
|
| 66 |
+
std_age = 10
|
| 67 |
+
|
| 68 |
+
def scale_age(age):
|
| 69 |
+
return (age - mean_age) / std_age
|
| 70 |
+
|
| 71 |
+
# --------------------------
|
| 72 |
+
# Load the Hybrid Model (Treatment Prediction)
|
| 73 |
+
# --------------------------
|
| 74 |
+
num_genders = len(gender_classes)
|
| 75 |
+
num_medications = len(medication_classes)
|
| 76 |
+
num_therapies = len(therapy_classes)
|
| 77 |
+
MODEL_SAVE_PATH = "22.03.2025-16.02-ML128E10" # Update accordingly
|
| 78 |
+
|
| 79 |
+
model = HybridMentalHealthModel("emilyalsentzer/Bio_ClinicalBERT", num_genders, num_medications, num_therapies)
|
| 80 |
+
state_dict = torch.load(MODEL_SAVE_PATH, map_location=device)
|
| 81 |
+
if "gender_fc.weight" in state_dict:
|
| 82 |
+
del state_dict["gender_fc.weight"]
|
| 83 |
+
model.load_state_dict(state_dict, strict=False)
|
| 84 |
+
model.to(device)
|
| 85 |
+
model.eval()
|
| 86 |
+
|
| 87 |
+
# --------------------------
|
| 88 |
+
# Global Diagnosis Model (Mental Health Diagnosis)
|
| 89 |
+
# --------------------------
|
| 90 |
+
diagnosis_tokenizer = AutoTokenizer.from_pretrained("ethandavey/mental-health-diagnosis-bert") # Update with your model ID
|
| 91 |
+
diagnosis_model = AutoModelForSequenceClassification.from_pretrained("ethandavey/mental-health-diagnosis-bert") # Update with your model ID
|
| 92 |
+
diagnosis_model.to(device)
|
| 93 |
+
diagnosis_model.eval()
|
| 94 |
+
|
| 95 |
+
def predict_disease(text):
|
| 96 |
+
inputs = diagnosis_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| 97 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
outputs = diagnosis_model(**inputs)
|
| 100 |
+
probabilities = F.softmax(outputs.logits, dim=1).squeeze()
|
| 101 |
+
label_mapping = {0: "Anxiety", 1: "Normal", 2: "Depression", 3: "Suicidal", 4: "Stress"}
|
| 102 |
+
|
| 103 |
+
topk = torch.topk(probabilities, k=3)
|
| 104 |
+
top_preds = [(label_mapping[i.item()], probabilities[i].item()) for i in topk.indices]
|
| 105 |
+
return top_preds
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def predict_med_therapy(symptoms, age, gender):
|
| 109 |
+
encoding = tokenizer(symptoms, return_tensors="pt", truncation=True, padding='max_length', max_length=128)
|
| 110 |
+
input_ids = encoding["input_ids"].to(device)
|
| 111 |
+
attention_mask = encoding["attention_mask"].to(device)
|
| 112 |
+
age_norm = torch.tensor([[scale_age(age)]], dtype=torch.float32).to(device)
|
| 113 |
+
gender_idx = gender_encoder.get(gender, 0)
|
| 114 |
+
gender_tensor = torch.tensor([gender_idx], dtype=torch.long).to(device)
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
med_logits, therapy_logits = model(input_ids, attention_mask, age_norm, gender_tensor)
|
| 117 |
+
med_probabilities = torch.softmax(med_logits, dim=1)
|
| 118 |
+
therapy_probabilities = torch.softmax(therapy_logits, dim=1)
|
| 119 |
+
med_pred = torch.argmax(med_probabilities, dim=1).item()
|
| 120 |
+
therapy_pred = torch.argmax(therapy_probabilities, dim=1).item()
|
| 121 |
+
med_confidence = med_probabilities[0][med_pred].item()
|
| 122 |
+
therapy_confidence = therapy_probabilities[0][therapy_pred].item()
|
| 123 |
+
predicted_med = inv_medication_encoder.get(med_pred, "Unknown")
|
| 124 |
+
predicted_therapy = inv_therapy_encoder.get(therapy_pred, "Unknown")
|
| 125 |
+
return (predicted_med, med_confidence), (predicted_therapy, therapy_confidence)
|
| 126 |
+
|
| 127 |
+
# --------------------------
|
| 128 |
+
# OpenAI Functions (Summarization and Explanation)
|
| 129 |
+
# --------------------------
|
| 130 |
+
def get_concise_rewrite(text, max_tokens, temperature=0.7):
|
| 131 |
+
messages = [
|
| 132 |
+
{"role": "system", "content": "You are an expert rewriting assistant. Rewrite the given statement into a concise version while preserving its tone and vocabulary."},
|
| 133 |
+
{"role": "user", "content": text}
|
| 134 |
+
]
|
| 135 |
+
try:
|
| 136 |
+
response = client.chat.completions.create(model="gpt-4o-mini", messages=messages, max_tokens=max_tokens, temperature=temperature)
|
| 137 |
+
concise_text = response.choices[0].message.content.strip()
|
| 138 |
+
except Exception as e:
|
| 139 |
+
concise_text = f"API call failed: {e}"
|
| 140 |
+
return concise_text
|
| 141 |
+
|
| 142 |
+
def get_explanation(patient_statement, predicted_diagnosis):
|
| 143 |
+
messages = [
|
| 144 |
+
{"role": "system", "content": "You are an expert mental health assistant. Provide a concise, evidence-based explanation of how the patient's statement supports the diagnosis."},
|
| 145 |
+
{"role": "user", "content": f"Patient statement: {patient_statement}\nPredicted diagnosis: {predicted_diagnosis}\nExplain briefly."}
|
| 146 |
+
]
|
| 147 |
+
try:
|
| 148 |
+
response = client.chat.completions.create(model="gpt-4o-mini", messages=messages, max_tokens=256)
|
| 149 |
+
explanation = response.choices[0].message.content.strip()
|
| 150 |
+
except Exception as e:
|
| 151 |
+
explanation = "API call failed."
|
| 152 |
+
return explanation
|
| 153 |
+
|
| 154 |
+
# --------------------------
|
| 155 |
+
# Database Functions
|
| 156 |
+
# --------------------------
|
| 157 |
+
def init_db():
|
| 158 |
+
conn = sqlite3.connect("users.db")
|
| 159 |
+
c = conn.cursor()
|
| 160 |
+
c.execute("""
|
| 161 |
+
CREATE TABLE IF NOT EXISTS users (
|
| 162 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 163 |
+
username TEXT UNIQUE NOT NULL,
|
| 164 |
+
password TEXT NOT NULL,
|
| 165 |
+
full_name TEXT,
|
| 166 |
+
email TEXT
|
| 167 |
+
)
|
| 168 |
+
""")
|
| 169 |
+
c.execute("""
|
| 170 |
+
CREATE TABLE IF NOT EXISTS chat_history (
|
| 171 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 172 |
+
username TEXT NOT NULL,
|
| 173 |
+
message TEXT NOT NULL,
|
| 174 |
+
response TEXT NOT NULL,
|
| 175 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
| 176 |
+
)
|
| 177 |
+
""")
|
| 178 |
+
c.execute("""
|
| 179 |
+
CREATE TABLE IF NOT EXISTS patient_sessions (
|
| 180 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 181 |
+
username TEXT,
|
| 182 |
+
patient_name TEXT,
|
| 183 |
+
age REAL,
|
| 184 |
+
gender TEXT,
|
| 185 |
+
symptoms TEXT,
|
| 186 |
+
diagnosis TEXT,
|
| 187 |
+
medication TEXT,
|
| 188 |
+
therapy TEXT,
|
| 189 |
+
summary TEXT,
|
| 190 |
+
explanation TEXT,
|
| 191 |
+
pdf_report TEXT,
|
| 192 |
+
session_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
| 193 |
+
appointment_date DATE
|
| 194 |
+
)
|
| 195 |
+
""")
|
| 196 |
+
conn.commit()
|
| 197 |
+
conn.close()
|
| 198 |
+
|
| 199 |
+
def register_user(username, password, full_name, email):
|
| 200 |
+
if not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email):
|
| 201 |
+
return "Invalid email format."
|
| 202 |
+
if len(password) <= 8:
|
| 203 |
+
return "Password must be more than 8 characters."
|
| 204 |
+
conn = sqlite3.connect("users.db")
|
| 205 |
+
c = conn.cursor()
|
| 206 |
+
hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
|
| 207 |
+
try:
|
| 208 |
+
c.execute("INSERT INTO users (username, password, full_name, email) VALUES (?, ?, ?, ?)", (username, hashed, full_name, email))
|
| 209 |
+
conn.commit()
|
| 210 |
+
return "User registered successfully."
|
| 211 |
+
except sqlite3.IntegrityError:
|
| 212 |
+
return "Username already exists."
|
| 213 |
+
finally:
|
| 214 |
+
conn.close()
|
| 215 |
+
|
| 216 |
+
def login_user(username, password):
|
| 217 |
+
conn = sqlite3.connect("users.db")
|
| 218 |
+
c = conn.cursor()
|
| 219 |
+
c.execute("SELECT password FROM users WHERE username = ?", (username,))
|
| 220 |
+
user = c.fetchone()
|
| 221 |
+
conn.close()
|
| 222 |
+
if user and bcrypt.checkpw(password.encode(), user[0]):
|
| 223 |
+
return True
|
| 224 |
+
return False
|
| 225 |
+
|
| 226 |
+
def get_user_info(username):
|
| 227 |
+
conn = sqlite3.connect("users.db")
|
| 228 |
+
c = conn.cursor()
|
| 229 |
+
c.execute("SELECT username, email, full_name FROM users WHERE username = ?", (username,))
|
| 230 |
+
user = c.fetchone()
|
| 231 |
+
conn.close()
|
| 232 |
+
if user:
|
| 233 |
+
return f"Username: {user[0]}\nFull Name: {user[2]}\nEmail: {user[1]}"
|
| 234 |
+
else:
|
| 235 |
+
return "User info not found."
|
| 236 |
+
|
| 237 |
+
def get_chat_history(username):
|
| 238 |
+
conn = sqlite3.connect("users.db")
|
| 239 |
+
c = conn.cursor()
|
| 240 |
+
c.execute("SELECT message, response, timestamp FROM chat_history WHERE username = ? ORDER BY timestamp DESC", (username,))
|
| 241 |
+
history = c.fetchall()
|
| 242 |
+
conn.close()
|
| 243 |
+
return history
|
| 244 |
+
|
| 245 |
+
def get_patient_sessions(filter_name="", filter_date=""):
|
| 246 |
+
conn = sqlite3.connect("users.db")
|
| 247 |
+
c = conn.cursor()
|
| 248 |
+
query = "SELECT patient_name, age, gender, symptoms, diagnosis, medication, therapy, summary, explanation, pdf_report, session_timestamp FROM patient_sessions WHERE 1=1"
|
| 249 |
+
params = []
|
| 250 |
+
if filter_name:
|
| 251 |
+
query += " AND patient_name LIKE ?"
|
| 252 |
+
params.append(f"%{filter_name}%")
|
| 253 |
+
if filter_date:
|
| 254 |
+
query += " AND DATE(session_timestamp)=?"
|
| 255 |
+
params.append(filter_date)
|
| 256 |
+
c.execute(query, params)
|
| 257 |
+
sessions = c.fetchall()
|
| 258 |
+
conn.close()
|
| 259 |
+
return sessions
|
| 260 |
+
|
| 261 |
+
def insert_patient_session(session_data):
|
| 262 |
+
conn = sqlite3.connect("users.db")
|
| 263 |
+
c = conn.cursor()
|
| 264 |
+
c.execute("""
|
| 265 |
+
INSERT INTO patient_sessions (username, patient_name, age, gender, symptoms, diagnosis, medication, therapy, summary, explanation, pdf_report, appointment_date)
|
| 266 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 267 |
+
""", (
|
| 268 |
+
session_data.get("username"), session_data.get("patient_name"), session_data.get("age"), session_data.get("gender"),
|
| 269 |
+
session_data.get("symptoms"), session_data.get("diagnosis"), session_data.get("medication"),
|
| 270 |
+
session_data.get("therapy"), session_data.get("summary"), session_data.get("explanation"),
|
| 271 |
+
session_data.get("pdf_report"), session_data.get("appointment_date")))
|
| 272 |
+
conn.commit()
|
| 273 |
+
conn.close()
|
| 274 |
+
|
| 275 |
+
# --------------------------
|
| 276 |
+
# PDF Report Generation Function
|
| 277 |
+
# --------------------------
|
| 278 |
+
def generate_pdf_report(session_data):
|
| 279 |
+
pdf = FPDF()
|
| 280 |
+
pdf.add_page()
|
| 281 |
+
pdf.set_font("Arial", size=12)
|
| 282 |
+
pdf.cell(200, 10, txt="Patient Session Report", ln=True, align='C')
|
| 283 |
+
pdf.ln(10)
|
| 284 |
+
for key, value in session_data.items():
|
| 285 |
+
pdf.multi_cell(0, 10, txt=f"{key.capitalize()}: {value}")
|
| 286 |
+
reports_dir = "reports"
|
| 287 |
+
os.makedirs(reports_dir, exist_ok=True)
|
| 288 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 289 |
+
filename = f"{reports_dir}/{session_data.get('patient_name')}_{timestamp}.pdf"
|
| 290 |
+
pdf.output(filename)
|
| 291 |
+
return filename
|
| 292 |
+
|
| 293 |
+
# --------------------------
|
| 294 |
+
# Helper: Autofill Previous Patient Info
|
| 295 |
+
# --------------------------
|
| 296 |
+
def get_previous_patient_info(selected_patient):
|
| 297 |
+
conn = sqlite3.connect("users.db")
|
| 298 |
+
c = conn.cursor()
|
| 299 |
+
c.execute("SELECT patient_name, age, gender FROM patient_sessions WHERE patient_name=? ORDER BY session_timestamp DESC LIMIT 1", (selected_patient,))
|
| 300 |
+
record = c.fetchone()
|
| 301 |
+
conn.close()
|
| 302 |
+
if record:
|
| 303 |
+
return record[0], record[1], record[2]
|
| 304 |
+
else:
|
| 305 |
+
return "", None, ""
|
| 306 |
+
|
| 307 |
+
def get_previous_patients():
|
| 308 |
+
conn = sqlite3.connect("users.db")
|
| 309 |
+
c = conn.cursor()
|
| 310 |
+
c.execute("SELECT DISTINCT patient_name FROM patient_sessions")
|
| 311 |
+
records = c.fetchall()
|
| 312 |
+
conn.close()
|
| 313 |
+
return [r[0] for r in records]
|
| 314 |
+
|
| 315 |
+
# --------------------------
|
| 316 |
+
# Gradio UI Setup with External CSS
|
| 317 |
+
# --------------------------
|
| 318 |
+
with gr.Blocks(css=open("styles.css", "r").read(), theme="soft") as app:
|
| 319 |
+
user_session = gr.State(value="")
|
| 320 |
+
profile_visible = gr.State(value=False)
|
| 321 |
+
session_data_state = gr.State(value="")
|
| 322 |
+
|
| 323 |
+
with gr.Row(elem_id="header") as header_row:
|
| 324 |
+
with gr.Column(scale=8):
|
| 325 |
+
gr.Markdown("## Mental Health Chatbot")
|
| 326 |
+
with gr.Column(scale=4) as profile_container:
|
| 327 |
+
profile_button = gr.Button("👤", elem_id="profile_button", variant="secondary")
|
| 328 |
+
with gr.Column(visible=False, elem_id="profile_info_box") as profile_info_box:
|
| 329 |
+
profile_info = gr.HTML()
|
| 330 |
+
logout_button = gr.Button("Logout", elem_id="logout_button")
|
| 331 |
+
|
| 332 |
+
with gr.Column(visible=True, elem_id="login_page") as login_page:
|
| 333 |
+
gr.Markdown("## Login")
|
| 334 |
+
with gr.Row():
|
| 335 |
+
username_login = gr.Textbox(label="Username")
|
| 336 |
+
password_login = gr.Textbox(label="Password", type="password")
|
| 337 |
+
login_btn = gr.Button("Login")
|
| 338 |
+
login_output = gr.Textbox(label="Login Status", interactive=False)
|
| 339 |
+
gr.Markdown("New user? Click below to register.")
|
| 340 |
+
go_to_register = gr.Button("Go to Register")
|
| 341 |
+
|
| 342 |
+
with gr.Column(visible=False, elem_id="register_page") as register_page:
|
| 343 |
+
gr.Markdown("## Register")
|
| 344 |
+
new_username = gr.Textbox(label="New Username")
|
| 345 |
+
new_password = gr.Textbox(label="New Password", type="password")
|
| 346 |
+
full_name = gr.Textbox(label="Full Name")
|
| 347 |
+
email = gr.Textbox(label="Email")
|
| 348 |
+
register_btn = gr.Button("Register")
|
| 349 |
+
register_output = gr.Textbox(label="Registration Status", interactive=False)
|
| 350 |
+
gr.Markdown("Already have an account?")
|
| 351 |
+
back_to_login = gr.Button("Back to Login")
|
| 352 |
+
|
| 353 |
+
with gr.Tabs(visible=False, elem_id="main_panel") as main_panel:
|
| 354 |
+
with gr.Tab("Chatbot"):
|
| 355 |
+
with gr.Row():
|
| 356 |
+
with gr.Column(scale=1):
|
| 357 |
+
previous_patient = gr.Dropdown(label="Previous Patients", choices=[], interactive=True)
|
| 358 |
+
patient_name_input = gr.Textbox(placeholder="Enter patient name", label="Patient Name")
|
| 359 |
+
gender_input = gr.Dropdown(choices=list(gender_encoder.keys()), label="Gender")
|
| 360 |
+
age_input = gr.Number(label="Age")
|
| 361 |
+
symptoms_input = gr.Textbox(placeholder="Describe symptoms", label="Symptoms", lines=4)
|
| 362 |
+
submit = gr.Button("Submit")
|
| 363 |
+
generate_report_btn = gr.Button("Generate Report", visible=False)
|
| 364 |
+
with gr.Column(scale=1):
|
| 365 |
+
with gr.Row():
|
| 366 |
+
with gr.Column(scale=4, min_width=240): # Textbox column
|
| 367 |
+
diagnosis_textbox = gr.Textbox(label="Diagnosis",
|
| 368 |
+
interactive=False)
|
| 369 |
+
with gr.Column(scale=1, min_width=120): # Confidence column
|
| 370 |
+
diagnosis_conf_html = gr.HTML(elem_classes=["confidence-container"])
|
| 371 |
+
|
| 372 |
+
with gr.Row():
|
| 373 |
+
with gr.Column(scale=4, min_width=240):
|
| 374 |
+
medication_textbox = gr.Textbox(label="Medication",
|
| 375 |
+
interactive=False)
|
| 376 |
+
with gr.Column(scale=1, min_width=120):
|
| 377 |
+
medication_conf_html = gr.HTML(elem_classes=["confidence-container"])
|
| 378 |
+
|
| 379 |
+
with gr.Row():
|
| 380 |
+
with gr.Column(scale=4, min_width=240):
|
| 381 |
+
therapy_textbox = gr.Textbox(label="Therapy",
|
| 382 |
+
interactive=False)
|
| 383 |
+
with gr.Column(scale=1, min_width=120):
|
| 384 |
+
therapy_conf_html = gr.HTML(elem_classes=["confidence-container"])
|
| 385 |
+
summary_textbox = gr.Textbox(label="Concise Summary", interactive=False)
|
| 386 |
+
explanation_textbox = gr.Textbox(label="Explanation", interactive=False)
|
| 387 |
+
with gr.Row():
|
| 388 |
+
report_download = gr.File(label="Download Report", interactive=False)
|
| 389 |
+
|
| 390 |
+
def handle_chat_extended(patient_name, gender, age, symptoms):
|
| 391 |
+
if age is None or age <= 0:
|
| 392 |
+
error_msg = "Age must be greater than 0."
|
| 393 |
+
return (error_msg, "", error_msg, "", error_msg, "", error_msg, error_msg, gr.update(visible=False))
|
| 394 |
+
|
| 395 |
+
if age > 150:
|
| 396 |
+
error_msg2 = "Age must be lower than 150"
|
| 397 |
+
return (error_msg2, "", error_msg2, "", error_msg2, "", error_msg2, error_msg2, gr.update(visible=False))
|
| 398 |
+
|
| 399 |
+
if len(symptoms.split()) > 512:
|
| 400 |
+
msg = "Input exceeds maximum allowed length of 512 words."
|
| 401 |
+
return (msg, "", msg, "", msg, "", msg, msg, gr.update(visible=False))
|
| 402 |
+
|
| 403 |
+
full_statement = f"Patient Name: {patient_name}, Gender: {gender}, Age: {age}, Symptoms: {symptoms}"
|
| 404 |
+
summary = get_concise_rewrite(full_statement, max_tokens=150, temperature=0.7)
|
| 405 |
+
|
| 406 |
+
# Predict top 3 diagnoses
|
| 407 |
+
diagnosis_preds = predict_disease(full_statement) # Now returns list of (label, confidence)
|
| 408 |
+
diagnosis_display = "\n".join([f"{label}" for label, _ in diagnosis_preds])
|
| 409 |
+
|
| 410 |
+
def get_confidence_class(percentage):
|
| 411 |
+
if percentage <= 50:
|
| 412 |
+
return "confidence-low"
|
| 413 |
+
elif percentage <= 74:
|
| 414 |
+
return "confidence-medium"
|
| 415 |
+
else:
|
| 416 |
+
return "confidence-high"
|
| 417 |
+
|
| 418 |
+
diagnosis_conf_html_val = "<div class='confidence-multi'>" + "<br>".join([
|
| 419 |
+
f"<div class='confidence-display'><span class='confidence-value {get_confidence_class(conf * 100)}'>{conf * 100:.1f}% confidence</span></div>"
|
| 420 |
+
for _, conf in diagnosis_preds
|
| 421 |
+
]) + "</div>"
|
| 422 |
+
|
| 423 |
+
# Predict medication and therapy
|
| 424 |
+
(med_pred, med_conf), (therapy_pred, therapy_conf) = predict_med_therapy(symptoms, age, gender)
|
| 425 |
+
med_percentage = med_conf * 100
|
| 426 |
+
therapy_percentage = therapy_conf * 100
|
| 427 |
+
|
| 428 |
+
def get_conf_html(percentage):
|
| 429 |
+
return f"""
|
| 430 |
+
<div class="confidence-display">
|
| 431 |
+
<span class="confidence-value {get_confidence_class(percentage)}">
|
| 432 |
+
{percentage:.1f}% confidence
|
| 433 |
+
</span>
|
| 434 |
+
</div>
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
medication_conf_html_val = get_conf_html(med_percentage)
|
| 438 |
+
therapy_conf_html_val = get_conf_html(therapy_percentage)
|
| 439 |
+
|
| 440 |
+
# Explanation
|
| 441 |
+
top_diag_labels = ", ".join([label for label, _ in diagnosis_preds])
|
| 442 |
+
explanation = get_explanation(full_statement, f"{top_diag_labels}, {med_pred} and {therapy_pred}")
|
| 443 |
+
|
| 444 |
+
# Prepare session data
|
| 445 |
+
top_diag_with_conf = ", ".join([f"{label} ({conf * 100:.1f}%)" for label, conf in diagnosis_preds])
|
| 446 |
+
session_data = {
|
| 447 |
+
"patient_name": patient_name,
|
| 448 |
+
"age": age,
|
| 449 |
+
"gender": gender,
|
| 450 |
+
"symptoms": symptoms,
|
| 451 |
+
"diagnosis": top_diag_with_conf,
|
| 452 |
+
"medication": f"{med_pred} ({med_percentage:.1f}% confidence)",
|
| 453 |
+
"therapy": f"{therapy_pred} ({therapy_percentage:.1f}% confidence)",
|
| 454 |
+
"summary": summary,
|
| 455 |
+
"explanation": explanation,
|
| 456 |
+
"session_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 457 |
+
}
|
| 458 |
+
session_data_state.value = json.dumps(session_data)
|
| 459 |
+
|
| 460 |
+
# Save to chat history
|
| 461 |
+
conn = sqlite3.connect("users.db")
|
| 462 |
+
c = conn.cursor()
|
| 463 |
+
if user_session.value:
|
| 464 |
+
c.execute("INSERT INTO chat_history (username, message, response) VALUES (?, ?, ?)",
|
| 465 |
+
(user_session.value, full_statement, top_diag_with_conf))
|
| 466 |
+
conn.commit()
|
| 467 |
+
conn.close()
|
| 468 |
+
|
| 469 |
+
return (
|
| 470 |
+
diagnosis_display, diagnosis_conf_html_val,
|
| 471 |
+
med_pred, medication_conf_html_val,
|
| 472 |
+
therapy_pred, therapy_conf_html_val,
|
| 473 |
+
summary, explanation,
|
| 474 |
+
gr.update(visible=True)
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
submit.click(handle_chat_extended,
|
| 479 |
+
inputs=[patient_name_input, gender_input, age_input, symptoms_input],
|
| 480 |
+
outputs=[diagnosis_textbox, diagnosis_conf_html, medication_textbox, medication_conf_html,
|
| 481 |
+
therapy_textbox, therapy_conf_html, summary_textbox, explanation_textbox,
|
| 482 |
+
generate_report_btn])
|
| 483 |
+
|
| 484 |
+
def handle_generate_report():
|
| 485 |
+
try:
|
| 486 |
+
data = json.loads(session_data_state.value)
|
| 487 |
+
except:
|
| 488 |
+
return None
|
| 489 |
+
pdf_file = generate_pdf_report(data)
|
| 490 |
+
data["username"] = user_session.value
|
| 491 |
+
data["appointment_date"] = ""
|
| 492 |
+
data["pdf_report"] = pdf_file
|
| 493 |
+
insert_patient_session(data)
|
| 494 |
+
return pdf_file
|
| 495 |
+
|
| 496 |
+
generate_report_btn.click(handle_generate_report, outputs=report_download)
|
| 497 |
+
|
| 498 |
+
def autofill_previous(selected_patient):
|
| 499 |
+
name, age_val, gender_val = get_previous_patient_info(selected_patient)
|
| 500 |
+
return name, age_val, gender_val
|
| 501 |
+
|
| 502 |
+
previous_patient.change(autofill_previous,
|
| 503 |
+
inputs=[previous_patient],
|
| 504 |
+
outputs=[patient_name_input, age_input, gender_input])
|
| 505 |
+
|
| 506 |
+
with gr.Tab("Chat History"):
|
| 507 |
+
history_output = gr.Textbox(label="Chat History", interactive=False)
|
| 508 |
+
load_history_btn = gr.Button("Load History")
|
| 509 |
+
|
| 510 |
+
def load_history():
|
| 511 |
+
if user_session.value:
|
| 512 |
+
history = get_chat_history(user_session.value)
|
| 513 |
+
return "\n".join([f"[{h[2]}] {h[0]}\nBot: {h[1]}" for h in history])
|
| 514 |
+
else:
|
| 515 |
+
return "Please log in to view history."
|
| 516 |
+
|
| 517 |
+
load_history_btn.click(load_history, outputs=history_output)
|
| 518 |
+
|
| 519 |
+
with gr.Tab("Book an Appointment"):
|
| 520 |
+
with gr.Row():
|
| 521 |
+
with gr.Column():
|
| 522 |
+
patient_name_appt = gr.Textbox(label="Patient Name", placeholder="Enter your name")
|
| 523 |
+
doctor_name = gr.Dropdown(choices=["Dr. Smith", "Dr. Johnson", "Dr. Lee"], label="Select Doctor")
|
| 524 |
+
appointment_date = gr.Textbox(label="Appointment Date", placeholder="YYYY-MM-DD")
|
| 525 |
+
appointment_time = gr.Textbox(label="Appointment Time", placeholder="HH:MM (24-hour format)")
|
| 526 |
+
reason = gr.TextArea(label="Reason for Visit", placeholder="Describe your symptoms or reason for the visit")
|
| 527 |
+
book_button = gr.Button("Book Appointment")
|
| 528 |
+
with gr.Column():
|
| 529 |
+
booking_output = gr.Textbox(label="Booking Confirmation", interactive=False)
|
| 530 |
+
|
| 531 |
+
def book_appointment(patient_name, doctor_name, appointment_date, appointment_time, reason):
|
| 532 |
+
if not user_session.value:
|
| 533 |
+
return "Please log in to book an appointment."
|
| 534 |
+
patient_name = (patient_name or "").strip()
|
| 535 |
+
doctor_name = (doctor_name or "").strip()
|
| 536 |
+
appointment_date = (appointment_date or "").strip()
|
| 537 |
+
appointment_time = (appointment_time or "").strip()
|
| 538 |
+
reason = (reason or "").strip()
|
| 539 |
+
if not (patient_name and doctor_name and appointment_date and appointment_time and reason):
|
| 540 |
+
return "Please fill in all the fields."
|
| 541 |
+
if not re.fullmatch(r"[A-Za-z ]+", patient_name):
|
| 542 |
+
return "Patient name should contain only letters and spaces."
|
| 543 |
+
try:
|
| 544 |
+
datetime.strptime(appointment_date, "%Y-%m-%d")
|
| 545 |
+
except ValueError:
|
| 546 |
+
return "Appointment date must be in YYYY-MM-DD format."
|
| 547 |
+
try:
|
| 548 |
+
datetime.strptime(appointment_time, "%H:%M")
|
| 549 |
+
except ValueError:
|
| 550 |
+
return "Appointment time must be in HH:MM (24-hour) format."
|
| 551 |
+
confirmation = (f"Appointment booked for {patient_name} with {doctor_name} on {appointment_date} at {appointment_time}.\n\n"
|
| 552 |
+
f"Reason: {reason}")
|
| 553 |
+
return confirmation
|
| 554 |
+
|
| 555 |
+
book_button.click(book_appointment,
|
| 556 |
+
inputs=[patient_name_appt, doctor_name, appointment_date, appointment_time, reason],
|
| 557 |
+
outputs=booking_output)
|
| 558 |
+
|
| 559 |
+
with gr.Tab("Patient Sessions"):
|
| 560 |
+
gr.Markdown("### Search Patient Sessions")
|
| 561 |
+
search_name = gr.Textbox(label="Patient Name (optional)")
|
| 562 |
+
search_date = gr.Textbox(label="Date (YYYY-MM-DD, optional)")
|
| 563 |
+
search_button = gr.Button("Search")
|
| 564 |
+
sessions_output = gr.Textbox(label="Sessions", interactive=False)
|
| 565 |
+
|
| 566 |
+
def search_sessions(name, date):
|
| 567 |
+
sessions = get_patient_sessions(filter_name=name, filter_date=date)
|
| 568 |
+
if sessions:
|
| 569 |
+
output = "\n\n".join([f"Patient: {s[0]}\nAge: {s[1]}\nGender: {s[2]}\nSymptoms: {s[3]}\nDiagnosis: {s[4]}\nMedication: {s[5]}\nTherapy: {s[6]}\nSummary: {s[7]}\nExplanation: {s[8]}\nReport: {s[9]}\nSession Time: {s[10]}" for s in sessions])
|
| 570 |
+
return output
|
| 571 |
+
else:
|
| 572 |
+
return "No sessions found."
|
| 573 |
+
|
| 574 |
+
search_button.click(search_sessions, inputs=[search_name, search_date], outputs=sessions_output)
|
| 575 |
+
|
| 576 |
+
def handle_login(username, password):
|
| 577 |
+
if login_user(username, password):
|
| 578 |
+
user_session.value = username
|
| 579 |
+
prev_choices = get_previous_patients()
|
| 580 |
+
return (f"Welcome, {username}!",
|
| 581 |
+
gr.update(visible=True),
|
| 582 |
+
gr.update(visible=False),
|
| 583 |
+
gr.update(visible=True),
|
| 584 |
+
gr.update(choices=prev_choices))
|
| 585 |
+
else:
|
| 586 |
+
return "Invalid credentials.", gr.update(), gr.update(), gr.update(), gr.update()
|
| 587 |
+
|
| 588 |
+
def handle_register(username, password, full_name, email):
|
| 589 |
+
return register_user(username, password, full_name, email)
|
| 590 |
+
|
| 591 |
+
def go_to_register_page():
|
| 592 |
+
return gr.update(visible=False), gr.update(visible=True)
|
| 593 |
+
|
| 594 |
+
def back_to_login_page():
|
| 595 |
+
return gr.update(visible=True), gr.update(visible=False)
|
| 596 |
+
|
| 597 |
+
login_btn.click(handle_login,
|
| 598 |
+
inputs=[username_login, password_login],
|
| 599 |
+
outputs=[login_output, main_panel, login_page, header_row])
|
| 600 |
+
go_to_register.click(go_to_register_page, outputs=[login_page, register_page])
|
| 601 |
+
register_btn.click(handle_register,
|
| 602 |
+
inputs=[new_username, new_password, full_name, email],
|
| 603 |
+
outputs=register_output)
|
| 604 |
+
back_to_login.click(back_to_login_page, outputs=[login_page, register_page])
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
# Toggle profile function
|
| 608 |
+
def toggle_profile(user, current_visible):
|
| 609 |
+
if not user:
|
| 610 |
+
return gr.update(visible=False), False, ""
|
| 611 |
+
new_visible = not current_visible
|
| 612 |
+
info = get_user_info(user) if new_visible else ""
|
| 613 |
+
return gr.update(visible=new_visible), new_visible, info
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
# Connect profile button click with correct input order:
|
| 617 |
+
profile_button.click(
|
| 618 |
+
toggle_profile,
|
| 619 |
+
inputs=[user_session, profile_visible],
|
| 620 |
+
outputs=[profile_info_box, profile_visible, profile_info]
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
# Handle login: update previous patients dropdown
|
| 625 |
+
def handle_login(username, password):
|
| 626 |
+
if login_user(username, password):
|
| 627 |
+
user_session.value = username
|
| 628 |
+
prev_choices = get_previous_patients()
|
| 629 |
+
return (f"Welcome, {username}!",
|
| 630 |
+
gr.update(visible=True), # main_panel visible
|
| 631 |
+
gr.update(visible=False), # login_page hidden
|
| 632 |
+
gr.update(visible=True), # header_row visible
|
| 633 |
+
gr.update(choices=prev_choices)) # update dropdown choices
|
| 634 |
+
else:
|
| 635 |
+
return "Invalid credentials.", gr.update(), gr.update(), gr.update(), gr.update()
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# Connect login button click:
|
| 639 |
+
login_btn.click(
|
| 640 |
+
handle_login,
|
| 641 |
+
inputs=[username_login, password_login],
|
| 642 |
+
outputs=[login_output, main_panel, login_page, header_row, previous_patient]
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
init_db()
|
| 646 |
+
app.launch()
|