pipeline2 / langchain_medical_agents_refactored.py
Nourhenem's picture
initial commit
f92da22 verified
raw
history blame
14.4 kB
#!/usr/bin/env python3
"""
LangChain Medical Agents Architecture - Refactored
A multi-agent system for processing medical transcriptions and documents.
"""
import os
import re
from datetime import datetime
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
# Import modular components
from models import TemplateAnalysis, MedicalTranscription, SectionContent, InsertSectionsInput
from sftp_agent import create_sftp_downloader_agent, download_model_from_sftp
from template_analyzer import create_template_analyzer_agent, analyze_word_template
from transcription_processor import (
create_transcription_corrector_chain,
create_medical_analyzer_chain,
create_title_generator_chain,
load_transcription_with_user_id
)
from section_generator import create_dynamic_section_prompt, fix_section_names
from document_assembler import create_document_assembler_agent
from document_validator import validate_generated_document, create_validation_chain
# Load environment variables
load_dotenv()
# Initialize LLM with Azure OpenAI
llm = AzureChatOpenAI(
azure_deployment="gtp-4o-eastus2",
openai_api_version="2024-02-15-preview",
azure_endpoint="https://voxist-gpt-eastus2.openai.azure.com/",
api_key="98db8190a2ff438b904c7e9862a13210",
temperature=0.1
)
class MedicalDocumentOrchestrator:
"""Main orchestrator that coordinates all agents."""
def __init__(self, template_path: str = None, transcription_path: str = None, transcriptions_dir: str = "transcriptions"):
self.template_path = template_path
self.transcription_path = transcription_path
self.transcriptions_dir = transcriptions_dir
self.template_analysis = None
self.corrected_transcription = None
self.medical_data = None
self.generated_sections = None
self.generated_title = None
self.downloaded_models = None
def run_full_pipeline(self, output_path: str = None) -> str:
"""Run the complete medical document processing pipeline."""
print("πŸš€ Starting LangChain Medical Document Pipeline...")
# Step 0: Download only the model corresponding to the transcription
print("\nπŸ“₯ Step 0: Downloading model from SFTP for the selected transcription...")
try:
transcription_filename = os.path.basename(self.transcription_path)
match = re.search(r'transcriptions_(.+)\.rtf_',
transcription_filename)
if match:
model_id = match.group(1)
model_filename = f"{model_id}.rtf"
local_filename = f"{model_id}.doc"
local_template_path = os.path.join("models", local_filename)
print(f"πŸ”Ž Model identifier for this transcription: {model_id}")
# Download only the required model via a simple agent
simple_sftp_agent = create_openai_tools_agent(
llm=llm,
tools=[download_model_from_sftp],
prompt=ChatPromptTemplate.from_messages([
("system", "You are an SFTP downloader. Download the specified model file."),
("human", "Download the model file: {model_filename}"),
MessagesPlaceholder("agent_scratchpad")
])
)
simple_sftp_executor = AgentExecutor(
agent=simple_sftp_agent,
tools=[download_model_from_sftp],
verbose=True
)
result = simple_sftp_executor.invoke({
"model_filename": model_filename
})
print(
f"βœ… Model downloaded and available as: {local_template_path}")
self.template_path = local_template_path
self.downloaded_models = [{
'model_id': model_id,
'model_filename': model_filename,
'local_filename': local_filename,
'local_path': local_template_path,
'status': 'success'
}]
else:
raise ValueError(
"Unable to extract the model identifier from the transcription filename.")
except Exception as e:
print(f"❌ Error during SFTP download step: {str(e)}")
if self.template_path:
print("⚠️ Continuing with pipeline using the provided template_path...")
else:
print(
"❌ No template path provided and SFTP download failed. Cannot continue.")
raise Exception(
"Cannot continue without a template. SFTP download failed and no template path was provided.")
self.downloaded_models = []
# Step 1: Analyze template
print("\nπŸ“‹ Step 1: Analyzing template...")
if not self.template_path:
raise ValueError("No template path available for analysis")
self.template_analysis = analyze_word_template(self.template_path)
print(
f"βœ… Template analyzed: {len(self.template_analysis.get('sections', []))} sections found")
# Step 2: Load and correct transcription
print("\n✏️ Step 2: Correcting transcription...")
raw_transcription, user_id = load_transcription_with_user_id(
self.transcription_path)
transcription_corrector_chain = create_transcription_corrector_chain(
llm)
self.corrected_transcription = transcription_corrector_chain.invoke({
"transcription": raw_transcription
}).content
# ← Ajoute ces deux lignes juste aprΓ¨s
print("\n===== Transcription après correction =====")
print(self.corrected_transcription)
print("βœ… Transcription corrected")
print("βœ… Transcription corrected")
# Step 3: Analyze medical data
print("\nπŸ”¬ Step 3: Analyzing medical data...")
medical_analyzer_chain = create_medical_analyzer_chain(llm)
self.medical_data = medical_analyzer_chain.invoke({
"corrected_transcription": self.corrected_transcription
}).content
print("βœ… Medical data analyzed")
# Step 4: Generate title
print("\nπŸ“ Step 4: Generating title...")
title_generator_chain = create_title_generator_chain(llm)
self.generated_title = title_generator_chain.invoke({
"medical_data": self.medical_data
}).content
print(f"βœ… Title generated: {self.generated_title}")
# Step 5: Generate sections
print("\nπŸ“ Step 5: Generating sections...")
# Extract sections from template analysis
template_sections = []
# Debug: see exactly what template_analysis contains
print("\n--- DEBUG: Type and content of template_analysis ---")
print(f"Type: {type(self.template_analysis)}")
print(f"Content: {self.template_analysis}")
if hasattr(self.template_analysis, '__dict__'):
print(f"Attributes: {self.template_analysis.__dict__}")
print("--- END DEBUG ---\n")
# Always retrieve the sections list if possible
try:
if isinstance(self.template_analysis, dict) and 'sections' in self.template_analysis:
template_sections = [section['text']
for section in self.template_analysis['sections']]
elif hasattr(self.template_analysis, 'get') and 'sections' in self.template_analysis:
template_sections = [section['text']
for section in self.template_analysis['sections']]
elif hasattr(self.template_analysis, 'output') and isinstance(self.template_analysis.output, dict) and 'sections' in self.template_analysis.output:
template_sections = [section['text']
for section in self.template_analysis.output['sections']]
except Exception as e:
print('Error extracting sections:', e)
# Fallback: try to extract from the agent response text
if not template_sections:
response_text = str(self.template_analysis)
if 'Technique' in response_text and 'RΓ©sultat' in response_text and 'Conclusion' in response_text:
template_sections = ['Technique\xa0:',
'RΓ©sultat\xa0:', 'Conclusion\xa0:']
elif 'CONCLUSION' in response_text:
template_sections = ['CONCLUSION\xa0:']
# Create dynamic prompt based on template sections
dynamic_section_prompt = create_dynamic_section_prompt(
template_sections)
section_generator_chain = dynamic_section_prompt | llm
generated_content = section_generator_chain.invoke({
"template_sections": template_sections,
"medical_data": self.medical_data,
"corrected_transcription": self.corrected_transcription
}).content
# Post-process to ensure exact section names are used
self.generated_sections = fix_section_names(
generated_content, template_sections)
print("\n--- DEBUG: Generated sections ---")
print(self.generated_sections)
print("--- END DEBUG ---\n")
print("\n--- DEBUG: Template sections ---")
print(template_sections)
print("--- END DEBUG ---\n")
print("\n--- DEBUG: Generated title ---")
print(self.generated_title)
print("--- END DEBUG ---\n")
# Step 6: Assemble document
print("\nπŸ“„ Step 6: Assembling document...")
if output_path is None:
# Generate output filename based on user_id
# Replace the last extension with .docx
if '.' in user_id:
# Split by dots and replace the last part with docx
parts = user_id.split('.')
parts[-1] = 'docx'
output_filename = '.'.join(parts)
else:
# If no extension, just add .docx
output_filename = f"{user_id}.docx"
output_path = output_filename
# Use the agent for assembly
document_assembler_executor = create_document_assembler_agent(llm)
result = document_assembler_executor.invoke({
"template_path": self.template_path,
"sections_text": self.generated_sections,
"title": self.generated_title,
"output_path": output_path
})
print(f"πŸŽ‰ Pipeline completed! Document saved: {output_path}")
# Step 7: Validate document
print("\nπŸ“‹ Step 7: Validating document...")
validation_result = validate_generated_document(
self.template_path, self.transcription_path, output_path)
# Display validation results
print("\n" + "=" * 60)
print("πŸ“Š VALIDATION RESULTS")
print("=" * 60)
# Overall score
score = validation_result["overall_score"]
score_emoji = "🟒" if score >= 0.8 else "🟑" if score >= 0.6 else "πŸ”΄"
print(f"{score_emoji} Overall Score: {score:.1%}")
# Structure validation
structure_valid = validation_result["structure_valid"]
structure_emoji = "βœ…" if structure_valid else "❌"
print(f"{structure_emoji} Structure Valid: {structure_valid}")
if not structure_valid:
missing = validation_result["missing_sections"]
print(f" Missing sections: {', '.join(missing)}")
# Entities validation
entities_coverage = validation_result["entities_coverage"]
entities_emoji = "βœ…" if entities_coverage >= 80 else "⚠️"
print(f"{entities_emoji} Medical Entities Coverage: {entities_coverage:.1f}%")
if entities_coverage < 80:
missing_entities = validation_result["missing_entities"][:5]
print(f" Missing entities: {', '.join(missing_entities)}")
# Generate AI validation report
print("\nπŸ“ AI Validation Report:")
print("-" * 40)
# Extract content for AI validation
from docx import Document
doc = Document(output_path)
generated_content = []
for paragraph in doc.paragraphs:
text = paragraph.text.strip()
if text and not text.startswith("Date:") and not text.startswith("Heure:"):
generated_content.append(text)
generated_text = "\n".join(generated_content)
validation_chain = create_validation_chain(llm)
ai_validation = validation_chain.invoke({
"transcription": self.corrected_transcription,
"generated_content": generated_text,
"structure_valid": structure_valid,
"entities_coverage": entities_coverage,
"missing_sections": validation_result["missing_sections"],
"missing_entities": validation_result["missing_entities"]
})
print(ai_validation.content)
print("\n" + "=" * 60)
print("βœ… Document validated")
# Remove the local model after validation
try:
if self.template_path and os.path.exists(self.template_path):
os.remove(self.template_path)
print(f"πŸ—‘οΈ Deleted local model file: {self.template_path}")
except Exception as e:
print(f"⚠️ Could not delete local model file: {e}")
return output_path
def main():
"""Main function to run the LangChain medical document pipeline."""
print("πŸ₯ LangChain Medical Document Agents - Refactored")
print("=" * 60)
# Initialize orchestrator
orchestrator = MedicalDocumentOrchestrator(
template_path="default.528.251014072.doc",
transcription_path="transciption.txt"
)
# Run the complete pipeline
output_file = orchestrator.run_full_pipeline()
print(f"\nβœ… Final document: {output_file}")
print("πŸŽ‰ LangChain pipeline completed successfully!")
if __name__ == "__main__":
main()