File size: 22,215 Bytes
f92da22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
#!/usr/bin/env python3
"""
Enhanced Medical Transcription Processor with ML-based correction
Uses pretrained models and embeddings for dynamic medical term correction
"""

import os
import json
import re
import numpy as np
from typing import Dict, Any, Tuple, List, Optional
from langchain.tools import tool
from langchain.prompts import ChatPromptTemplate
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import spacy
from spacy.matcher import Matcher
import torch
from transformers import (
    AutoTokenizer, AutoModelForMaskedLM, 
    AutoModelForTokenClassification, pipeline
)
from difflib import SequenceMatcher
import pickle
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class MedicalTermCorrector:
    """ML-based medical term corrector using pretrained models and embeddings"""
    
    def __init__(self, cache_dir: str = "./model_cache"):
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        
        # Initialize models
        self.sentence_transformer = None
        self.nlp = None
        self.medical_ner = None
        self.masked_lm_model = None
        self.masked_lm_tokenizer = None
        self.medical_embeddings = None
        self.medical_terms = None
        
        self._load_models()
        self._load_medical_knowledge()
    
    def _load_models(self):
        """Load pretrained models for medical text processing"""
        try:
            # Load sentence transformer for semantic similarity
            logger.info("Loading sentence transformer model...")
            self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
            
            # Load spaCy model for NER and linguistic processing
            logger.info("Loading spaCy model...")
            try:
                self.nlp = spacy.load("fr_core_news_sm")
            except OSError:
                logger.warning("French spaCy model not found. Install with: python -m spacy download fr_core_news_sm")
                self.nlp = spacy.load("en_core_web_sm")  # Fallback
            
            # Load medical NER model
            logger.info("Loading medical NER model...")
            self.medical_ner = pipeline(
                "ner",
                model="samrawal/bert-base-uncased_clinical-ner",
                tokenizer="samrawal/bert-base-uncased_clinical-ner",
                aggregation_strategy="simple"
            )
            
            # Load masked language model for context-aware corrections
            logger.info("Loading masked language model...")
            self.masked_lm_tokenizer = AutoTokenizer.from_pretrained("camembert-base")
            self.masked_lm_model = AutoModelForMaskedLM.from_pretrained("camembert-base")
            
        except Exception as e:
            logger.error(f"Error loading models: {e}")
            raise
    
    def _load_medical_knowledge(self):
        """Load medical terminology from external sources"""
        # Check if cached embeddings exist
        embeddings_path = os.path.join(self.cache_dir, "medical_embeddings.pkl")
        terms_path = os.path.join(self.cache_dir, "medical_terms.pkl")
        
        if os.path.exists(embeddings_path) and os.path.exists(terms_path):
            logger.info("Loading cached medical knowledge...")
            with open(embeddings_path, 'rb') as f:
                self.medical_embeddings = pickle.load(f)
            with open(terms_path, 'rb') as f:
                self.medical_terms = pickle.load(f)
        else:
            logger.info("Building medical knowledge base...")
            self._build_medical_knowledge()
    
    def _build_medical_knowledge(self):
        """Build medical knowledge base from various sources"""
        # Medical terms from various domains
        medical_terms = [
            # Radiology
            "mammographie", "échographie", "IRM", "TDM", "radiographie",
            "scintigraphie", "angiographie", "arthrographie",
            
            # Anatomy
            "utérus", "ovaires", "myomètre", "endomètre", "cervix",
            "ganglions", "axillaires", "mammaire", "pelvien",
            "thyroïde", "pancréas", "foie", "rate", "reins",
            
            # Pathology
            "adénomyose", "endométriose", "fibrome", "kyste",
            "carcinome", "adénome", "métastase", "tumeur",
            "inflammation", "nécrose", "hémorragie", "œdème",
            
            # Classifications
            "BI-RADS", "ACR", "TNM", "WHO", "BIRADS",
            
            # Procedures
            "biopsie", "dépistage", "surveillance", "contrôle",
            "ponction", "drainage", "résection", "ablation",
            
            # Measurements
            "millimètre", "centimètre", "millilitre", "gramme",
            "pourcentage", "degré", "unité", "concentration"
        ]
        
        # Generate embeddings for medical terms
        logger.info("Generating embeddings for medical terms...")
        self.medical_terms = medical_terms
        self.medical_embeddings = self.sentence_transformer.encode(medical_terms)
        
        # Cache the embeddings
        embeddings_path = os.path.join(self.cache_dir, "medical_embeddings.pkl")
        terms_path = os.path.join(self.cache_dir, "medical_terms.pkl")
        
        with open(embeddings_path, 'wb') as f:
            pickle.dump(self.medical_embeddings, f)
        with open(terms_path, 'wb') as f:
            pickle.dump(self.medical_terms, f)
    
    def find_similar_medical_term(self, word: str, threshold: float = 0.7) -> Optional[str]:
        """Find the most similar medical term using embeddings"""
        if not self.medical_embeddings.any():
            return None
            
        try:
            word_embedding = self.sentence_transformer.encode([word])
            similarities = cosine_similarity(word_embedding, self.medical_embeddings)[0]
            
            max_similarity_idx = np.argmax(similarities)
            max_similarity = similarities[max_similarity_idx]
            
            if max_similarity > threshold:
                return self.medical_terms[max_similarity_idx]
            
        except Exception as e:
            logger.warning(f"Error finding similar term for '{word}': {e}")
        
        return None
    
    def correct_with_context(self, sentence: str, target_word: str) -> str:
        """Use masked language model for context-aware correction"""
        try:
            # Replace target word with mask token
            masked_sentence = sentence.replace(target_word, self.masked_lm_tokenizer.mask_token)
            
            # Tokenize and get predictions
            inputs = self.masked_lm_tokenizer(masked_sentence, return_tensors="pt")
            
            with torch.no_grad():
                outputs = self.masked_lm_model(**inputs)
                predictions = outputs.logits
            
            # Find mask token position
            mask_token_index = torch.where(inputs["input_ids"] == self.masked_lm_tokenizer.mask_token_id)[1]
            
            if len(mask_token_index) > 0:
                mask_token_logits = predictions[0, mask_token_index, :]
                top_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
                
                # Get top predictions
                candidates = [self.masked_lm_tokenizer.decode([token]) for token in top_tokens]
                
                # Filter for medical relevance
                for candidate in candidates:
                    candidate = candidate.strip()
                    if len(candidate) > 2 and candidate.isalpha():
                        # Check if candidate is medically relevant
                        similar_term = self.find_similar_medical_term(candidate, threshold=0.6)
                        if similar_term:
                            return similar_term
                        
                        # Check string similarity with original
                        if SequenceMatcher(None, target_word.lower(), candidate.lower()).ratio() > 0.6:
                            return candidate
            
        except Exception as e:
            logger.warning(f"Error in context correction for '{target_word}': {e}")
        
        return target_word
    
    def extract_medical_entities(self, text: str) -> List[Dict]:
        """Extract medical entities using NER"""
        try:
            entities = self.medical_ner(text)
            return entities
        except Exception as e:
            logger.warning(f"Error in medical NER: {e}")
            return []
    
    def correct_medical_text(self, text: str) -> str:
        """Main method to correct medical text using ML models"""
        corrected_text = text
        
        # Extract potential medical entities
        entities = self.extract_medical_entities(text)
        
        # Process with spaCy for linguistic analysis
        if self.nlp:
            doc = self.nlp(text)
            
            # Find words that might be medical terms but are misspelled
            for token in doc:
                if (token.is_alpha and len(token.text) > 3 and 
                    not token.is_stop and not token.like_url):
                    
                    # Check if it's a potential medical term
                    similar_term = self.find_similar_medical_term(token.text)
                    
                    if similar_term and similar_term != token.text:
                        # Use context-aware correction
                        context_correction = self.correct_with_context(text, token.text)
                        
                        # Choose the best correction
                        final_correction = context_correction if context_correction != token.text else similar_term
                        
                        # Apply correction with word boundaries
                        pattern = r'\b' + re.escape(token.text) + r'\b'
                        corrected_text = re.sub(pattern, final_correction, corrected_text, flags=re.IGNORECASE)
                        
                        logger.info(f"Corrected: '{token.text}' -> '{final_correction}'")
        
        return corrected_text


class DateTimeNormalizer:
    """Normalize dates and times in medical texts using regex patterns"""
    
    def __init__(self):
        self.date_patterns = [
            # French date patterns
            (r'(\d{1,2})\s+(\d{1,2})\s+(\d{4})', r'\1/\2/\3'),
            (r'(\d{1,2})\s+(\d{1,2})\s+(\d{2})\s+(\d{2})', r'\1/\2/\3\4'),
            (r'(\d{1,2})\s+(janvier|février|mars|avril|mai|juin|juillet|août|septembre|octobre|novembre|décembre)\s+(\d{4})', 
             self._convert_french_date),
        ]
        
        self.time_patterns = [
            (r'(\d{1,2})\s+heures?\s+(\d{1,2})', r'\1:\2'),
            (r'(\d{1,2})\s+h\s+(\d{1,2})', r'\1:\2'),
            (r'midi\s+(\d{1,2})', r'12:\1'),
            (r'minuit\s+(\d{1,2})', r'00:\1'),
        ]
    
    def _convert_french_date(self, match):
        """Convert French month names to numbers"""
        months = {
            'janvier': '01', 'février': '02', 'mars': '03', 'avril': '04',
            'mai': '05', 'juin': '06', 'juillet': '07', 'août': '08',
            'septembre': '09', 'octobre': '10', 'novembre': '11', 'décembre': '12'
        }
        day, month, year = match.groups()
        return f"{day}/{months.get(month.lower(), month)}/{year}"
    
    def normalize_dates_times(self, text: str) -> str:
        """Normalize all dates and times in the text"""
        result = text
        
        for pattern, replacement in self.date_patterns:
            if callable(replacement):
                result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
            else:
                result = re.sub(pattern, replacement, result)
        
        for pattern, replacement in self.time_patterns:
            result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
        
        return result


