#!/usr/bin/env python3 """ Enhanced Medical Transcription Processor with ML-based correction Uses pretrained models and embeddings for dynamic medical term correction """ import os import json import re import numpy as np from typing import Dict, Any, Tuple, List, Optional from langchain.tools import tool from langchain.prompts import ChatPromptTemplate from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import spacy from spacy.matcher import Matcher import torch from transformers import ( AutoTokenizer, AutoModelForMaskedLM, AutoModelForTokenClassification, pipeline ) from difflib import SequenceMatcher import pickle import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MedicalTermCorrector: """ML-based medical term corrector using pretrained models and embeddings""" def __init__(self, cache_dir: str = "./model_cache"): self.cache_dir = cache_dir os.makedirs(cache_dir, exist_ok=True) # Initialize models self.sentence_transformer = None self.nlp = None self.medical_ner = None self.masked_lm_model = None self.masked_lm_tokenizer = None self.medical_embeddings = None self.medical_terms = None self._load_models() self._load_medical_knowledge() def _load_models(self): """Load pretrained models for medical text processing""" try: # Load sentence transformer for semantic similarity logger.info("Loading sentence transformer model...") self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2') # Load spaCy model for NER and linguistic processing logger.info("Loading spaCy model...") try: self.nlp = spacy.load("fr_core_news_sm") except OSError: logger.warning("French spaCy model not found. Install with: python -m spacy download fr_core_news_sm") self.nlp = spacy.load("en_core_web_sm") # Fallback # Load medical NER model logger.info("Loading medical NER model...") self.medical_ner = pipeline( "ner", model="samrawal/bert-base-uncased_clinical-ner", tokenizer="samrawal/bert-base-uncased_clinical-ner", aggregation_strategy="simple" ) # Load masked language model for context-aware corrections logger.info("Loading masked language model...") self.masked_lm_tokenizer = AutoTokenizer.from_pretrained("camembert-base") self.masked_lm_model = AutoModelForMaskedLM.from_pretrained("camembert-base") except Exception as e: logger.error(f"Error loading models: {e}") raise def _load_medical_knowledge(self): """Load medical terminology from external sources""" # Check if cached embeddings exist embeddings_path = os.path.join(self.cache_dir, "medical_embeddings.pkl") terms_path = os.path.join(self.cache_dir, "medical_terms.pkl") if os.path.exists(embeddings_path) and os.path.exists(terms_path): logger.info("Loading cached medical knowledge...") with open(embeddings_path, 'rb') as f: self.medical_embeddings = pickle.load(f) with open(terms_path, 'rb') as f: self.medical_terms = pickle.load(f) else: logger.info("Building medical knowledge base...") self._build_medical_knowledge() def _build_medical_knowledge(self): """Build medical knowledge base from various sources""" # Medical terms from various domains medical_terms = [ # Radiology "mammographie", "échographie", "IRM", "TDM", "radiographie", "scintigraphie", "angiographie", "arthrographie", # Anatomy "utérus", "ovaires", "myomètre", "endomètre", "cervix", "ganglions", "axillaires", "mammaire", "pelvien", "thyroïde", "pancréas", "foie", "rate", "reins", # Pathology "adénomyose", "endométriose", "fibrome", "kyste", "carcinome", "adénome", "métastase", "tumeur", "inflammation", "nécrose", "hémorragie", "œdème", # Classifications "BI-RADS", "ACR", "TNM", "WHO", "BIRADS", # Procedures "biopsie", "dépistage", "surveillance", "contrôle", "ponction", "drainage", "résection", "ablation", # Measurements "millimètre", "centimètre", "millilitre", "gramme", "pourcentage", "degré", "unité", "concentration" ] # Generate embeddings for medical terms logger.info("Generating embeddings for medical terms...") self.medical_terms = medical_terms self.medical_embeddings = self.sentence_transformer.encode(medical_terms) # Cache the embeddings embeddings_path = os.path.join(self.cache_dir, "medical_embeddings.pkl") terms_path = os.path.join(self.cache_dir, "medical_terms.pkl") with open(embeddings_path, 'wb') as f: pickle.dump(self.medical_embeddings, f) with open(terms_path, 'wb') as f: pickle.dump(self.medical_terms, f) def find_similar_medical_term(self, word: str, threshold: float = 0.7) -> Optional[str]: """Find the most similar medical term using embeddings""" if not self.medical_embeddings.any(): return None try: word_embedding = self.sentence_transformer.encode([word]) similarities = cosine_similarity(word_embedding, self.medical_embeddings)[0] max_similarity_idx = np.argmax(similarities) max_similarity = similarities[max_similarity_idx] if max_similarity > threshold: return self.medical_terms[max_similarity_idx] except Exception as e: logger.warning(f"Error finding similar term for '{word}': {e}") return None def correct_with_context(self, sentence: str, target_word: str) -> str: """Use masked language model for context-aware correction""" try: # Replace target word with mask token masked_sentence = sentence.replace(target_word, self.masked_lm_tokenizer.mask_token) # Tokenize and get predictions inputs = self.masked_lm_tokenizer(masked_sentence, return_tensors="pt") with torch.no_grad(): outputs = self.masked_lm_model(**inputs) predictions = outputs.logits # Find mask token position mask_token_index = torch.where(inputs["input_ids"] == self.masked_lm_tokenizer.mask_token_id)[1] if len(mask_token_index) > 0: mask_token_logits = predictions[0, mask_token_index, :] top_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() # Get top predictions candidates = [self.masked_lm_tokenizer.decode([token]) for token in top_tokens] # Filter for medical relevance for candidate in candidates: candidate = candidate.strip() if len(candidate) > 2 and candidate.isalpha(): # Check if candidate is medically relevant similar_term = self.find_similar_medical_term(candidate, threshold=0.6) if similar_term: return similar_term # Check string similarity with original if SequenceMatcher(None, target_word.lower(), candidate.lower()).ratio() > 0.6: return candidate except Exception as e: logger.warning(f"Error in context correction for '{target_word}': {e}") return target_word def extract_medical_entities(self, text: str) -> List[Dict]: """Extract medical entities using NER""" try: entities = self.medical_ner(text) return entities except Exception as e: logger.warning(f"Error in medical NER: {e}") return [] def correct_medical_text(self, text: str) -> str: """Main method to correct medical text using ML models""" corrected_text = text # Extract potential medical entities entities = self.extract_medical_entities(text) # Process with spaCy for linguistic analysis if self.nlp: doc = self.nlp(text) # Find words that might be medical terms but are misspelled for token in doc: if (token.is_alpha and len(token.text) > 3 and not token.is_stop and not token.like_url): # Check if it's a potential medical term similar_term = self.find_similar_medical_term(token.text) if similar_term and similar_term != token.text: # Use context-aware correction context_correction = self.correct_with_context(text, token.text) # Choose the best correction final_correction = context_correction if context_correction != token.text else similar_term # Apply correction with word boundaries pattern = r'\b' + re.escape(token.text) + r'\b' corrected_text = re.sub(pattern, final_correction, corrected_text, flags=re.IGNORECASE) logger.info(f"Corrected: '{token.text}' -> '{final_correction}'") return corrected_text class DateTimeNormalizer: """Normalize dates and times in medical texts using regex patterns""" def __init__(self): self.date_patterns = [ # French date patterns (r'(\d{1,2})\s+(\d{1,2})\s+(\d{4})', r'\1/\2/\3'), (r'(\d{1,2})\s+(\d{1,2})\s+(\d{2})\s+(\d{2})', r'\1/\2/\3\4'), (r'(\d{1,2})\s+(janvier|février|mars|avril|mai|juin|juillet|août|septembre|octobre|novembre|décembre)\s+(\d{4})', self._convert_french_date), ] self.time_patterns = [ (r'(\d{1,2})\s+heures?\s+(\d{1,2})', r'\1:\2'), (r'(\d{1,2})\s+h\s+(\d{1,2})', r'\1:\2'), (r'midi\s+(\d{1,2})', r'12:\1'), (r'minuit\s+(\d{1,2})', r'00:\1'), ] def _convert_french_date(self, match): """Convert French month names to numbers""" months = { 'janvier': '01', 'février': '02', 'mars': '03', 'avril': '04', 'mai': '05', 'juin': '06', 'juillet': '07', 'août': '08', 'septembre': '09', 'octobre': '10', 'novembre': '11', 'décembre': '12' } day, month, year = match.groups() return f"{day}/{months.get(month.lower(), month)}/{year}" def normalize_dates_times(self, text: str) -> str: """Normalize all dates and times in the text""" result = text for pattern, replacement in self.date_patterns: if callable(replacement): result = re.sub(pattern, replacement, result, flags=re.IGNORECASE) else: result = re.sub(pattern, replacement, result) for pattern, replacement in self.time_patterns: result = re.sub(pattern, replacement, result, flags=re.IGNORECASE) return result # Initialize global corrector instance medical_corrector = None datetime_normalizer = DateTimeNormalizer() def initialize_medical_corrector(cache_dir: str = "./model_cache"): """Initialize the medical corrector (call once at startup)""" global medical_corrector if medical_corrector is None: medical_corrector = MedicalTermCorrector(cache_dir) @tool def load_transcription(transcription_path: str) -> str: """Load and return the raw transcription text from a file.""" if not os.path.exists(transcription_path): raise FileNotFoundError(f"Transcription file not found: {transcription_path}") with open(transcription_path, 'r', encoding='utf-8') as f: return f.read().strip() def load_transcription_with_user_id(transcription_path: str) -> Tuple[str, str]: """Load transcription text and user_id from a JSON file.""" if not os.path.exists(transcription_path): raise FileNotFoundError(f"Transcription file not found: {transcription_path}") with open(transcription_path, 'r', encoding='utf-8') as f: data = json.load(f) transcription_text = data.get('transcription', '') user_id = data.get('user_id', 'unknown') return transcription_text, user_id def preprocess_medical_text(text: str) -> str: """Preprocess medical text using ML models""" global medical_corrector if medical_corrector is None: initialize_medical_corrector() # Apply ML-based medical term correction corrected_text = medical_corrector.correct_medical_text(text) # Normalize dates and times corrected_text = datetime_normalizer.normalize_dates_times(corrected_text) # Fix common formatting issues corrected_text = re.sub(r'(\d+)\s*x\s*(\d+)\s*mm', r'\1 x \2 mm', corrected_text) corrected_text = re.sub(r'(\d+)\s*mm\s*sur\s*(\d+)\s*mm', r'\1 mm x \2 mm', corrected_text) return corrected_text def create_transcription_corrector_chain(llm): """Create the enhanced transcription corrector chain with ML preprocessing""" transcription_corrector_prompt = ChatPromptTemplate.from_messages([ ("system", """You are an expert medical transcriptionist with deep knowledge of French medical terminology. The text you receive has already been preprocessed with ML models for medical term correction. Your task is to refine the document structure and ensure professional medical report formatting: FORMATTING AND STRUCTURE: - Organize content into clear medical report sections: * Title/Header * Clinical indication * Technique/Method * Results/Findings (use bullet points for lists) * Conclusion - Ensure proper spacing and line breaks - Replace "la ligne" and "à la ligne" with appropriate line breaks - Replace "point" with periods followed by line breaks when contextually appropriate - Format measurements consistently (e.g., "72 x 40 mm") - Ensure proper capitalization of medical terms and proper nouns QUALITY ASSURANCE: - Verify that medical terminology is accurate and consistent - Ensure dates are properly formatted (DD/MM/YYYY) - Check that medical classifications are correct (BI-RADS, ACR, etc.) - Maintain professional medical language throughout - Ensure logical flow and coherence PRESERVATION RULES: - Maintain all original medical content and findings - Do not add clinical interpretations not present in the original - Preserve all measurements, dates, and technical details - Keep the original meaning and medical context intact Return the refined text as a properly formatted medical report without explanations."""), ("human", """Refine and format the following preprocessed medical transcription: {transcription} Focus on professional medical report structure and formatting while preserving all original medical content.""") ]) return transcription_corrector_prompt | llm def create_medical_analyzer_chain(llm): """Create the medical analyzer chain with ML enhancement""" medical_analyzer_prompt = ChatPromptTemplate.from_messages([ ("system", """You are a medical information extractor with expertise in French medical terminology. Extract and categorize ONLY the medical information that is explicitly mentioned in the transcription. The text has been preprocessed with ML models for better medical term accuracy. EXTRACTION CATEGORIES: 1. Procedure/Examination type 2. Clinical indication 3. Technique/Method used 4. Anatomical structures examined 5. Measurements and dimensions 6. Pathological findings 7. Normal findings 8. Medical classifications (BI-RADS, ACR, etc.) 9. Recommendations/Follow-up 10. Conclusion stated in the report EXTRACTION RULES: - Focus on clinical findings, measurements, and observations - Extract exact measurements with units - Identify medical procedures and techniques - Note anatomical structures and their conditions - Include any pathological or normal findings - Preserve medical classifications and scores - DO NOT add interpretations beyond what's stated - DO NOT make clinical assumptions Organize the extracted information clearly under each category."""), ("human", """Extract and organize the medical information from this transcription: {corrected_transcription} Provide a structured medical analysis with clear categorization.""") ]) return medical_analyzer_prompt | llm def create_title_generator_chain(llm): """Create the title generator chain with medical context""" title_generator_prompt = ChatPromptTemplate.from_messages([ ("system", """You are a medical report title generator with expertise in French medical terminology. Generate a professional medical report title in FRENCH based on the medical data and findings. TITLE GUIDELINES: - Be specific to the examination type (IRM, mammographie, échographie, TDM, etc.) - Include the anatomical region examined - Use standard French medical terminology - Keep titles concise but informative (5-10 words) - Follow French medical report conventions - Consider the primary purpose (dépistage, surveillance, diagnostic, etc.) EXAMPLES: - "Mammographie de dépistage bilatérale" - "IRM pelvienne - Exploration utérine" - "Échographie abdominale - Surveillance hépatique" - "TDM thoracique avec injection de contraste" - "Radiographie pulmonaire - Contrôle post-opératoire" Return ONLY the title in French."""), ("human", """Generate a professional medical report title in French for: {medical_data} Create a concise, specific title that reflects the examination type and focus.""") ]) return title_generator_prompt | llm def process_medical_transcription(transcription_text: str, llm) -> Dict[str, Any]: """Complete ML-enhanced processing pipeline for medical transcription""" # Initialize ML corrector if not already done if medical_corrector is None: initialize_medical_corrector() # Step 1: ML-based preprocessing logger.info("Applying ML-based medical term correction...") preprocessed_text = preprocess_medical_text(transcription_text) # Step 2: LLM-based structure and formatting correction logger.info("Applying LLM-based formatting and structure correction...") corrector_chain = create_transcription_corrector_chain(llm) corrected_text = corrector_chain.invoke({"transcription": preprocessed_text}) # Step 3: Medical content analysis logger.info("Extracting medical information...") analyzer_chain = create_medical_analyzer_chain(llm) medical_analysis = analyzer_chain.invoke({"corrected_transcription": corrected_text}) # Step 4: Generate appropriate title logger.info("Generating medical report title...") title_chain = create_title_generator_chain(llm) title = title_chain.invoke({"medical_data": medical_analysis}) # Step 5: Extract entities for validation entities = medical_corrector.extract_medical_entities(corrected_text) if medical_corrector else [] return { "original_transcription": transcription_text, "preprocessed_transcription": preprocessed_text, "corrected_transcription": corrected_text, "medical_analysis": medical_analysis, "title": title, "extracted_entities": entities, "processing_info": { "ml_corrections_applied": True, "entities_extracted": len(entities), "preprocessing_successful": True } } def validate_medical_transcription(corrected_text: str, entities: List[Dict]) -> List[str]: """Validate the corrected medical transcription using ML insights""" issues = [] # Check entity consistency medical_entities = [e for e in entities if e.get('entity_group') in ['MEDICATION', 'DISEASE', 'TREATMENT']] if len(medical_entities) < 2: issues.append("Low medical entity density detected - review for missing medical terms") # Check for measurement formatting measurements = re.findall(r'\d+\s*[x×]\s*\d+\s*mm', corrected_text) if not measurements and ('mm' in corrected_text or 'cm' in corrected_text): issues.append("Measurement formatting may need attention") # Check for date consistency dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', corrected_text) if not dates and re.search(r'\d{4}', corrected_text): issues.append("Date formatting may need standardization") return issues