import gradio as gr import torch from gliner import GLiNER import pandas as pd import warnings import random import re import time warnings.filterwarnings('ignore') # Common NER entity types (using full names) STANDARD_ENTITIES = [ 'DATE', 'EVENT', 'FACILITY', 'GEOPOLITICAL ENTITY', 'LANGUAGE', 'LOCATION', 'MISCELLANEOUS', 'NATIONALITIES/GROUPS', 'ORGANIZATION', 'PERSON', 'PRODUCT', 'WORK OF ART' ] # Colour schemes (updated to match full names) STANDARD_COLORS = { 'DATE': '#FF6B6B', # Red 'EVENT': '#4ECDC4', # Teal 'FACILITY': '#45B7D1', # Blue 'GEOPOLITICAL ENTITY': '#F9CA24', # Yellow 'LANGUAGE': '#6C5CE7', # Purple 'LOCATION': '#A0E7E5', # Light Cyan 'MISCELLANEOUS': '#FD79A8', # Pink 'NATIONALITIES/GROUPS': '#8E8E93', # Grey 'ORGANIZATION': '#55A3FF', # Light Blue 'PERSON': '#00B894', # Green 'PRODUCT': '#E17055', # Orange-Red 'WORK OF ART': '#DDA0DD' # Plum } # Entity definitions for glossary (alphabetically ordered with full name (abbreviation) format) ENTITY_DEFINITIONS = { 'DATE': 'Date (DATE): Absolute or relative dates or periods', 'EVENT': 'Event (EVENT): Named hurricanes, battles, wars, sports events, etc.', 'FACILITY': 'Facility (FAC): Buildings, airports, highways, bridges, etc.', 'GEOPOLITICAL ENTITY': 'Geopolitical Entity (GPE): Countries, cities, states', 'LANGUAGE': 'Language (LANG): Any named language', 'LOCATION': 'Location (LOC): Non-GPE locations - Mountain ranges, bodies of water', 'MISCELLANEOUS': 'Miscellaneous (MISC): Entities that don\'t fit elsewhere', 'NATIONALITIES/GROUPS': 'Nationalities/Groups (NORP): Nationalities or religious or political groups', 'ORGANIZATION': 'Organization (ORG): Companies, agencies, institutions, etc.', 'PERSON': 'Person (PER): People, including fictional characters', 'PRODUCT': 'Product (PRODUCT): Objects, vehicles, foods, etc. (Not services)', 'WORK OF ART': 'Work of Art (Work of Art): Titles of books, songs, movies, paintings, etc.' } # Additional colours for custom entities CUSTOM_COLOR_PALETTE = [ '#FF9F43', '#10AC84', '#EE5A24', '#0FBC89', '#5F27CD', '#FF3838', '#2F3640', '#3742FA', '#2ED573', '#FFA502', '#FF6348', '#1E90FF', '#FF1493', '#32CD32', '#FFD700', '#FF4500', '#DA70D6', '#00CED1', '#FF69B4', '#7B68EE' ] class HybridNERManager: def __init__(self): self.gliner_model = None self.spacy_model = None self.flair_models = {} self.all_entity_colors = {} self.model_names = [ 'flair_ner-large', 'spacy_en_core_web_trf', 'flair_ner-ontonotes-large', 'gliner_knowledgator/modern-gliner-bi-large-v1.0' ] # Mapping from full names to abbreviations for model compatibility self.entity_mapping = { 'DATE': 'DATE', 'EVENT': 'EVENT', 'FACILITY': 'FAC', 'GEOPOLITICAL ENTITY': 'GPE', 'LANGUAGE': 'LANG', 'LOCATION': 'LOC', 'MISCELLANEOUS': 'MISC', 'NATIONALITIES/GROUPS': 'NORP', 'ORGANIZATION': 'ORG', 'PERSON': 'PER', 'PRODUCT': 'PRODUCT', 'WORK OF ART': 'Work of Art' } # Reverse mapping for display self.abbrev_to_full = {v: k for k, v in self.entity_mapping.items()} def load_model(self, model_name): """Load the specified model""" try: if 'spacy' in model_name: return self.load_spacy_model() elif 'flair' in model_name: return self.load_flair_model(model_name) elif 'gliner' in model_name: return self.load_gliner_model() except Exception as e: print(f"Error loading {model_name}: {str(e)}") return None def load_spacy_model(self): """Load spaCy model for common NER""" if self.spacy_model is None: try: import spacy try: # Try transformer model first, fallback to small model self.spacy_model = spacy.load("en_core_web_trf") print("✓ spaCy transformer model loaded successfully") except OSError: try: self.spacy_model = spacy.load("en_core_web_sm") print("✓ spaCy common model loaded successfully") except OSError: print("spaCy model not found. Using GLiNER for all entity types.") return None except Exception as e: print(f"Error loading spaCy model: {str(e)}") return None return self.spacy_model def load_flair_model(self, model_name): """Load Flair models""" if model_name not in self.flair_models: try: from flair.models import SequenceTagger if 'ontonotes' in model_name: model = SequenceTagger.load("flair/ner-english-ontonotes-large") print("✓ Flair OntoNotes model loaded successfully") else: model = SequenceTagger.load("flair/ner-english-large") print("✓ Flair large model loaded successfully") self.flair_models[model_name] = model except Exception as e: print(f"Error loading {model_name}: {str(e)}") # Fallback to GLiNER return self.load_gliner_model() return self.flair_models[model_name] def load_gliner_model(self): """Load GLiNER model for custom entities""" if self.gliner_model is None: try: # Try the modern GLiNER model first, fallback to stable model self.gliner_model = GLiNER.from_pretrained("knowledgator/gliner-bi-large-v1.0") print("✓ GLiNER knowledgator model loaded successfully") except Exception as e: print(f"Primary GLiNER model failed: {str(e)}") try: # Fallback to stable model self.gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1") print("✓ GLiNER fallback model loaded successfully") except Exception as e2: print(f"Error loading GLiNER model: {str(e2)}") return None return self.gliner_model def assign_colours(self, standard_entities, custom_entities): """Assign colours to all entity types""" self.all_entity_colors = {} # Assign common colours for entity in standard_entities: self.all_entity_colors[entity.upper()] = STANDARD_COLORS.get(entity.upper(), '#CCCCCC') # Assign custom colours for i, entity in enumerate(custom_entities): if i < len(CUSTOM_COLOR_PALETTE): self.all_entity_colors[entity.upper()] = CUSTOM_COLOR_PALETTE[i] else: # Generate random colour if we run out self.all_entity_colors[entity.upper()] = f"#{random.randint(0, 0xFFFFFF):06x}" return self.all_entity_colors def extract_entities_by_model(self, text, entity_types, model_name, threshold=0.3): """Extract entities using the specified model""" # Convert full names to abbreviations for model processing abbrev_types = [] for entity in entity_types: if entity in self.entity_mapping: abbrev_types.append(self.entity_mapping[entity]) else: abbrev_types.append(entity) if 'spacy' in model_name: return self.extract_spacy_entities(text, abbrev_types) elif 'flair' in model_name: return self.extract_flair_entities(text, abbrev_types, model_name) elif 'gliner' in model_name: return self.extract_gliner_entities(text, abbrev_types, threshold, is_custom=False) else: return [] def extract_spacy_entities(self, text, entity_types): """Extract entities using spaCy""" model = self.load_spacy_model() if model is None: return [] try: doc = model(text) entities = [] for ent in doc.ents: if ent.label_ in entity_types: # Convert abbreviation back to full name for display display_label = self.abbrev_to_full.get(ent.label_, ent.label_) entities.append({ 'text': ent.text, 'label': display_label, 'start': ent.start_char, 'end': ent.end_char, 'confidence': 1.0, # spaCy doesn't provide confidence scores 'source': 'spaCy' }) return entities except Exception as e: print(f"Error with spaCy extraction: {str(e)}") return [] def extract_flair_entities(self, text, entity_types, model_name): """Extract entities using Flair""" model = self.load_flair_model(model_name) if model is None: return [] try: from flair.data import Sentence sentence = Sentence(text) model.predict(sentence) entities = [] for entity in sentence.get_spans('ner'): # Map Flair labels to our common set label = entity.tag if label == 'PERSON': label = 'PER' elif label == 'ORGANIZATION': label = 'ORG' elif label == 'LOCATION': label = 'LOC' elif label == 'MISCELLANEOUS': label = 'MISC' if label in entity_types: # Convert abbreviation back to full name for display display_label = self.abbrev_to_full.get(label, label) entities.append({ 'text': entity.text, 'label': display_label, 'start': entity.start_position, 'end': entity.end_position, 'confidence': entity.score, 'source': f'Flair-{model_name.split("-")[-1]}' }) return entities except Exception as e: print(f"Error with Flair extraction: {str(e)}") return [] def extract_gliner_entities(self, text, entity_types, threshold=0.3, is_custom=True): """Extract entities using GLiNER""" model = self.load_gliner_model() if model is None: return [] try: entities = model.predict_entities(text, entity_types, threshold=threshold) result = [] for entity in entities: # Convert abbreviation back to full name for display if not custom if not is_custom: display_label = self.abbrev_to_full.get(entity['label'].upper(), entity['label'].upper()) else: display_label = entity['label'].upper() result.append({ 'text': entity['text'], 'label': display_label, 'start': entity['start'], 'end': entity['end'], 'confidence': entity.get('score', 0.0), 'source': 'GLiNER-Custom' if is_custom else 'GLiNER-Common' }) return result except Exception as e: print(f"Error with GLiNER extraction: {str(e)}") return [] def find_overlapping_entities(entities): """Find and share overlapping entities - specifically entities found by BOTH common NER models AND custom entities""" if not entities: return [] # Sort entities by start position sorted_entities = sorted(entities, key=lambda x: x['start']) shared_entities = [] i = 0 while i < len(sorted_entities): current_entity = sorted_entities[i] overlapping_entities = [current_entity] # Find all entities that overlap with current entity j = i + 1 while j < len(sorted_entities): next_entity = sorted_entities[j] # Check if entities overlap (same text span or overlapping positions) if (current_entity['start'] <= next_entity['start'] < current_entity['end'] or next_entity['start'] <= current_entity['start'] < current_entity['end'] or current_entity['text'].lower() == next_entity['text'].lower()): overlapping_entities.append(next_entity) sorted_entities.pop(j) else: j += 1 # Create shared entity only if we have BOTH common and custom entities if len(overlapping_entities) == 1: shared_entities.append(overlapping_entities[0]) else: # Check if this is a true "shared" entity (common + custom) has_common = False has_custom = False for entity in overlapping_entities: source = entity.get('source', '') if source in ['spaCy', 'GLiNER-Common'] or source.startswith('Flair-'): has_common = True elif source == 'GLiNER-Custom': has_custom = True if has_common and has_custom: # This is a true shared entity (common + custom) shared_entity = share_entities(overlapping_entities) shared_entities.append(shared_entity) else: # These are just overlapping entities from the same source type, keep separate shared_entities.extend(overlapping_entities) i += 1 return shared_entities def share_entities(entity_list): """Share multiple overlapping entities into one""" if len(entity_list) == 1: return entity_list[0] # Use the entity with the longest text span as the base base_entity = max(entity_list, key=lambda x: len(x['text'])) # Collect all labels and sources labels = [entity['label'] for entity in entity_list] sources = [entity['source'] for entity in entity_list] confidences = [entity['confidence'] for entity in entity_list] return { 'text': base_entity['text'], 'start': base_entity['start'], 'end': base_entity['end'], 'labels': labels, 'sources': sources, 'confidences': confidences, 'is_shared': True, 'entity_count': len(entity_list) } def create_highlighted_html(text, entities, entity_colors): """Create HTML with highlighted entities""" if not entities: return f"
{text}
| Entity Text | All Labels | Sources | Count |
|---|---|---|---|
| {entity['text']} | {labels_text} | {sources_text} | {entity['entity_count']} |
| Entity Text | Confidence | Type | Source |
|---|---|---|---|
| {entity['text']} | {confidence:.3f} | {entity['label']} | {source_badge} |
No entities found.
" # Share overlapping entities shared_entities = find_overlapping_entities(entities) # Group entities by type entity_groups = {} for entity in shared_entities: if entity.get('is_shared', False): key = 'SHARED_ENTITIES' else: key = entity['label'] if key not in entity_groups: entity_groups[key] = [] entity_groups[key].append(entity) if not entity_groups: return "No entities found.
" # Create container with all tables all_tables_html = """The confidence threshold controls how certain the model needs to be before identifying an entity:
Start with 0.3 for comprehensive detection, then adjust based on your needs.