|
|
|
|
|
""" |
|
|
Enhanced Medical Transcription Processor with ML-based correction |
|
|
Uses pretrained models and embeddings for dynamic medical term correction |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import re |
|
|
import numpy as np |
|
|
from typing import Dict, Any, Tuple, List, Optional |
|
|
from langchain.tools import tool |
|
|
from langchain.prompts import ChatPromptTemplate |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import spacy |
|
|
from spacy.matcher import Matcher |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, AutoModelForMaskedLM, |
|
|
AutoModelForTokenClassification, pipeline |
|
|
) |
|
|
from difflib import SequenceMatcher |
|
|
import pickle |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MedicalTermCorrector: |
|
|
"""ML-based medical term corrector using pretrained models and embeddings""" |
|
|
|
|
|
def __init__(self, cache_dir: str = "./model_cache"): |
|
|
self.cache_dir = cache_dir |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.sentence_transformer = None |
|
|
self.nlp = None |
|
|
self.medical_ner = None |
|
|
self.masked_lm_model = None |
|
|
self.masked_lm_tokenizer = None |
|
|
self.medical_embeddings = None |
|
|
self.medical_terms = None |
|
|
|
|
|
self._load_models() |
|
|
self._load_medical_knowledge() |
|
|
|
|
|
def _load_models(self): |
|
|
"""Load pretrained models for medical text processing""" |
|
|
try: |
|
|
|
|
|
logger.info("Loading sentence transformer model...") |
|
|
self.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
|
|
|
logger.info("Loading spaCy model...") |
|
|
try: |
|
|
self.nlp = spacy.load("fr_core_news_sm") |
|
|
except OSError: |
|
|
logger.warning("French spaCy model not found. Install with: python -m spacy download fr_core_news_sm") |
|
|
self.nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
|
|
|
logger.info("Loading medical NER model...") |
|
|
self.medical_ner = pipeline( |
|
|
"ner", |
|
|
model="samrawal/bert-base-uncased_clinical-ner", |
|
|
tokenizer="samrawal/bert-base-uncased_clinical-ner", |
|
|
aggregation_strategy="simple" |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Loading masked language model...") |
|
|
self.masked_lm_tokenizer = AutoTokenizer.from_pretrained("camembert-base") |
|
|
self.masked_lm_model = AutoModelForMaskedLM.from_pretrained("camembert-base") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading models: {e}") |
|
|
raise |
|
|
|
|
|
def _load_medical_knowledge(self): |
|
|
"""Load medical terminology from external sources""" |
|
|
|
|
|
embeddings_path = os.path.join(self.cache_dir, "medical_embeddings.pkl") |
|
|
terms_path = os.path.join(self.cache_dir, "medical_terms.pkl") |
|
|
|
|
|
if os.path.exists(embeddings_path) and os.path.exists(terms_path): |
|
|
logger.info("Loading cached medical knowledge...") |
|
|
with open(embeddings_path, 'rb') as f: |
|
|
self.medical_embeddings = pickle.load(f) |
|
|
with open(terms_path, 'rb') as f: |
|
|
self.medical_terms = pickle.load(f) |
|
|
else: |
|
|
logger.info("Building medical knowledge base...") |
|
|
self._build_medical_knowledge() |
|
|
|
|
|
def _build_medical_knowledge(self): |
|
|
"""Build medical knowledge base from various sources""" |
|
|
|
|
|
medical_terms = [ |
|
|
|
|
|
"mammographie", "échographie", "IRM", "TDM", "radiographie", |
|
|
"scintigraphie", "angiographie", "arthrographie", |
|
|
|
|
|
|
|
|
"utérus", "ovaires", "myomètre", "endomètre", "cervix", |
|
|
"ganglions", "axillaires", "mammaire", "pelvien", |
|
|
"thyroïde", "pancréas", "foie", "rate", "reins", |
|
|
|
|
|
|
|
|
"adénomyose", "endométriose", "fibrome", "kyste", |
|
|
"carcinome", "adénome", "métastase", "tumeur", |
|
|
"inflammation", "nécrose", "hémorragie", "œdème", |
|
|
|
|
|
|
|
|
"BI-RADS", "ACR", "TNM", "WHO", "BIRADS", |
|
|
|
|
|
|
|
|
"biopsie", "dépistage", "surveillance", "contrôle", |
|
|
"ponction", "drainage", "résection", "ablation", |
|
|
|
|
|
|
|
|
"millimètre", "centimètre", "millilitre", "gramme", |
|
|
"pourcentage", "degré", "unité", "concentration" |
|
|
] |
|
|
|
|
|
|
|
|
logger.info("Generating embeddings for medical terms...") |
|
|
self.medical_terms = medical_terms |
|
|
self.medical_embeddings = self.sentence_transformer.encode(medical_terms) |
|
|
|
|
|
|
|
|
embeddings_path = os.path.join(self.cache_dir, "medical_embeddings.pkl") |
|
|
terms_path = os.path.join(self.cache_dir, "medical_terms.pkl") |
|
|
|
|
|
with open(embeddings_path, 'wb') as f: |
|
|
pickle.dump(self.medical_embeddings, f) |
|
|
with open(terms_path, 'wb') as f: |
|
|
pickle.dump(self.medical_terms, f) |
|
|
|
|
|
def find_similar_medical_term(self, word: str, threshold: float = 0.7) -> Optional[str]: |
|
|
"""Find the most similar medical term using embeddings""" |
|
|
if not self.medical_embeddings.any(): |
|
|
return None |
|
|
|
|
|
try: |
|
|
word_embedding = self.sentence_transformer.encode([word]) |
|
|
similarities = cosine_similarity(word_embedding, self.medical_embeddings)[0] |
|
|
|
|
|
max_similarity_idx = np.argmax(similarities) |
|
|
max_similarity = similarities[max_similarity_idx] |
|
|
|
|
|
if max_similarity > threshold: |
|
|
return self.medical_terms[max_similarity_idx] |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Error finding similar term for '{word}': {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
def correct_with_context(self, sentence: str, target_word: str) -> str: |
|
|
"""Use masked language model for context-aware correction""" |
|
|
try: |
|
|
|
|
|
masked_sentence = sentence.replace(target_word, self.masked_lm_tokenizer.mask_token) |
|
|
|
|
|
|
|
|
inputs = self.masked_lm_tokenizer(masked_sentence, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.masked_lm_model(**inputs) |
|
|
predictions = outputs.logits |
|
|
|
|
|
|
|
|
mask_token_index = torch.where(inputs["input_ids"] == self.masked_lm_tokenizer.mask_token_id)[1] |
|
|
|
|
|
if len(mask_token_index) > 0: |
|
|
mask_token_logits = predictions[0, mask_token_index, :] |
|
|
top_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() |
|
|
|
|
|
|
|
|
candidates = [self.masked_lm_tokenizer.decode([token]) for token in top_tokens] |
|
|
|
|
|
|
|
|
for candidate in candidates: |
|
|
candidate = candidate.strip() |
|
|
if len(candidate) > 2 and candidate.isalpha(): |
|
|
|
|
|
similar_term = self.find_similar_medical_term(candidate, threshold=0.6) |
|
|
if similar_term: |
|
|
return similar_term |
|
|
|
|
|
|
|
|
if SequenceMatcher(None, target_word.lower(), candidate.lower()).ratio() > 0.6: |
|
|
return candidate |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Error in context correction for '{target_word}': {e}") |
|
|
|
|
|
return target_word |
|
|
|
|
|
def extract_medical_entities(self, text: str) -> List[Dict]: |
|
|
"""Extract medical entities using NER""" |
|
|
try: |
|
|
entities = self.medical_ner(text) |
|
|
return entities |
|
|
except Exception as e: |
|
|
logger.warning(f"Error in medical NER: {e}") |
|
|
return [] |
|
|
|
|
|
def correct_medical_text(self, text: str) -> str: |
|
|
"""Main method to correct medical text using ML models""" |
|
|
corrected_text = text |
|
|
|
|
|
|
|
|
entities = self.extract_medical_entities(text) |
|
|
|
|
|
|
|
|
if self.nlp: |
|
|
doc = self.nlp(text) |
|
|
|
|
|
|
|
|
for token in doc: |
|
|
if (token.is_alpha and len(token.text) > 3 and |
|
|
not token.is_stop and not token.like_url): |
|
|
|
|
|
|
|
|
similar_term = self.find_similar_medical_term(token.text) |
|
|
|
|
|
if similar_term and similar_term != token.text: |
|
|
|
|
|
context_correction = self.correct_with_context(text, token.text) |
|
|
|
|
|
|
|
|
final_correction = context_correction if context_correction != token.text else similar_term |
|
|
|
|
|
|
|
|
pattern = r'\b' + re.escape(token.text) + r'\b' |
|
|
corrected_text = re.sub(pattern, final_correction, corrected_text, flags=re.IGNORECASE) |
|
|
|
|
|
logger.info(f"Corrected: '{token.text}' -> '{final_correction}'") |
|
|
|
|
|
return corrected_text |
|
|
|
|
|
|
|
|
class DateTimeNormalizer: |
|
|
"""Normalize dates and times in medical texts using regex patterns""" |
|
|
|
|
|
def __init__(self): |
|
|
self.date_patterns = [ |
|
|
|
|
|
(r'(\d{1,2})\s+(\d{1,2})\s+(\d{4})', r'\1/\2/\3'), |
|
|
(r'(\d{1,2})\s+(\d{1,2})\s+(\d{2})\s+(\d{2})', r'\1/\2/\3\4'), |
|
|
(r'(\d{1,2})\s+(janvier|février|mars|avril|mai|juin|juillet|août|septembre|octobre|novembre|décembre)\s+(\d{4})', |
|
|
self._convert_french_date), |
|
|
] |
|
|
|
|
|
self.time_patterns = [ |
|
|
(r'(\d{1,2})\s+heures?\s+(\d{1,2})', r'\1:\2'), |
|
|
(r'(\d{1,2})\s+h\s+(\d{1,2})', r'\1:\2'), |
|
|
(r'midi\s+(\d{1,2})', r'12:\1'), |
|
|
(r'minuit\s+(\d{1,2})', r'00:\1'), |
|
|
] |
|
|
|
|
|
def _convert_french_date(self, match): |
|
|
"""Convert French month names to numbers""" |
|
|
months = { |
|
|
'janvier': '01', 'février': '02', 'mars': '03', 'avril': '04', |
|
|
'mai': '05', 'juin': '06', 'juillet': '07', 'août': '08', |
|
|
'septembre': '09', 'octobre': '10', 'novembre': '11', 'décembre': '12' |
|
|
} |
|
|
day, month, year = match.groups() |
|
|
return f"{day}/{months.get(month.lower(), month)}/{year}" |
|
|
|
|
|
def normalize_dates_times(self, text: str) -> str: |
|
|
"""Normalize all dates and times in the text""" |
|
|
result = text |
|
|
|
|
|
for pattern, replacement in self.date_patterns: |
|
|
if callable(replacement): |
|
|
result = re.sub(pattern, replacement, result, flags=re.IGNORECASE) |
|
|
else: |
|
|
result = re.sub(pattern, replacement, result) |
|
|
|
|
|
for pattern, replacement in self.time_patterns: |
|
|
result = re.sub(pattern, replacement, result, flags=re.IGNORECASE) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
medical_corrector = None |
|
|
datetime_normalizer = DateTimeNormalizer() |
|
|
|
|
|
|
|
|
def initialize_medical_corrector(cache_dir: str = "./model_cache"): |
|
|
"""Initialize the medical corrector (call once at startup)""" |
|
|
global medical_corrector |
|
|
if medical_corrector is None: |
|
|
medical_corrector = MedicalTermCorrector(cache_dir) |
|
|
|
|
|
|
|
|
@tool |
|
|
def load_transcription(transcription_path: str) -> str: |
|
|
"""Load and return the raw transcription text from a file.""" |
|
|
if not os.path.exists(transcription_path): |
|
|
raise FileNotFoundError(f"Transcription file not found: {transcription_path}") |
|
|
|
|
|
with open(transcription_path, 'r', encoding='utf-8') as f: |
|
|
return f.read().strip() |
|
|
|
|
|
|
|
|
def load_transcription_with_user_id(transcription_path: str) -> Tuple[str, str]: |
|
|
"""Load transcription text and user_id from a JSON file.""" |
|
|
if not os.path.exists(transcription_path): |
|
|
raise FileNotFoundError(f"Transcription file not found: {transcription_path}") |
|
|
|
|
|
with open(transcription_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
transcription_text = data.get('transcription', '') |
|
|
user_id = data.get('user_id', 'unknown') |
|
|
|
|
|
return transcription_text, user_id |
|
|
|
|
|
|
|
|
def preprocess_medical_text(text: str) -> str: |
|
|
"""Preprocess medical text using ML models""" |
|
|
global medical_corrector |
|
|
|
|
|
if medical_corrector is None: |
|
|
initialize_medical_corrector() |
|
|
|
|
|
|
|
|
corrected_text = medical_corrector.correct_medical_text(text) |
|
|
|
|
|
|
|
|
corrected_text = datetime_normalizer.normalize_dates_times(corrected_text) |
|
|
|
|
|
|
|
|
corrected_text = re.sub(r'(\d+)\s*x\s*(\d+)\s*mm', r'\1 x \2 mm', corrected_text) |
|
|
corrected_text = re.sub(r'(\d+)\s*mm\s*sur\s*(\d+)\s*mm', r'\1 mm x \2 mm', corrected_text) |
|
|
|
|
|
return corrected_text |
|
|
|
|
|
|
|
|
def create_transcription_corrector_chain(llm): |
|
|
"""Create the enhanced transcription corrector chain with ML preprocessing""" |
|
|
transcription_corrector_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", """You are an expert medical transcriptionist with deep knowledge of French medical terminology. |
|
|
The text you receive has already been preprocessed with ML models for medical term correction. |
|
|
|
|
|
Your task is to refine the document structure and ensure professional medical report formatting: |
|
|
|
|
|
FORMATTING AND STRUCTURE: |
|
|
- Organize content into clear medical report sections: |
|
|
* Title/Header |
|
|
* Clinical indication |
|
|
* Technique/Method |
|
|
* Results/Findings (use bullet points for lists) |
|
|
* Conclusion |
|
|
- Ensure proper spacing and line breaks |
|
|
- Replace "la ligne" and "à la ligne" with appropriate line breaks |
|
|
- Replace "point" with periods followed by line breaks when contextually appropriate |
|
|
- Format measurements consistently (e.g., "72 x 40 mm") |
|
|
- Ensure proper capitalization of medical terms and proper nouns |
|
|
|
|
|
QUALITY ASSURANCE: |
|
|
- Verify that medical terminology is accurate and consistent |
|
|
- Ensure dates are properly formatted (DD/MM/YYYY) |
|
|
- Check that medical classifications are correct (BI-RADS, ACR, etc.) |
|
|
- Maintain professional medical language throughout |
|
|
- Ensure logical flow and coherence |
|
|
|
|
|
PRESERVATION RULES: |
|
|
- Maintain all original medical content and findings |
|
|
- Do not add clinical interpretations not present in the original |
|
|
- Preserve all measurements, dates, and technical details |
|
|
- Keep the original meaning and medical context intact |
|
|
|
|
|
Return the refined text as a properly formatted medical report without explanations."""), |
|
|
("human", """Refine and format the following preprocessed medical transcription: |
|
|
|
|
|
{transcription} |
|
|
|
|
|
Focus on professional medical report structure and formatting while preserving all original medical content.""") |
|
|
]) |
|
|
|
|
|
return transcription_corrector_prompt | llm |
|
|
|
|
|
|
|
|
def create_medical_analyzer_chain(llm): |
|
|
"""Create the medical analyzer chain with ML enhancement""" |
|
|
medical_analyzer_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", """You are a medical information extractor with expertise in French medical terminology. |
|
|
|
|
|
Extract and categorize ONLY the medical information that is explicitly mentioned in the transcription. |
|
|
The text has been preprocessed with ML models for better medical term accuracy. |
|
|
|
|
|
EXTRACTION CATEGORIES: |
|
|
1. Procedure/Examination type |
|
|
2. Clinical indication |
|
|
3. Technique/Method used |
|
|
4. Anatomical structures examined |
|
|
5. Measurements and dimensions |
|
|
6. Pathological findings |
|
|
7. Normal findings |
|
|
8. Medical classifications (BI-RADS, ACR, etc.) |
|
|
9. Recommendations/Follow-up |
|
|
10. Conclusion stated in the report |
|
|
|
|
|
EXTRACTION RULES: |
|
|
- Focus on clinical findings, measurements, and observations |
|
|
- Extract exact measurements with units |
|
|
- Identify medical procedures and techniques |
|
|
- Note anatomical structures and their conditions |
|
|
- Include any pathological or normal findings |
|
|
- Preserve medical classifications and scores |
|
|
- DO NOT add interpretations beyond what's stated |
|
|
- DO NOT make clinical assumptions |
|
|
|
|
|
Organize the extracted information clearly under each category."""), |
|
|
("human", """Extract and organize the medical information from this transcription: |
|
|
|
|
|
{corrected_transcription} |
|
|
|
|
|
Provide a structured medical analysis with clear categorization.""") |
|
|
]) |
|
|
|
|
|
return medical_analyzer_prompt | llm |
|
|
|
|
|
|
|
|
def create_title_generator_chain(llm): |
|
|
"""Create the title generator chain with medical context""" |
|
|
title_generator_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", """You are a medical report title generator with expertise in French medical terminology. |
|
|
|
|
|
Generate a professional medical report title in FRENCH based on the medical data and findings. |
|
|
|
|
|
TITLE GUIDELINES: |
|
|
- Be specific to the examination type (IRM, mammographie, échographie, TDM, etc.) |
|
|
- Include the anatomical region examined |
|
|
- Use standard French medical terminology |
|
|
- Keep titles concise but informative (5-10 words) |
|
|
- Follow French medical report conventions |
|
|
- Consider the primary purpose (dépistage, surveillance, diagnostic, etc.) |
|
|
|
|
|
EXAMPLES: |
|
|
- "Mammographie de dépistage bilatérale" |
|
|
- "IRM pelvienne - Exploration utérine" |
|
|
- "Échographie abdominale - Surveillance hépatique" |
|
|
- "TDM thoracique avec injection de contraste" |
|
|
- "Radiographie pulmonaire - Contrôle post-opératoire" |
|
|
|
|
|
Return ONLY the title in French."""), |
|
|
("human", """Generate a professional medical report title in French for: |
|
|
|
|
|
{medical_data} |
|
|
|
|
|
Create a concise, specific title that reflects the examination type and focus.""") |
|
|
]) |
|
|
|
|
|
return title_generator_prompt | llm |
|
|
|
|
|
|
|
|
def process_medical_transcription(transcription_text: str, llm) -> Dict[str, Any]: |
|
|
"""Complete ML-enhanced processing pipeline for medical transcription""" |
|
|
|
|
|
|
|
|
if medical_corrector is None: |
|
|
initialize_medical_corrector() |
|
|
|
|
|
|
|
|
logger.info("Applying ML-based medical term correction...") |
|
|
preprocessed_text = preprocess_medical_text(transcription_text) |
|
|
|
|
|
|
|
|
logger.info("Applying LLM-based formatting and structure correction...") |
|
|
corrector_chain = create_transcription_corrector_chain(llm) |
|
|
corrected_text = corrector_chain.invoke({"transcription": preprocessed_text}) |
|
|
|
|
|
|
|
|
logger.info("Extracting medical information...") |
|
|
analyzer_chain = create_medical_analyzer_chain(llm) |
|
|
medical_analysis = analyzer_chain.invoke({"corrected_transcription": corrected_text}) |
|
|
|
|
|
|
|
|
logger.info("Generating medical report title...") |
|
|
title_chain = create_title_generator_chain(llm) |
|
|
title = title_chain.invoke({"medical_data": medical_analysis}) |
|
|
|
|
|
|
|
|
entities = medical_corrector.extract_medical_entities(corrected_text) if medical_corrector else [] |
|
|
|
|
|
return { |
|
|
"original_transcription": transcription_text, |
|
|
"preprocessed_transcription": preprocessed_text, |
|
|
"corrected_transcription": corrected_text, |
|
|
"medical_analysis": medical_analysis, |
|
|
"title": title, |
|
|
"extracted_entities": entities, |
|
|
"processing_info": { |
|
|
"ml_corrections_applied": True, |
|
|
"entities_extracted": len(entities), |
|
|
"preprocessing_successful": True |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def validate_medical_transcription(corrected_text: str, entities: List[Dict]) -> List[str]: |
|
|
"""Validate the corrected medical transcription using ML insights""" |
|
|
issues = [] |
|
|
|
|
|
|
|
|
medical_entities = [e for e in entities if e.get('entity_group') in ['MEDICATION', 'DISEASE', 'TREATMENT']] |
|
|
if len(medical_entities) < 2: |
|
|
issues.append("Low medical entity density detected - review for missing medical terms") |
|
|
|
|
|
|
|
|
measurements = re.findall(r'\d+\s*[x×]\s*\d+\s*mm', corrected_text) |
|
|
if not measurements and ('mm' in corrected_text or 'cm' in corrected_text): |
|
|
issues.append("Measurement formatting may need attention") |
|
|
|
|
|
|
|
|
dates = re.findall(r'\d{1,2}/\d{1,2}/\d{4}', corrected_text) |
|
|
if not dates and re.search(r'\d{4}', corrected_text): |
|
|
issues.append("Date formatting may need standardization") |
|
|
|
|
|
return issues |
|
|
|
|
|
|