Spaces:
Runtime error
Runtime error
| import re, uuid | |
| import base64 | |
| import bcrypt | |
| import gradio as gr | |
| from gradio_pdf import PDF | |
| from pathlib import Path | |
| import time | |
| import shutil | |
| from typing import AsyncGenerator, List, Optional, Tuple | |
| from gradio import ChatMessage | |
| from fpdf import FPDF | |
| REPORT_DIR = Path("reports") | |
| REPORT_DIR.mkdir(exist_ok=True) | |
| SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme' | |
| USERS = { | |
| 'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6' | |
| } | |
| class ChatInterface: | |
| """ | |
| A chat interface for interacting with a medical AI agent through Gradio. | |
| Handles file uploads, message processing, and chat history management. | |
| Supports both regular image files and DICOM medical imaging files. | |
| """ | |
| def __init__(self, agent, tools_dict): | |
| """ | |
| Initialize the chat interface. | |
| Args: | |
| agent: The medical AI agent to handle requests | |
| tools_dict (dict): Dictionary of available tools for image processing | |
| """ | |
| self.agent = agent | |
| self.tools_dict = tools_dict | |
| self.upload_dir = Path("temp") | |
| self.upload_dir.mkdir(exist_ok=True) | |
| self.current_thread_id = None | |
| # Separate storage for original and display paths | |
| self.original_file_path = None # For LLM (.dcm or other) | |
| self.display_file_path = None # For UI (always viewable format) | |
| def handle_upload(self, file_path: str) -> str: | |
| """ | |
| Handle new file upload and set appropriate paths. | |
| Args: | |
| file_path (str): Path to the uploaded file | |
| Returns: | |
| str: Display path for UI, or None if no file uploaded | |
| """ | |
| if not file_path: | |
| return None | |
| source = Path(file_path) | |
| timestamp = int(time.time()) | |
| # Save original file with proper suffix | |
| suffix = source.suffix.lower() | |
| saved_path = self.upload_dir / f"upload_{timestamp}{suffix}" | |
| shutil.copy2(file_path, saved_path) # Use file_path directly instead of source | |
| self.original_file_path = str(saved_path) | |
| # Handle DICOM conversion for display only | |
| if suffix == ".dcm": | |
| output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path)) | |
| self.display_file_path = output["image_path"] | |
| else: | |
| self.display_file_path = str(saved_path) | |
| return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True) | |
| def add_message( | |
| self, message: str, display_image: str, history: List[dict] | |
| ) -> Tuple[List[dict], gr.Textbox]: | |
| """ | |
| Add a new message to the chat history. | |
| Args: | |
| message (str): Text message to add | |
| display_image (str): Path to image being displayed | |
| history (List[dict]): Current chat history | |
| Returns: | |
| Tuple[List[dict], gr.Textbox]: Updated history and textbox component | |
| """ | |
| image_path = self.original_file_path or display_image | |
| if image_path is not None: | |
| history.append({"role": "user", "content": {"path": image_path}}) | |
| if message is not None: | |
| history.append({"role": "user", "content": message}) | |
| return history, gr.Textbox(value=message, interactive=False) | |
| async def process_message( | |
| self, message: str, display_image: Optional[str], chat_history: List[ChatMessage] | |
| ) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]: | |
| """ | |
| Process a message and generate responses. | |
| Args: | |
| message (str): User message to process | |
| display_image (Optional[str]): Path to currently displayed image | |
| chat_history (List[ChatMessage]): Current chat history | |
| Yields: | |
| Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string | |
| """ | |
| chat_history = chat_history or [] | |
| # Initialize thread if needed | |
| if not self.current_thread_id: | |
| self.current_thread_id = str(time.time()) | |
| messages = [] | |
| image_path = self.original_file_path or display_image | |
| if image_path is not None: | |
| # Send path for tools | |
| messages.append({"role": "user", "content": f"image_path: {image_path}"}) | |
| # Load and encode image for multimodal | |
| with open(image_path, "rb") as img_file: | |
| img_base64 = base64.b64encode(img_file.read()).decode("utf-8") | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, | |
| } | |
| ], | |
| } | |
| ) | |
| if message is not None: | |
| messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) | |
| try: | |
| for event in self.agent.workflow.stream( | |
| {"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}} | |
| ): | |
| if isinstance(event, dict): | |
| if "process" in event: | |
| content = event["process"]["messages"][-1].content | |
| if content: | |
| content = re.sub(r"temp/[^\s]*", "", content) | |
| chat_history.append(ChatMessage(role="assistant", content=content)) | |
| yield chat_history, self.display_file_path, "" | |
| elif "execute" in event: | |
| for message in event["execute"]["messages"]: | |
| tool_name = message.name | |
| tool_result = eval(message.content)[0] | |
| if tool_result: | |
| metadata = {"title": f"πΌοΈ Image from tool: {tool_name}"} | |
| formatted_result = " ".join( | |
| line.strip() for line in str(tool_result).splitlines() | |
| ).strip() | |
| metadata["description"] = formatted_result | |
| chat_history.append( | |
| ChatMessage( | |
| role="assistant", | |
| content=formatted_result, | |
| metadata=metadata, | |
| ) | |
| ) | |
| # For image_visualizer, use display path | |
| if tool_name == "image_visualizer": | |
| self.display_file_path = tool_result["image_path"] | |
| chat_history.append( | |
| ChatMessage( | |
| role="assistant", | |
| # content=gr.Image(value=self.display_file_path), | |
| content={"path": self.display_file_path}, | |
| ) | |
| ) | |
| yield chat_history, self.display_file_path, "" | |
| except Exception as e: | |
| chat_history.append( | |
| ChatMessage( | |
| role="assistant", content=f"β Error: {str(e)}", metadata={"title": "Error"} | |
| ) | |
| ) | |
| yield chat_history, self.display_file_path | |
| def create_demo(agent, tools_dict): | |
| """ | |
| Create a Gradio demo interface for the medical AI agent. | |
| Args: | |
| agent: The medical AI agent to handle requests | |
| tools_dict (dict): Dictionary of available tools for image processing | |
| Returns: | |
| gr.Blocks: Gradio Blocks interface | |
| """ | |
| interface = ChatInterface(agent, tools_dict) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| auth_state = gr.State(False) | |
| with gr.Column(visible=True) as login_page: | |
| gr.Markdown("## π Login") | |
| username = gr.Textbox(label="Username") | |
| password = gr.Textbox(label="Password", type="password") | |
| login_button = gr.Button("Login") | |
| login_error = gr.Markdown(visible=False) | |
| with gr.Column(visible=False) as main_page: | |
| gr.Markdown( | |
| """ | |
| # π₯ MOLx - Powered by MedRAX | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| [], | |
| height=800, | |
| container=True, | |
| show_label=True, | |
| elem_classes="chat-box", | |
| type="messages", | |
| label="Agent", | |
| avatar_images=( | |
| None, | |
| "assets/medrax_logo.jpg", | |
| ), | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| txt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Ask about the X-ray...", | |
| container=False, | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Tabs(): | |
| with gr.Tab(label="Image section"): | |
| image_display = gr.Image( | |
| label="Image", type="filepath", height=650, container=True | |
| ) | |
| with gr.Row(): | |
| upload_button = gr.UploadButton( | |
| "π Upload X-Ray", | |
| file_types=["image"], | |
| ) | |
| dicom_upload = gr.UploadButton( | |
| "π Upload DICOM", | |
| file_types=["file"], | |
| ) | |
| with gr.Row(): | |
| analyze_btn = gr.Button("Analyze", interactive=False) | |
| ground_btn = gr.Button("Ground", interactive=False) | |
| segment_btn = gr.Button("Segment", interactive=False) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat") | |
| new_thread_btn = gr.Button("New Thread") | |
| with gr.Tab(label="Report section"): | |
| generate_report_btn = gr.Button("Generate Report") | |
| diseases_df = gr.Dataframe( | |
| headers=["Disease", "Info"], | |
| datatype=["str", "str"], | |
| interactive=False, visible=False, max_height=220) | |
| conclusion_tb = gr.Textbox(label="Conclusion", interactive=False, visible=False) | |
| with gr.Row(): | |
| approve_btn = gr.Button("Approve", visible=False) | |
| reject_btn = gr.Button("Reject", visible=False) | |
| download_pdf_btn = gr.DownloadButton(label="π₯ Download PDF", visible=False) | |
| # pdf_preview = gr.HTML(visible=False) | |
| # pdf_preview = gr.File(visible=False) | |
| pdf_preview = PDF(visible=False) | |
| rejection_text = gr.Textbox( | |
| show_label=False, | |
| visible=False, | |
| placeholder="Tell us what is wrong with the report", | |
| container=False, | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| submit_reject_btn = gr.Button("Submit", visible=False) | |
| cancel_reject_btn = gr.Button("Cancel", visible=False) | |
| # Event handlers | |
| def authenticate(username, password): | |
| hashed = USERS.get(username) | |
| if hashed and bcrypt.checkpw(password.encode(), hashed): | |
| return ( | |
| gr.update(visible=False), # hide login | |
| gr.update(visible=True), # show main | |
| gr.update(visible=False), # hide error | |
| True # set state | |
| ) | |
| return None, None, gr.update(value="β Incorrect username or password", visible=True), False | |
| def clear_chat(): | |
| interface.original_file_path = None | |
| interface.display_file_path = None | |
| return [], None | |
| def new_thread(): | |
| interface.current_thread_id = str(time.time()) | |
| return [], interface.display_file_path | |
| def handle_file_upload(file): | |
| return interface.handle_upload(file.name) | |
| def generate_report(): | |
| result = interface.agent.summarize_message(interface.current_thread_id) | |
| table = [[d["name"], d["info"]] for d in result["Disease"]] | |
| return ( | |
| gr.update(value=table, interactive=True, visible=True), | |
| gr.update(value=result["Conclusion"], lines=4, interactive=True, visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ) | |
| def records_to_pdf(table, conclusion) -> Path: | |
| """ | |
| Writes a PDF report under ./reports/ and returns the Path. | |
| """ | |
| print(type(table)) | |
| pdf = FPDF() | |
| pdf.set_auto_page_break(auto=True, margin=15) | |
| pdf.add_page() | |
| pdf.set_font(family="Helvetica", size=12) | |
| pdf.cell(0, 10, "Chest-X-ray Report", ln=1, align="C") | |
| pdf.ln(4) | |
| pdf.set_font(family="Helvetica", style="B") | |
| pdf.cell(60, 8, "Disease") | |
| pdf.cell(0, 8, "Information", ln=1) | |
| pdf.set_font(family="Helvetica", style="") | |
| for idx, row in table.iterrows(): | |
| pdf.multi_cell(0, 8, f"{row['Disease']}: {row['Info']}") | |
| pdf.ln(4) | |
| pdf.set_font(family="Helvetica", style="B") | |
| pdf.cell(0, 8, "Conclusion", ln=1) | |
| pdf.set_font(family="Helvetica", style="") | |
| pdf.multi_cell(0, 8, conclusion) | |
| pdf_path = REPORT_DIR / f"report_{uuid.uuid4().hex}.pdf" | |
| pdf.output(str(pdf_path)) | |
| return pdf_path | |
| def build_pdf_and_preview(table, conclusion): | |
| pdf_path = records_to_pdf(table, conclusion) | |
| iframe_html = ( | |
| f'<iframe src="file={pdf_path}" ' | |
| 'style="width:100%;height:650px;border:none;"></iframe>' | |
| ) | |
| return ( | |
| gr.update(value=pdf_path, visible=True), # for DownloadButton | |
| gr.update(value=str(pdf_path), visible=True) # for HTML preview | |
| ) | |
| def show_reject_ui(): | |
| return gr.update(visible=True, value=""), gr.update(visible=True), gr.update(visible=True) | |
| def hide_reject_ui(): | |
| return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(visible=False) | |
| login_button.click(authenticate, [username, password], [login_page, main_page, login_error, auth_state]) | |
| chat_msg = txt.submit( | |
| interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] | |
| ) | |
| bot_msg = chat_msg.then( | |
| interface.process_message, | |
| inputs=[txt, image_display, chatbot], | |
| outputs=[chatbot, image_display, txt], | |
| ) | |
| bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt]) | |
| analyze_btn.click( | |
| lambda: gr.update(value="Analyze the above image and identify the probabilites of occurrence of various diseases along with proper reason"), None, txt | |
| ).then( | |
| interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] # add message & clear box | |
| ).then( | |
| interface.process_message, | |
| inputs=[txt, image_display, chatbot], | |
| outputs=[chatbot, image_display, txt], | |
| ).then(lambda: gr.Textbox(interactive=True), None, [txt]) | |
| ground_btn.click( | |
| lambda: gr.update(value="Ground the main disease in this CXR"), None, txt | |
| ).then( | |
| interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] # add message & clear box | |
| ).then( | |
| interface.process_message, | |
| inputs=[txt, image_display, chatbot], | |
| outputs=[chatbot, image_display, txt], | |
| ).then(lambda: gr.Textbox(interactive=True), None, [txt]) | |
| segment_btn.click( | |
| lambda: gr.update(value="Segment the major affected lung"), None, txt | |
| ).then( | |
| interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] # add message & clear box | |
| ).then( | |
| interface.process_message, | |
| inputs=[txt, image_display, chatbot], | |
| outputs=[chatbot, image_display, txt], | |
| ).then(lambda: gr.Textbox(interactive=True), None, [txt]) | |
| upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display, analyze_btn, ground_btn, segment_btn]) | |
| dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display, analyze_btn, ground_btn, segment_btn]) | |
| clear_btn.click(clear_chat, outputs=[chatbot, image_display]) | |
| new_thread_btn.click(new_thread, outputs=[chatbot, image_display]) | |
| generate_report_btn.click(generate_report, outputs=[diseases_df, conclusion_tb, approve_btn, reject_btn]) | |
| approve_btn.click( | |
| build_pdf_and_preview, | |
| inputs=[diseases_df, conclusion_tb], | |
| outputs=[download_pdf_btn, pdf_preview], | |
| ) | |
| reject_btn.click(show_reject_ui, outputs=[rejection_text, submit_reject_btn, cancel_reject_btn]) | |
| cancel_reject_btn.click(hide_reject_ui, outputs=[rejection_text, submit_reject_btn, cancel_reject_btn]) | |
| return demo | |