# Initialize global corrector instance
medical_corrector = None
datetime_normalizer = DateTimeNormalizer()


def initialize_medical_corrector(cache_dir: str = "./model_cache"):
    """Initialize the medical corrector (call once at startup)"""
    global medical_corrector
    if medical_corrector is None:
        medical_corrector = MedicalTermCorrector(cache_dir)


@tool
def load_transcription(transcription_path: str) -> str:
    """Load and return the raw transcription text from a file."""
    if not os.path.exists(transcription_path):
        raise FileNotFoundError(f"Transcription file not found: {transcription_path}")

    with open(transcription_path, 'r', encoding='utf-8') as f:
        return f.read().strip()


def load_transcription_with_user_id(transcription_path: str) -> Tuple[str, str]:
    """Load transcription text and user_id from a JSON file."""
    if not os.path.exists(transcription_path):
        raise FileNotFoundError(f"Transcription file not found: {transcription_path}")

    with open(transcription_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    transcription_text = data.get('transcription', '')
    user_id = data.get('user_id', 'unknown')

    return transcription_text, user_id


def preprocess_medical_text(text: str) -> str:
    """Preprocess medical text using ML models"""
    global medical_corrector
    
    if medical_corrector is None:
        initialize_medical_corrector()
    
    # Apply ML-based medical term correction
    corrected_text = medical_corrector.correct_medical_text(text)
    
    # Normalize dates and times
    corrected_text = datetime_normalizer.normalize_dates_times(corrected_text)
    
    # Fix common formatting issues
    corrected_text = re.sub(r'(\d+)\s*x\s*(\d+)\s*mm', r'\1 x \2 mm', corrected_text)
    corrected_text = re.sub(r'(\d+)\s*mm\s*sur\s*(\d+)\s*mm', r'\1 mm x \2 mm', corrected_text)
    
    return corrected_text


def create_transcription_corrector_chain(llm):
    """Create the enhanced transcription corrector chain with ML preprocessing"""
    transcription_corrector_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are an expert medical transcriptionist with deep knowledge of French medical terminology. 
        The text you receive has already been preprocessed with ML models for medical term correction.

        Your task is to refine the document structure and ensure professional medical report formatting:

        FORMATTING AND STRUCTURE:
        - Organize content into clear medical report sections:
          * Title/Header
          * Clinical indication
          * Technique/Method
          * Results/Findings (use bullet points for lists)
          * Conclusion
        - Ensure proper spacing and line breaks
        - Replace "la ligne" and "à la ligne" with appropriate line breaks
        - Replace "point" with periods followed by line breaks when contextually appropriate
        - Format measurements consistently (e.g., "72 x 40 mm")
        - Ensure proper capitalization of medical terms and proper nouns

        QUALITY ASSURANCE:
        - Verify that medical terminology is accurate and consistent
        - Ensure dates are properly formatted (DD/MM/YYYY)
        - Check that medical classifications are correct (BI-RADS, ACR, etc.)
        - Maintain professional medical language throughout
        - Ensure logical flow and coherence

        PRESERVATION RULES:
        - Maintain all original medical content and findings
        - Do not add clinical interpretations not present in the original
        - Preserve all measurements, dates, and technical details
        - Keep the original meaning and medical context intact

        Return the refined text as a properly formatted medical report without explanations."""),
        ("human", """Refine and format the following preprocessed medical transcription:

{transcription}

Focus on professional medical report structure and formatting while preserving all original medical content.""")
    ])

    return transcription_corrector_prompt | llm


def create_medical_analyzer_chain(llm):
    """Create the medical analyzer chain with ML enhancement"""
    medical_analyzer_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a medical information extractor with expertise in French medical terminology.
        
        Extract and categorize ONLY the medical information that is explicitly mentioned in the transcription.
        The text has been preprocessed with ML models for better medical term accuracy.
        
        EXTRACTION CATEGORIES:
        1. Procedure/Examination type
        2. Clinical indication
        3. Technique/Method used
        4. Anatomical structures examined
        5. Measurements and dimensions
        6. Pathological findings
        7. Normal findings
        8. Medical classifications (BI-RADS, ACR, etc.)
        9. Recommendations/Follow-up
        10. Conclusion stated in the report
        
        EXTRACTION RULES:
        - Focus on clinical findings, measurements, and observations
        - Extract exact measurements with units
        - Identify medical procedures and techniques
        - Note anatomical structures and their conditions
        - Include any pathological or normal findings
        - Preserve medical classifications and scores
        - DO NOT add interpretations beyond what's stated
        - DO NOT make clinical assumptions
        
        Organize the extracted information clearly under each category."""),
        ("human", """Extract and organize the medical information from this transcription:

{corrected_transcription}

Provide a structured medical analysis with clear categorization.""")
    ])

    return medical_analyzer_prompt | llm


