|
|
import os |
|
|
import re |
|
|
import json |
|
|
import numpy as np |
|
|
from typing import List, Dict, Any, Optional, Tuple, Union |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, AutoModel, AutoModelForTokenClassification, |
|
|
TrainingArguments, Trainer, pipeline |
|
|
) |
|
|
from torch.utils.data import Dataset |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
|
|
|
|
|
|
import logging |
|
|
from tqdm import tqdm |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class MedicalEntity: |
|
|
"""Structure pour les entités médicales extraites par NER""" |
|
|
exam_types: List[Tuple[str, float]] |
|
|
specialties: List[Tuple[str, float]] |
|
|
anatomical_regions: List[Tuple[str, float]] |
|
|
pathologies: List[Tuple[str, float]] |
|
|
medical_procedures: List[Tuple[str, float]] |
|
|
measurements: List[Tuple[str, float]] |
|
|
medications: List[Tuple[str, float]] |
|
|
symptoms: List[Tuple[str, float]] |
|
|
|
|
|
class AdvancedMedicalNER: |
|
|
"""NER médical avancé basé sur CamemBERT-Bio fine-tuné""" |
|
|
|
|
|
def __init__(self, model_name: str = "auto", cache_dir: str = "./models_cache"): |
|
|
self.cache_dir = Path(cache_dir) |
|
|
self.cache_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
self.model_name = self._select_best_model(model_name) |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
self._load_ner_model() |
|
|
|
|
|
|
|
|
self.entity_labels = [ |
|
|
"O", |
|
|
"B-EXAM", "I-EXAM", |
|
|
"B-SPECIALTY", "I-SPECIALTY", |
|
|
"B-ANATOMY", "I-ANATOMY", |
|
|
"B-PATHOLOGY", "I-PATHOLOGY", |
|
|
"B-PROCEDURE", "I-PROCEDURE", |
|
|
"B-MEASURE", "I-MEASURE", |
|
|
"B-MEDICATION", "I-MEDICATION", |
|
|
"B-SYMPTOM", "I-SYMPTOM" |
|
|
] |
|
|
|
|
|
self.id2label = {i: label for i, label in enumerate(self.entity_labels)} |
|
|
self.label2id = {label: i for i, label in enumerate(self.entity_labels)} |
|
|
|
|
|
def _select_best_model(self, model_name: str) -> str: |
|
|
"""Sélection automatique du meilleur modèle NER médical""" |
|
|
|
|
|
if model_name != "auto": |
|
|
return model_name |
|
|
|
|
|
|
|
|
preferred_models = [ |
|
|
"almanach/camembert-bio-base", |
|
|
"Dr-BERT/DrBERT-7GB", |
|
|
"emilyalsentzer/Bio_ClinicalBERT", |
|
|
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", |
|
|
"dmis-lab/biobert-base-cased-v1.2", |
|
|
"camembert-base" |
|
|
] |
|
|
|
|
|
for model in preferred_models: |
|
|
try: |
|
|
|
|
|
AutoTokenizer.from_pretrained(model, cache_dir=self.cache_dir) |
|
|
logger.info(f"Modèle sélectionné: {model}") |
|
|
return model |
|
|
except: |
|
|
continue |
|
|
|
|
|
|
|
|
logger.warning("Utilisation du modèle de base camembert-base") |
|
|
return "camembert-base" |
|
|
|
|
|
def _load_ner_model(self): |
|
|
"""Charge ou crée le modèle NER fine-tuné""" |
|
|
|
|
|
fine_tuned_path = self.cache_dir / "medical_ner_model" |
|
|
|
|
|
if fine_tuned_path.exists(): |
|
|
logger.info("Chargement du modèle NER fine-tuné existant") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(fine_tuned_path) |
|
|
self.ner_model = AutoModelForTokenClassification.from_pretrained(fine_tuned_path) |
|
|
else: |
|
|
logger.info("Création d'un nouveau modèle NER médical") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=self.cache_dir) |
|
|
|
|
|
|
|
|
self.ner_model = AutoModelForTokenClassification.from_pretrained( |
|
|
self.model_name, |
|
|
num_labels=len(self.entity_labels), |
|
|
id2label=self.id2label, |
|
|
label2id=self.label2id, |
|
|
cache_dir=self.cache_dir |
|
|
) |
|
|
|
|
|
self.ner_model.to(self.device) |
|
|
|
|
|
|
|
|
self.ner_pipeline = pipeline( |
|
|
"token-classification", |
|
|
model=self.ner_model, |
|
|
tokenizer=self.tokenizer, |
|
|
device=0 if torch.cuda.is_available() else -1, |
|
|
aggregation_strategy="simple" |
|
|
) |
|
|
|
|
|
def extract_entities(self, text: str) -> MedicalEntity: |
|
|
"""Extraction d'entités avec le modèle NER fine-tuné""" |
|
|
|
|
|
|
|
|
try: |
|
|
ner_results = self.ner_pipeline(text) |
|
|
except Exception as e: |
|
|
logger.error(f"Erreur NER: {e}") |
|
|
return MedicalEntity([], [], [], [], [], [], [], []) |
|
|
|
|
|
|
|
|
entities = { |
|
|
"EXAM": [], |
|
|
"SPECIALTY": [], |
|
|
"ANATOMY": [], |
|
|
"PATHOLOGY": [], |
|
|
"PROCEDURE": [], |
|
|
"MEASURE": [], |
|
|
"MEDICATION": [], |
|
|
"SYMPTOM": [] |
|
|
} |
|
|
|
|
|
for result in ner_results: |
|
|
entity_type = result['entity_group'].replace('B-', '').replace('I-', '') |
|
|
entity_text = result['word'] |
|
|
confidence = result['score'] |
|
|
|
|
|
if entity_type in entities and confidence > 0.7: |
|
|
entities[entity_type].append((entity_text, confidence)) |
|
|
|
|
|
return MedicalEntity( |
|
|
exam_types=entities["EXAM"], |
|
|
specialties=entities["SPECIALTY"], |
|
|
anatomical_regions=entities["ANATOMY"], |
|
|
pathologies=entities["PATHOLOGY"], |
|
|
medical_procedures=entities["PROCEDURE"], |
|
|
measurements=entities["MEASURE"], |
|
|
medications=entities["MEDICATION"], |
|
|
symptoms=entities["SYMPTOM"] |
|
|
) |
|
|
|
|
|
def fine_tune_on_templates(self, templates_data: List[Dict], |
|
|
output_dir: str = None, |
|
|
epochs: int = 3): |
|
|
"""Fine-tuning du modèle NER sur des templates médicaux""" |
|
|
|
|
|
if output_dir is None: |
|
|
output_dir = self.cache_dir / "medical_ner_model" |
|
|
|
|
|
logger.info("Début du fine-tuning NER sur templates médicaux") |
|
|
|
|
|
|
|
|
|
|
|
train_dataset = self._prepare_training_data(templates_data) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
num_train_epochs=epochs, |
|
|
per_device_train_batch_size=8, |
|
|
per_device_eval_batch_size=8, |
|
|
warmup_steps=100, |
|
|
weight_decay=0.01, |
|
|
logging_dir=f"{output_dir}/logs", |
|
|
save_strategy="epoch", |
|
|
evaluation_strategy="epoch" if train_dataset.get('eval') else "no", |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="eval_loss" if train_dataset.get('eval') else None, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=self.ner_model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset['train'], |
|
|
eval_dataset=train_dataset.get('eval'), |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
trainer.save_model() |
|
|
self.tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
logger.info(f"Fine-tuning terminé, modèle sauvé dans {output_dir}") |
|
|
|
|
|
def _prepare_training_data(self, templates_data: List[Dict]) -> Dict: |
|
|
"""Prépare les données d'entraînement pour le NER (auto-annotation intelligente)""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmptyDataset(Dataset): |
|
|
def __len__(self): |
|
|
return 0 |
|
|
def __getitem__(self, idx): |
|
|
return {} |
|
|
|
|
|
return {'train': EmptyDataset()} |
|
|
|
|
|
class AdvancedMedicalEmbedding: |
|
|
"""Générateur d'embeddings médicaux avancés avec cross-encoder reranking""" |
|
|
|
|
|
def __init__(self, |
|
|
base_model: str = "almanach/camembert-bio-base", |
|
|
cross_encoder_model: str = "auto"): |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.base_model_name = base_model |
|
|
|
|
|
|
|
|
self._load_base_model() |
|
|
|
|
|
|
|
|
self._load_cross_encoder(cross_encoder_model) |
|
|
|
|
|
def _load_base_model(self): |
|
|
"""Charge le modèle de base pour les embeddings""" |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name) |
|
|
self.base_model = AutoModel.from_pretrained(self.base_model_name) |
|
|
self.base_model.to(self.device) |
|
|
logger.info(f"Modèle de base chargé: {self.base_model_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Erreur chargement modèle de base: {e}") |
|
|
raise |
|
|
|
|
|
def _load_cross_encoder(self, model_name: str): |
|
|
"""Charge le cross-encoder pour reranking""" |
|
|
|
|
|
if model_name == "auto": |
|
|
|
|
|
cross_encoders = [ |
|
|
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", |
|
|
"emilyalsentzer/Bio_ClinicalBERT", |
|
|
self.base_model_name |
|
|
] |
|
|
|
|
|
for model in cross_encoders: |
|
|
try: |
|
|
self.cross_tokenizer = AutoTokenizer.from_pretrained(model) |
|
|
self.cross_model = AutoModel.from_pretrained(model) |
|
|
self.cross_model.to(self.device) |
|
|
logger.info(f"Cross-encoder chargé: {model}") |
|
|
break |
|
|
except: |
|
|
continue |
|
|
else: |
|
|
self.cross_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.cross_model = AutoModel.from_pretrained(model_name) |
|
|
self.cross_model.to(self.device) |
|
|
|
|
|
def generate_embedding(self, text: str, entities: MedicalEntity = None) -> np.ndarray: |
|
|
"""Génère un embedding enrichi pour un texte médical""" |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
text, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.base_model(**inputs) |
|
|
|
|
|
|
|
|
attention_mask = inputs['attention_mask'] |
|
|
token_embeddings = outputs.last_hidden_state |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
|
|
|
if entities: |
|
|
embedding = self._enrich_with_ner_entities(embedding, entities) |
|
|
|
|
|
return embedding.cpu().numpy().flatten().astype(np.float32) |
|
|
|
|
|
def _enrich_with_ner_entities(self, base_embedding: torch.Tensor, entities: MedicalEntity) -> torch.Tensor: |
|
|
"""Enrichit l'embedding avec les entités extraites par NER""" |
|
|
|
|
|
|
|
|
entity_texts = [] |
|
|
confidence_weights = [] |
|
|
|
|
|
for entity_list in [entities.exam_types, entities.specialties, |
|
|
entities.anatomical_regions, entities.pathologies]: |
|
|
for entity_text, confidence in entity_list: |
|
|
entity_texts.append(entity_text) |
|
|
confidence_weights.append(confidence) |
|
|
|
|
|
if not entity_texts: |
|
|
return base_embedding |
|
|
|
|
|
|
|
|
entity_text_combined = " [SEP] ".join(entity_texts) |
|
|
entity_inputs = self.tokenizer( |
|
|
entity_text_combined, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=256, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
entity_outputs = self.base_model(**entity_inputs) |
|
|
entity_embedding = torch.mean(entity_outputs.last_hidden_state, dim=1) |
|
|
|
|
|
|
|
|
avg_confidence = np.mean(confidence_weights) if confidence_weights else 0.5 |
|
|
fusion_weight = min(0.4, avg_confidence) |
|
|
|
|
|
enriched_embedding = (1 - fusion_weight) * base_embedding + fusion_weight * entity_embedding |
|
|
|
|
|
return enriched_embedding |
|
|
|
|
|
def cross_encoder_rerank(self, |
|
|
query: str, |
|
|
candidates: List[Dict], |
|
|
top_k: int = 3) -> List[Dict]: |
|
|
"""Reranking avec cross-encoder pour affiner la sélection""" |
|
|
|
|
|
if len(candidates) <= top_k: |
|
|
return candidates |
|
|
|
|
|
reranked_candidates = [] |
|
|
|
|
|
for candidate in candidates: |
|
|
|
|
|
pair_text = f"{query} [SEP] {candidate['document']}" |
|
|
|
|
|
|
|
|
inputs = self.cross_tokenizer( |
|
|
pair_text, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.cross_model(**inputs) |
|
|
|
|
|
cls_embedding = outputs.last_hidden_state[:, 0, :] |
|
|
similarity_score = torch.sigmoid(torch.mean(cls_embedding)).item() |
|
|
|
|
|
candidate_copy = candidate.copy() |
|
|
candidate_copy['cross_encoder_score'] = similarity_score |
|
|
candidate_copy['final_score'] = ( |
|
|
0.6 * candidate['similarity_score'] + |
|
|
0.4 * similarity_score |
|
|
) |
|
|
|
|
|
reranked_candidates.append(candidate_copy) |
|
|
|
|
|
|
|
|
reranked_candidates.sort(key=lambda x: x['final_score'], reverse=True) |
|
|
|
|
|
return reranked_candidates[:top_k] |
|
|
|
|
|
class MedicalTemplateVectorDB: |
|
|
"""Base de données vectorielle optimisée pour templates médicaux""" |
|
|
|
|
|
def __init__(self, db_path: str = "./medical_vector_db", collection_name: str = "medical_templates"): |
|
|
self.db_path = db_path |
|
|
self.collection_name = collection_name |
|
|
|
|
|
|
|
|
self.client = chromadb.PersistentClient( |
|
|
path=db_path, |
|
|
settings=Settings( |
|
|
anonymized_telemetry=False, |
|
|
allow_reset=True |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
self.collection = self.client.get_collection(collection_name) |
|
|
logger.info(f"Collection '{collection_name}' chargée") |
|
|
except: |
|
|
self.collection = self.client.create_collection( |
|
|
name=collection_name, |
|
|
metadata={ |
|
|
"hnsw:space": "cosine", |
|
|
"hnsw:M": 32, |
|
|
"hnsw:ef_construction": 200, |
|
|
"hnsw:ef_search": 50 |
|
|
} |
|
|
) |
|
|
logger.info(f"Collection '{collection_name}' créée avec optimisations HNSW") |
|
|
|
|
|
def add_template(self, |
|
|
template_id: str, |
|
|
template_text: str, |
|
|
embedding: np.ndarray, |
|
|
entities: MedicalEntity, |
|
|
metadata: Dict[str, Any] = None): |
|
|
"""Ajoute un template avec métadonnées enrichies par NER""" |
|
|
|
|
|
|
|
|
auto_metadata = { |
|
|
"exam_types": [entity[0] for entity in entities.exam_types], |
|
|
"specialties": [entity[0] for entity in entities.specialties], |
|
|
"anatomical_regions": [entity[0] for entity in entities.anatomical_regions], |
|
|
"pathologies": [entity[0] for entity in entities.pathologies], |
|
|
"procedures": [entity[0] for entity in entities.medical_procedures], |
|
|
"text_length": len(template_text), |
|
|
"entity_confidence_avg": np.mean([ |
|
|
entity[1] for entity_list in [ |
|
|
entities.exam_types, entities.specialties, |
|
|
entities.anatomical_regions, entities.pathologies |
|
|
] for entity in entity_list |
|
|
]) if any([entities.exam_types, entities.specialties, |
|
|
entities.anatomical_regions, entities.pathologies]) else 0.0 |
|
|
} |
|
|
|
|
|
if metadata: |
|
|
auto_metadata.update(metadata) |
|
|
|
|
|
self.collection.add( |
|
|
embeddings=[embedding.tolist()], |
|
|
documents=[template_text], |
|
|
metadatas=[auto_metadata], |
|
|
ids=[template_id] |
|
|
) |
|
|
|
|
|
logger.info(f"Template {template_id} ajouté avec métadonnées NER automatiques") |
|
|
|
|
|
def advanced_search(self, |
|
|
query_embedding: np.ndarray, |
|
|
n_results: int = 10, |
|
|
entity_filters: Dict[str, List[str]] = None, |
|
|
confidence_threshold: float = 0.0) -> List[Dict]: |
|
|
"""Recherche avancée avec filtres basés sur entités NER""" |
|
|
|
|
|
where_clause = {} |
|
|
|
|
|
|
|
|
if entity_filters: |
|
|
for entity_type, entity_values in entity_filters.items(): |
|
|
if entity_values: |
|
|
where_clause[entity_type] = {"$in": entity_values} |
|
|
|
|
|
|
|
|
if confidence_threshold > 0: |
|
|
where_clause["entity_confidence_avg"] = {"$gte": confidence_threshold} |
|
|
|
|
|
results = self.collection.query( |
|
|
query_embeddings=[query_embedding.tolist()], |
|
|
n_results=n_results, |
|
|
where=where_clause if where_clause else None, |
|
|
include=["documents", "metadatas", "distances"] |
|
|
) |
|
|
|
|
|
|
|
|
formatted_results = [] |
|
|
for i in range(len(results['ids'][0])): |
|
|
formatted_results.append({ |
|
|
'id': results['ids'][0][i], |
|
|
'document': results['documents'][0][i], |
|
|
'metadata': results['metadatas'][0][i], |
|
|
'similarity_score': 1 - results['distances'][0][i], |
|
|
'distance': results['distances'][0][i] |
|
|
}) |
|
|
|
|
|
return formatted_results |
|
|
|
|
|
class AdvancedMedicalTemplateProcessor: |
|
|
"""Processeur avancé avec NER fine-tuné et reranking cross-encoder""" |
|
|
|
|
|
def __init__(self, |
|
|
base_model: str = "almanach/camembert-bio-base", |
|
|
db_path: str = "./advanced_medical_vector_db"): |
|
|
|
|
|
self.ner_extractor = AdvancedMedicalNER() |
|
|
self.embedding_generator = AdvancedMedicalEmbedding(base_model) |
|
|
self.vector_db = MedicalTemplateVectorDB(db_path) |
|
|
|
|
|
logger.info("Processeur médical avancé initialisé avec NER fine-tuné et cross-encoder reranking") |
|
|
|
|
|
def process_templates_batch(self, |
|
|
templates: List[Dict[str, str]], |
|
|
batch_size: int = 8, |
|
|
fine_tune_ner: bool = False) -> None: |
|
|
"""Traitement avancé avec option de fine-tuning NER""" |
|
|
|
|
|
if fine_tune_ner: |
|
|
logger.info("Fine-tuning du modèle NER sur les templates...") |
|
|
self.ner_extractor.fine_tune_on_templates(templates) |
|
|
|
|
|
logger.info(f"Traitement avancé de {len(templates)} templates") |
|
|
|
|
|
for i in tqdm(range(0, len(templates), batch_size), desc="Traitement avancé"): |
|
|
batch = templates[i:i+batch_size] |
|
|
|
|
|
for template in batch: |
|
|
try: |
|
|
template_id = template['id'] |
|
|
template_text = template['text'] |
|
|
metadata = template.get('metadata', {}) |
|
|
|
|
|
|
|
|
entities = self.ner_extractor.extract_entities(template_text) |
|
|
|
|
|
|
|
|
embedding = self.embedding_generator.generate_embedding(template_text, entities) |
|
|
|
|
|
|
|
|
self.vector_db.add_template( |
|
|
template_id=template_id, |
|
|
template_text=template_text, |
|
|
embedding=embedding, |
|
|
entities=entities, |
|
|
metadata=metadata |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erreur traitement template {template.get('id', 'unknown')}: {e}") |
|
|
continue |
|
|
|
|
|
def find_best_template_with_reranking(self, |
|
|
transcription: str, |
|
|
initial_candidates: int = 10, |
|
|
final_results: int = 3) -> List[Dict]: |
|
|
"""Recherche optimale avec reranking cross-encoder""" |
|
|
|
|
|
|
|
|
query_entities = self.ner_extractor.extract_entities(transcription) |
|
|
|
|
|
|
|
|
query_embedding = self.embedding_generator.generate_embedding(transcription, query_entities) |
|
|
|
|
|
|
|
|
entity_filters = {} |
|
|
if query_entities.exam_types: |
|
|
entity_filters['exam_types'] = [entity[0] for entity in query_entities.exam_types] |
|
|
if query_entities.specialties: |
|
|
entity_filters['specialties'] = [entity[0] for entity in query_entities.specialties] |
|
|
if query_entities.anatomical_regions: |
|
|
entity_filters['anatomical_regions'] = [entity[0] for entity in query_entities.anatomical_regions] |
|
|
|
|
|
|
|
|
initial_candidates_results = self.vector_db.advanced_search( |
|
|
query_embedding=query_embedding, |
|
|
n_results=initial_candidates, |
|
|
entity_filters=entity_filters, |
|
|
confidence_threshold=0.6 |
|
|
) |
|
|
|
|
|
|
|
|
if len(initial_candidates_results) > final_results: |
|
|
final_results_reranked = self.embedding_generator.cross_encoder_rerank( |
|
|
query=transcription, |
|
|
candidates=initial_candidates_results, |
|
|
top_k=final_results |
|
|
) |
|
|
else: |
|
|
final_results_reranked = initial_candidates_results |
|
|
|
|
|
|
|
|
for result in final_results_reranked: |
|
|
result['query_entities'] = { |
|
|
'exam_types': query_entities.exam_types, |
|
|
'specialties': query_entities.specialties, |
|
|
'anatomical_regions': query_entities.anatomical_regions, |
|
|
'pathologies': query_entities.pathologies |
|
|
} |
|
|
|
|
|
return final_results_reranked |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Exemple d'utilisation du système avancé""" |
|
|
|
|
|
|
|
|
processor = AdvancedMedicalTemplateProcessor() |
|
|
|
|
|
|
|
|
sample_templates = [ |
|
|
{ |
|
|
'id': 'angio_001', |
|
|
'text': """Échographie et doppler artério-veineux des membres inférieurs. |
|
|
Exploration de l'incontinence veineuse superficielle...""", |
|
|
'metadata': {'source': 'angiologie', 'version': '2024'} |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
processor.process_templates_batch(sample_templates, fine_tune_ner=False) |
|
|
|
|
|
|
|
|
transcription = """madame bacon nicole bilan œdème droit gonalgies ostéophytes |
|
|
incontinence veineuse modérée portions surale droite crurale gauche saphéniennes""" |
|
|
|
|
|
best_matches = processor.find_best_template_with_reranking( |
|
|
transcription=transcription, |
|
|
initial_candidates=15, |
|
|
final_results=3 |
|
|
) |
|
|
|
|
|
|
|
|
for i, match in enumerate(best_matches): |
|
|
print(f"\n=== Match {i+1} ===") |
|
|
print(f"Template ID: {match['id']}") |
|
|
print(f"Score final: {match.get('final_score', match['similarity_score']):.4f}") |
|
|
print(f"Score cross-encoder: {match.get('cross_encoder_score', 'N/A')}") |
|
|
print(f"Entités détectées dans la query:") |
|
|
for entity_type, entities in match.get('query_entities', {}).items(): |
|
|
if entities: |
|
|
print(f" - {entity_type}: {[f'{e[0]} ({e[1]:.2f})' for e in entities]}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |