Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import faiss | |
| import torch | |
| import numpy as np | |
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # =============================== | |
| # Load Retrieval Components | |
| # =============================== | |
| print("Loading corpus and FAISS index...") | |
| df = pd.read_csv("retrieval_corpus.csv") | |
| index = faiss.read_index("faiss_index.bin") | |
| print("Loading embedding model...") | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # =============================== | |
| # Load LLM on CPU | |
| # =============================== | |
| model_id = "BioMistral/BioMistral-7B" | |
| print(f"Loading tokenizer and model: {model_id}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| ).to("cpu") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # =============================== | |
| # RAG Pipeline | |
| # =============================== | |
| def get_top_k_chunks(query, k=5): | |
| query_embedding = embedding_model.encode([query]) | |
| scores, indices = index.search(np.array(query_embedding).astype("float32"), k) | |
| return df.iloc[indices[0]]["text"].tolist() | |
| def build_prompt(query, chunks): | |
| context = "\n".join(f"{i+1}. {chunk}" for i, chunk in enumerate(chunks)) | |
| prompt = ( | |
| "You are a clinical reasoning assistant. Based on the following medical information, " | |
| "answer the query with a detailed explanation.\n\n" | |
| f"Context:\n{context}\n\n" | |
| f"Query: {query}\n" | |
| "Answer:" | |
| ) | |
| return prompt | |
| def generate_diagnosis(query): | |
| chunks = get_top_k_chunks(query) | |
| prompt = build_prompt(query, chunks) | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| input_ids = inputs.input_ids.to("cpu") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids=input_ids, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.7, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| answer = generated_text.split("Answer:")[-1].strip() | |
| return answer, "\n\n".join(chunks) | |
| # =============================== | |
| # Gradio UI | |
| # =============================== | |
| def run_interface(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## 🧠 Clinical Diagnosis Assistant (RAG)") | |
| gr.Markdown("Enter a clinical query. The assistant retrieves relevant medical facts and generates a diagnostic explanation.") | |
| with gr.Row(): | |
| query_input = gr.Textbox(label="Clinical Query", placeholder="e.g. 65-year-old male with shortness of breath...") | |
| generate_btn = gr.Button("Generate Diagnosis") | |
| with gr.Accordion("📄 Retrieved Context", open=False): | |
| context_output = gr.Textbox(label="Top-5 Retrieved Chunks", lines=10, interactive=False) | |
| answer_output = gr.Textbox(label="Generated Diagnosis", lines=8) | |
| generate_btn.click( | |
| fn=generate_diagnosis, | |
| inputs=query_input, | |
| outputs=[answer_output, context_output] | |
| ) | |
| return demo | |
| # =============================== | |
| # Launch App | |
| # =============================== | |
| if __name__ == "__main__": | |
| demo = run_interface() | |
| demo.launch() |