def create_title_generator_chain(llm):
    """Create the title generator chain with medical context"""
    title_generator_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a medical report title generator with expertise in French medical terminology.
        
        Generate a professional medical report title in FRENCH based on the medical data and findings.
        
        TITLE GUIDELINES:
        - Be specific to the examination type (IRM, mammographie, échographie, TDM, etc.)
        - Include the anatomical region examined
        - Use standard French medical terminology
        - Keep titles concise but informative (5-10 words)
        - Follow French medical report conventions
        - Consider the primary purpose (dépistage, surveillance, diagnostic, etc.)
        
        EXAMPLES:
        - "Mammographie de dépistage bilatérale"
        - "IRM pelvienne - Exploration utérine"
        - "Échographie abdominale - Surveillance hépatique"
        - "TDM thoracique avec injection de contraste"
        - "Radiographie pulmonaire - Contrôle post-opératoire"
        
        Return ONLY the title in French."""),
        ("human", """Generate a professional medical report title in French for:

{medical_data}

Create a concise, specific title that reflects the examination type and focus.""")
    ])

    return title_generator_prompt | llm


def process_medical_transcription(transcription_text: str, llm) -> Dict[str, Any]:
    """Complete ML-enhanced processing pipeline for medical transcription"""
    
    # Initialize ML corrector if not already done
    if medical_corrector is None:
        initialize_medical_corrector()
    
    # Step 1: ML-based preprocessing
    logger.info("Applying ML-based medical term correction...")
    preprocessed_text = preprocess_medical_text(transcription_text)
    
    # Step 2: LLM-based structure and formatting correction
    logger.info("Applying LLM-based formatting and structure correction...")
    corrector_chain = create_transcription_corrector_chain(llm)
    corrected_text = corrector_chain.invoke({"transcription": preprocessed_text})
    
    # Step 3: Medical content analysis
    logger.info("Extracting medical information...")
    analyzer_chain = create_medical_analyzer_chain(llm)
    medical_analysis = analyzer_chain.invoke({"corrected_transcription": corrected_text})
    
    # Step 4: Generate appropriate title
    logger.info("Generating medical report title...")
    title_chain = create_title_generator_chain(llm)
    title = title_chain.invoke({"medical_data": medical_analysis})
    
    # Step 5: Extract entities for validation
    entities = medical_corrector.extract_medical_entities(corrected_text) if medical_corrector else []
    
    return {
        "original_transcription": transcription_text,
        "preprocessed_transcription": preprocessed_text,
        "corrected_transcription": corrected_text,
        "medical_analysis": medical_analysis,
        "title": title,
        "extracted_entities": entities,
        "processing_info": {
            "ml_corrections_applied": True,
            "entities_extracted": len(entities),
            "preprocessing_successful": True
        }
    }


def validate_medical_transcription(corrected_text: str, entities: List[Dict]) -> List[str]:
    """Validate the corrected medical transcription using ML insights"""
    issues = []
    
    # Check entity consistency
    medical_entities = [e for e in entities if e.get('entity_group') in ['MEDICATION', 'DISEASE', 'TREATMENT']]
    if len(medical_entities) < 2:
        issues.append("Low medical entity density detected - review for missing medical terms")
    
    # Check for measurement formatting
    measurements = re.findall(r'\d+\s*[x×]\s*\d+\s*mm', corrected_text)
    if not measurements and ('mm' in corrected_text or 'cm' in corrected_text):
        issues.append("Measurement formatting may need attention")
    
    # Check for date consistency
    dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', corrected_text)
    if not dates and re.search(r'\d{4}', corrected_text):
        issues.append("Date formatting may need standardization")
    
    return issues