Spaces:
Running
Running
| """ | |
| Beautiful Medical NER Demo using OpenMed Models | |
| A comprehensive Named Entity Recognition demo for medical professionals | |
| featuring multiple specialized medical models with beautiful entity visualization. | |
| """ | |
| import gradio as gr | |
| import spacy | |
| from spacy import displacy | |
| from transformers import pipeline | |
| import warnings | |
| import logging | |
| import re | |
| from typing import Dict, List, Tuple | |
| import random | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| # Model configurations | |
| MODELS = { | |
| "Oncology Detection": { | |
| "model_id": "OpenMed/OpenMed-NER-OncologyDetect-SuperMedical-355M", | |
| "description": "Specialized in cancer, genetics, and oncology entities", | |
| }, | |
| # "Pharmaceutical Detection": { | |
| # "model_id": "OpenMed/OpenMed-NER-PharmaDetect-SuperClinical-434M", | |
| # "description": "Detects drugs, chemicals, and pharmaceutical entities", | |
| # }, | |
| # "Disease Detection": { | |
| # "model_id": "OpenMed/OpenMed-NER-DiseaseDetect-SuperClinical-434M", | |
| # "description": "Identifies diseases, conditions, and pathologies", | |
| # }, | |
| # "Genome Detection": { | |
| # "model_id": "OpenMed/OpenMed-NER-GenomeDetect-ModernClinical-395M", | |
| # "description": "Recognizes genes, proteins, and genomic entities", | |
| # }, | |
| } | |
| # Medical text examples for each model | |
| EXAMPLES = { | |
| "Oncology Detection": [ | |
| "The patient presented with metastatic adenocarcinoma of the lung with mutations in EGFR and KRAS genes. Treatment with erlotinib was initiated, targeting the epidermal growth factor receptor pathway.", | |
| "Histological examination revealed invasive ductal carcinoma with high-grade nuclear features. The tumor showed positive estrogen receptor and HER2 amplification, indicating potential for targeted therapy.", | |
| "The oncologist recommended adjuvant chemotherapy with doxorubicin and cyclophosphamide, followed by paclitaxel, to target rapidly dividing cancer cells in the breast tissue.", | |
| ], | |
| "Pharmaceutical Detection": [ | |
| "The patient was prescribed metformin 500mg twice daily for diabetes management, along with lisinopril 10mg for hypertension control and atorvastatin 20mg for cholesterol reduction.", | |
| "Administration of morphine sulfate provided effective pain relief, while ondansetron prevented chemotherapy-induced nausea. The patient also received dexamethasone as an anti-inflammatory agent.", | |
| "The pharmacokinetic study evaluated the absorption of ibuprofen and its interaction with warfarin, monitoring plasma concentrations and potential bleeding risks.", | |
| ], | |
| "Disease Detection": [ | |
| "The patient was diagnosed with type 2 diabetes mellitus, hypertension, and coronary artery disease. Additional findings included diabetic nephropathy and peripheral neuropathy.", | |
| "Clinical presentation was consistent with acute myocardial infarction complicated by cardiogenic shock. The patient also had a history of chronic obstructive pulmonary disease and atrial fibrillation.", | |
| "Laboratory results confirmed the diagnosis of rheumatoid arthritis with elevated inflammatory markers. The patient also exhibited symptoms of Sjögren's syndrome and osteoporosis.", | |
| ], | |
| "Genome Detection": [ | |
| "Genetic analysis revealed mutations in the BRCA1 and BRCA2 genes, significantly increasing the risk of hereditary breast and ovarian cancer. The p53 tumor suppressor gene also showed alterations.", | |
| "Expression profiling identified upregulation of MYC oncogene and downregulation of PTEN tumor suppressor. The mTOR signaling pathway showed significant activation in the tumor samples.", | |
| "Whole genome sequencing detected variants in CFTR gene associated with cystic fibrosis, along with polymorphisms in CYP2D6 affecting drug metabolism and APOE influencing Alzheimer's risk.", | |
| ], | |
| } | |
| def ner_filtered(text, *, pipe, min_score=0.60, min_length=1, remove_punctuation=True): | |
| """ | |
| Apply confidence and punctuation filtering to NER pipeline results. | |
| This is the proven filtering approach that eliminates spurious predictions. | |
| """ | |
| # 1️⃣ Run the NER model | |
| raw_entities = pipe(text) | |
| # 2️⃣ Define regex for content detection | |
| if remove_punctuation: | |
| has_content = re.compile(r"[A-Za-z0-9]") # At least one letter or digit | |
| else: | |
| has_content = re.compile(r".") # Allow everything | |
| # 3️⃣ Apply filters | |
| filtered_entities = [] | |
| for entity in raw_entities: | |
| # Confidence filter | |
| if entity["score"] < min_score: | |
| continue | |
| # Length filter | |
| if len(entity["word"].strip()) < min_length: | |
| continue | |
| # Punctuation filter | |
| if remove_punctuation and not has_content.search(entity["word"]): | |
| continue | |
| filtered_entities.append(entity) | |
| return filtered_entities | |
| def advanced_ner_filter(text, *, pipe, min_score=0.60, strip_edges=True, exclude_patterns=None): | |
| """ | |
| Advanced filtering with edge stripping and pattern exclusion. | |
| """ | |
| entities = pipe(text) | |
| filtered = [] | |
| for entity in entities: | |
| if entity["score"] < min_score: | |
| continue | |
| word = entity["word"] | |
| # Strip punctuation from edges | |
| if strip_edges: | |
| stripped = word.strip(".,!?;:()[]{}\"'-_") | |
| if not stripped: | |
| continue | |
| entity = entity.copy() | |
| entity["word"] = stripped | |
| # Apply exclusion patterns | |
| if exclude_patterns: | |
| skip = any(re.match(pattern, entity["word"]) for pattern in exclude_patterns) | |
| if skip: | |
| continue | |
| # Only keep entities with actual content | |
| if re.search(r"[A-Za-z0-9]", entity["word"]): | |
| filtered.append(entity) | |
| return filtered | |
| def merge_adjacent_entities(entities, original_text, max_gap=10): | |
| """ | |
| Merge adjacent entities of the same type that are separated by small gaps. | |
| Useful for handling cases like "BRCA1 and BRCA2" or "HER2-positive". | |
| """ | |
| if len(entities) < 2: | |
| return entities | |
| merged = [] | |
| current = entities[0].copy() | |
| for next_entity in entities[1:]: | |
| # Check if same entity type and close proximity | |
| if (current["entity_group"] == next_entity["entity_group"] and | |
| next_entity["start"] - current["end"] <= max_gap): | |
| # Check what's between them | |
| gap_text = original_text[current["end"]:next_entity["start"]] | |
| # Merge if gap contains only connecting words/punctuation | |
| if re.match(r"^[\s\-,/and]*$", gap_text.lower()): | |
| # Extend current entity to include the next one | |
| current["word"] = original_text[current["start"]:next_entity["end"]] | |
| current["end"] = next_entity["end"] | |
| current["score"] = (current["score"] + next_entity["score"]) / 2 | |
| continue | |
| # No merge, add current and move to next | |
| merged.append(current) | |
| current = next_entity.copy() | |
| # Don't forget the last entity | |
| merged.append(current) | |
| return merged | |
| class MedicalNERApp: | |
| def __init__(self): | |
| self.pipelines = {} | |
| self.nlp = spacy.blank("en") # SpaCy model for visualization | |
| self.load_models() | |
| def load_models(self): | |
| """Load and cache all models with proper aggregation strategy""" | |
| print("🏥 Loading Medical NER Models...") | |
| for model_name, config in MODELS.items(): | |
| print(f"Loading {model_name}...") | |
| try: | |
| # Use aggregation_strategy=None and handle grouping ourselves for better control | |
| ner_pipeline = pipeline( | |
| "token-classification", | |
| model=config["model_id"], | |
| aggregation_strategy=None, # ← Get raw tokens, group them properly ourselves | |
| device=0 if __name__ == "__main__" else -1 # Use GPU if available | |
| ) | |
| self.pipelines[model_name] = ner_pipeline | |
| print(f"✅ {model_name} loaded successfully with custom entity grouping") | |
| except Exception as e: | |
| print(f"❌ Error loading {model_name}: {str(e)}") | |
| self.pipelines[model_name] = None | |
| print("🎉 All models loaded and cached!") | |
| def smart_group_entities(self, tokens, text): | |
| """ | |
| Smart entity grouping that properly merges sub-tokens into complete entities. | |
| This fixes the issue where aggregation_strategy="simple" creates overlapping spans. | |
| """ | |
| if not tokens: | |
| return [] | |
| entities = [] | |
| current_entity = None | |
| for token in tokens: | |
| label = token['entity'] | |
| score = token['score'] | |
| word = token['word'] | |
| start = token['start'] | |
| end = token['end'] | |
| # Skip O (Outside) tags | |
| if label == 'O': | |
| if current_entity: | |
| entities.append(current_entity) | |
| current_entity = None | |
| continue | |
| # Clean the label (remove B- and I- prefixes) | |
| clean_label = label.replace('B-', '').replace('I-', '') | |
| # Start new entity (B- tag or different entity type) | |
| if label.startswith('B-') or (current_entity and current_entity['entity_group'] != clean_label): | |
| if current_entity: | |
| entities.append(current_entity) | |
| current_entity = { | |
| 'entity_group': clean_label, | |
| 'score': score, | |
| 'word': text[start:end], # Use actual text from the source | |
| 'start': start, | |
| 'end': end | |
| } | |
| # Continue current entity (I- tag) | |
| elif current_entity and clean_label == current_entity['entity_group']: | |
| # Extend the current entity | |
| current_entity['end'] = end | |
| current_entity['word'] = text[current_entity['start']:end] | |
| current_entity['score'] = (current_entity['score'] + score) / 2 # Average scores | |
| # Don't forget the last entity | |
| if current_entity: | |
| entities.append(current_entity) | |
| return entities | |
| def create_spacy_visualization(self, text: str, entities: List[Dict], model_name: str) -> str: | |
| """Create spaCy displaCy visualization with dynamic colors and improved span handling.""" | |
| print(f"\n🔍 VISUALIZATION DEBUG for {model_name}") | |
| print(f"Input text length: {len(text)} chars") | |
| print(f"Total entities to visualize: {len(entities)}") | |
| # Show all entities found | |
| print("\n📋 ENTITIES TO VISUALIZE:") | |
| entity_by_type = {} | |
| for i, ent in enumerate(entities): | |
| entity_type = ent['entity_group'] | |
| if entity_type not in entity_by_type: | |
| entity_by_type[entity_type] = [] | |
| entity_by_type[entity_type].append(ent) | |
| print(f" {i+1:2d}. [{ent['start']:3d}:{ent['end']:3d}] '{ent['word']:25}' -> {entity_type:20} (score: {ent['score']:.3f})") | |
| print(f"\n📊 ENTITY COUNTS BY TYPE:") | |
| for entity_type, ents in entity_by_type.items(): | |
| print(f" {entity_type}: {len(ents)} instances") | |
| doc = self.nlp(text) | |
| spacy_ents = [] | |
| failed_entities = [] | |
| print(f"\n🔧 CREATING SPACY SPANS:") | |
| for i, entity in enumerate(entities): | |
| try: | |
| start = entity['start'] | |
| end = entity['end'] | |
| label = entity['entity_group'] | |
| entity_text = entity['word'] | |
| print(f" {i+1:2d}. Trying span [{start}:{end}] '{entity_text}' -> {label}") | |
| # Try to create span with default mode first | |
| span = doc.char_span(start, end, label=label) | |
| if span is not None: | |
| spacy_ents.append(span) | |
| print(f" ✅ SUCCESS: '{span.text}' -> {label}") | |
| else: | |
| # Try different alignment modes | |
| span = doc.char_span(start, end, label=label, alignment_mode="expand") | |
| if span is not None: | |
| spacy_ents.append(span) | |
| print(f" ✅ SUCCESS (expand): '{span.text}' -> {label}") | |
| else: | |
| failed_entities.append(entity) | |
| print(f" ❌ FAILED: Could not create span for '{entity_text}' -> {label}") | |
| except Exception as e: | |
| failed_entities.append(entity) | |
| print(f" 💥 EXCEPTION: {str(e)}") | |
| print(f"\n📈 SPAN CREATION RESULTS:") | |
| print(f" ✅ Successful spans: {len(spacy_ents)}") | |
| print(f" ❌ Failed spans: {len(failed_entities)}") | |
| # Filter overlapping spans (this is much cleaner now) | |
| print(f"\n🔄 FILTERING OVERLAPPING SPANS...") | |
| print(f" Before filtering: {len(spacy_ents)} spans") | |
| spacy_ents = spacy.util.filter_spans(spacy_ents) | |
| print(f" After filtering: {len(spacy_ents)} spans") | |
| doc.ents = spacy_ents | |
| print(f"\n🎨 FINAL VISUALIZATION ENTITIES:") | |
| for ent in doc.ents: | |
| print(f" '{ent.text}' ({ent.label_}) [{ent.start_char}:{ent.end_char}]") | |
| # Define color palette | |
| color_palette = { | |
| "DISEASE": "#FF5733", | |
| "CHEM": "#33FF57", | |
| "GENE/PROTEIN": "#3357FF", | |
| "Cancer": "#FF33F6", | |
| "Cell": "#33FFF6", | |
| "Organ": "#F6FF33", | |
| "Tissue": "#FF8333", | |
| "Simple_chemical": "#8333FF", | |
| "Gene_or_gene_product": "#33FF83", | |
| "Organism": "#FF6B33", | |
| } | |
| unique_labels = sorted(list(set(ent.label_ for ent in doc.ents))) | |
| colors = {} | |
| for label in unique_labels: | |
| if label in color_palette: | |
| colors[label] = color_palette[label] | |
| else: | |
| colors[label] = "#" + ''.join([hex(x)[2:].zfill(2) for x in (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))]) | |
| options = { | |
| "ents": unique_labels, | |
| "colors": colors, | |
| "style": "max-width: 100%; line-height: 2.5; direction: ltr;" | |
| } | |
| print(f"\n🎨 VISUALIZATION CONFIG:") | |
| print(f" Entity types for display: {unique_labels}") | |
| print(f" Color mapping: {colors}") | |
| # Add debug info to the HTML output if there are issues | |
| debug_info = "" | |
| if failed_entities: | |
| debug_info = f""" | |
| <div style="margin-top: 15px; padding: 10px; background: #fff3cd; border: 1px solid #ffeaa7; border-radius: 5px; font-size: 12px;"> | |
| <strong>⚠️ Visualization Info:</strong><br> | |
| {len(failed_entities)} entities could not be visualized due to text alignment issues.<br> | |
| All entities are still counted in the summary below. | |
| </div> | |
| """ | |
| displacy_html = displacy.render(doc, style="ent", options=options, page=False) | |
| return displacy_html + debug_info | |
| def predict_entities(self, text: str, model_name: str, confidence_threshold: float = 0.60) -> Tuple[str, str]: | |
| """ | |
| Predict entities using smart grouping for maximum accuracy. | |
| """ | |
| if not text.strip(): | |
| return "<p>Please enter medical text to analyze.</p>", "No text provided" | |
| if model_name not in self.pipelines or self.pipelines[model_name] is None: | |
| return f"<p>❌ Model {model_name} is not available.</p>", "Model not available" | |
| try: | |
| print(f"\nDEBUG: Processing text with {model_name}") | |
| print(f"Text: {text}") | |
| print(f"Confidence threshold: {confidence_threshold}") | |
| # Get raw token predictions from the pipeline | |
| pipeline_instance = self.pipelines[model_name] | |
| raw_tokens = pipeline_instance(text) | |
| print(f"Got {len(raw_tokens)} raw tokens from pipeline") | |
| if not raw_tokens: | |
| return "<p>No entities detected.</p>", "No entities found" | |
| # Use our smart grouping to merge sub-tokens into complete entities | |
| grouped_entities = self.smart_group_entities(raw_tokens, text) | |
| print(f"Smart grouping created {len(grouped_entities)} entities") | |
| # Apply confidence filtering to the grouped entities | |
| filtered_entities = [] | |
| for entity in grouped_entities: | |
| if entity["score"] >= confidence_threshold: | |
| # Apply additional quality filters | |
| if (len(entity["word"].strip()) > 0 and # Not empty | |
| re.search(r"[A-Za-z0-9]", entity["word"])): # Contains actual content | |
| filtered_entities.append(entity) | |
| print(f"✅ After confidence filtering: {len(filtered_entities)} high-quality entities") | |
| if not filtered_entities: | |
| return f"<p>No entities found with confidence ≥ {confidence_threshold:.0%}. Try lowering the threshold.</p>", "No entities found" | |
| # Create visualization and summary | |
| html_output = self.create_spacy_visualization(text, filtered_entities, model_name) | |
| wrapped_html = self.wrap_displacy_output(html_output, model_name, len(filtered_entities), confidence_threshold) | |
| summary = self.create_summary(filtered_entities, model_name, confidence_threshold) | |
| return wrapped_html, summary | |
| except Exception as e: | |
| import traceback | |
| print(f"ERROR in predict_entities: {str(e)}") | |
| traceback.print_exc() | |
| error_msg = f"Error during prediction: {str(e)}" | |
| return f"<p>❌ {error_msg}</p>", error_msg | |
| def wrap_displacy_output(self, displacy_html: str, model_name: str, entity_count: int, confidence_threshold: float) -> str: | |
| """Wrap displaCy output in a beautiful container with filtering info.""" | |
| return f""" | |
| <div style="font-family: 'Segoe UI', Arial, sans-serif; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| overflow: hidden;"> | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; padding: 15px; text-align: center;"> | |
| <h3 style="margin: 0; font-size: 18px;">{model_name}</h3> | |
| <p style="margin: 5px 0 0 0; opacity: 0.9; font-size: 14px;"> | |
| Found {entity_count} high-confidence medical entities (≥{confidence_threshold:.0%}) | |
| </p> | |
| <div style="margin-top: 8px; font-size: 12px; opacity: 0.8;"> | |
| ✅ Filtered with aggregation_strategy="simple" + confidence threshold | |
| </div> | |
| </div> | |
| <div style="padding: 20px; margin: 0; line-height: 2.5;"> | |
| {displacy_html} | |
| </div> | |
| </div> | |
| """ | |
| def create_summary(self, entities: List[Dict], model_name: str, confidence_threshold: float) -> str: | |
| """Create a summary of detected entities with filtering info.""" | |
| if not entities: | |
| return "No entities detected." | |
| entity_counts = {} | |
| for entity in entities: | |
| label = entity["entity_group"] | |
| if label not in entity_counts: | |
| entity_counts[label] = [] | |
| entity_counts[label].append(entity) | |
| summary_parts = [f"📊 **{model_name} Analysis Results**\n"] | |
| summary_parts.append(f"**Total high-confidence entities**: {len(entities)} (threshold ≥{confidence_threshold:.0%})\n") | |
| for label, ents in sorted(entity_counts.items()): | |
| avg_confidence = sum(e["score"] for e in ents) / len(ents) | |
| unique_texts = sorted(list(set(e["word"] for e in ents))) | |
| summary_parts.append( | |
| f"• **{label}**: {len(ents)} instances " | |
| f"(avg confidence: {avg_confidence:.2f})\n" | |
| f" Examples: {', '.join(unique_texts[:3])}" | |
| f"{'...' if len(unique_texts) > 3 else ''}\n" | |
| ) | |
| # Add filtering information | |
| summary_parts.append("\n🎯 **Accuracy Improvements Applied**\n") | |
| summary_parts.append("✅ Smart BIO token grouping - Properly merges sub-tokens into complete entities\n") | |
| summary_parts.append(f"✅ Confidence threshold filtering - Only entities ≥ {confidence_threshold:.0%} confidence\n") | |
| summary_parts.append("✅ Content validation - Excludes empty or punctuation-only predictions\n") | |
| summary_parts.append("✅ Precise span alignment - Improved text-to-visual mapping\n") | |
| # Add model information | |
| summary_parts.append(f"\n🔬 **Model Information**\n") | |
| summary_parts.append(f"Model: `{MODELS[model_name]['model_id']}`\n") | |
| summary_parts.append(f"Description: {MODELS[model_name]['description']}\n") | |
| return "\n".join(summary_parts) | |
| # Initialize the app | |
| print("🚀 Initializing Medical NER Application...") | |
| ner_app = MedicalNERApp() | |
| # Warmup | |
| print("🔥 Warming up models...") | |
| warmup_text = "The patient has diabetes and takes metformin." | |
| for model_name in MODELS.keys(): | |
| if ner_app.pipelines[model_name] is not None: | |
| try: | |
| print(f"Warming up {model_name}...") | |
| _ = ner_app.predict_entities(warmup_text, model_name, 0.60) | |
| print(f"✅ {model_name} warmed up successfully") | |
| except Exception as e: | |
| print(f"⚠️ Warmup failed for {model_name}: {str(e)}") | |
| print("🎉 Model warmup complete!") | |
| def predict_wrapper(text: str, model_name: str, confidence_threshold: float): | |
| """Wrapper function for Gradio interface with confidence control""" | |
| html_output, summary = ner_app.predict_entities(text, model_name, confidence_threshold) | |
| return html_output, summary | |
| def load_example(model_name: str, example_idx: int): | |
| """Load example text for the selected model""" | |
| if model_name in EXAMPLES and 0 <= example_idx < len(EXAMPLES[model_name]): | |
| return EXAMPLES[model_name][example_idx] | |
| return "" | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="🏥 Medical NER Expert", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .main-header { | |
| text-align: center; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 2rem; | |
| border-radius: 15px; | |
| margin-bottom: 2rem; | |
| box-shadow: 0 8px 32px rgba(0,0,0,0.1); | |
| } | |
| .model-info { | |
| padding: 1rem; | |
| border-radius: 10px; | |
| border-left: 4px solid #667eea; | |
| margin: 1rem 0; | |
| } | |
| .accuracy-badge { | |
| background: #28a745; | |
| color: white; | |
| padding: 4px 8px; | |
| border-radius: 12px; | |
| font-size: 12px; | |
| font-weight: bold; | |
| } | |
| """, | |
| ) as demo: | |
| # Header | |
| gr.HTML( | |
| """ | |
| <div class="main-header"> | |
| <h1>🏥 Medical NER Expert</h1> | |
| <p>Advanced Named Entity Recognition for Medical Professionals</p> | |
| <div style="margin-top: 10px;"> | |
| <span class="accuracy-badge">✅ HIGH ACCURACY MODE</span> | |
| </div> | |
| <p style="font-size: 14px; margin-top: 10px; opacity: 0.9;"> | |
| Powered by OpenMed models + proven filtering techniques (aggregation_strategy="simple" + confidence thresholds) | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Model selection | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value="Oncology Detection", | |
| label="🔬 Select Medical NER Model", | |
| info="Choose the specialized model for your analysis", | |
| ) | |
| # Model info display | |
| model_info = gr.HTML( | |
| value=f""" | |
| <div class="model-info"> | |
| <strong>Oncology Detection</strong><br> | |
| {MODELS["Oncology Detection"]["description"]} | |
| </div> | |
| """ | |
| ) | |
| # Confidence threshold slider | |
| confidence_slider = gr.Slider( | |
| minimum=0.30, | |
| maximum=0.95, | |
| value=0.60, | |
| step=0.05, | |
| label="🎯 Confidence Threshold", | |
| info="Higher values = fewer but more confident predictions" | |
| ) | |
| # Text input | |
| text_input = gr.Textbox( | |
| lines=8, | |
| placeholder="Enter medical text here for entity recognition...", | |
| label="📝 Medical Text Input", | |
| value=EXAMPLES["Oncology Detection"][0], | |
| ) | |
| # Example buttons | |
| with gr.Row(): | |
| example_buttons = [] | |
| for i in range(3): | |
| btn = gr.Button(f"Example {i+1}", size="sm", variant="secondary") | |
| example_buttons.append(btn) | |
| # Analyze button | |
| analyze_btn = gr.Button("🔍 Analyze Text", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| # Results | |
| results_html = gr.HTML( | |
| label="🎯 Entity Recognition Results", | |
| value="<p>Select a model and enter text to see entity recognition results.</p>", | |
| ) | |
| # Summary | |
| summary_output = gr.Markdown( | |
| value="Analysis summary will appear here...", | |
| label="📊 Analysis Summary", | |
| ) | |
| # Update model info when model changes | |
| def update_model_info(model_name): | |
| if model_name in MODELS: | |
| return f""" | |
| <div class="model-info"> | |
| <strong>{model_name}</strong><br> | |
| {MODELS[model_name]["description"]}<br> | |
| <small>Model: {MODELS[model_name]["model_id"]}</small> | |
| </div> | |
| """ | |
| return "" | |
| model_dropdown.change( | |
| update_model_info, inputs=[model_dropdown], outputs=[model_info] | |
| ) | |
| # Example button handlers | |
| for i, btn in enumerate(example_buttons): | |
| btn.click( | |
| lambda model_name, idx=i: load_example(model_name, idx), | |
| inputs=[model_dropdown], | |
| outputs=[text_input], | |
| ) | |
| # Main analysis function | |
| analyze_btn.click( | |
| predict_wrapper, | |
| inputs=[text_input, model_dropdown, confidence_slider], | |
| outputs=[results_html, summary_output], | |
| ) | |
| # Auto-update when model changes (load first example) | |
| model_dropdown.change( | |
| lambda model_name: load_example(model_name, 0), | |
| inputs=[model_dropdown], | |
| outputs=[text_input], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| show_error=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |