Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline | |
| import html | |
| from collections import defaultdict | |
| # 设置页面 | |
| st.set_page_config( | |
| page_title="OpenMed NER Demo", | |
| page_icon="🏥", | |
| layout="wide" | |
| ) | |
| # 模型映射 | |
| MODELS = { | |
| "Pharmacology": "OpenMed/OpenMed-NER-PharmaDetect-SuperClinical-434M", | |
| "Oncology Genetics": "OpenMed/OpenMed-NER-OncologyDetect-SuperMedical-355M", | |
| "Species Detection": "OpenMed/OpenMed-NER-SpeciesDetect-PubMed-335M", | |
| "Chemical Detection": "OpenMed/OpenMed-NER-ChemicalDetect-PubMed-335M", | |
| "Anatomy Detection": "OpenMed/OpenMed-NER-AnatomyDetect-PubMed-335M", | |
| "Blood Cancer Detection": "OpenMed/OpenMed-NER-BloodCancerDetect-TinyMed-82M", | |
| "Disease Detection": "OpenMed/OpenMed-NER-DiseaseDetect-SuperClinical-434M" | |
| } | |
| # 实体类型颜色映射 | |
| ENTITY_COLORS = { | |
| "DRUG": "#FF9999", # 药物 - 浅红色 | |
| "CHEMICAL": "#FFCC99", # 化学物质 - 浅橙色 | |
| "DISEASE": "#FF99CC", # 疾病 - 浅粉色 | |
| "ANATOMY": "#99CCFF", # 解剖结构 - 浅蓝色 | |
| "SPECIES": "#99FF99", # 物种 - 浅绿色 | |
| "GENE": "#CC99FF", # 基因 - 浅紫色 | |
| "PROTEIN": "#FFFF99", # 蛋白质 - 浅黄色 | |
| "CELL": "#99FFFF", # 细胞 - 浅青色 | |
| "default": "#DDDDDD" # 默认 - 浅灰色 | |
| } | |
| # 初始化会话状态 | |
| if "text_input" not in st.session_state: | |
| st.session_state.text_input = "" | |
| if "entities" not in st.session_state: | |
| st.session_state.entities = [] | |
| if "model_loaded" not in st.session_state: | |
| st.session_state.model_loaded = None | |
| # 缓存模型加载 | |
| def load_model(model_name): | |
| try: | |
| ner_pipeline = pipeline( | |
| "token-classification", | |
| model=model_name, | |
| aggregation_strategy="simple" | |
| ) | |
| return ner_pipeline | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return None | |
| # 高亮文本中的实体 | |
| def highlight_entities(text, entities): | |
| if not entities: | |
| return text | |
| # 将文本转换为HTML安全格式 | |
| safe_text = html.escape(text) | |
| # 按起始位置排序实体 | |
| sorted_entities = sorted(entities, key=lambda x: x['start']) | |
| # 构建高亮文本 | |
| highlighted_parts = [] | |
| last_end = 0 | |
| for entity in sorted_entities: | |
| # 添加实体前的文本 | |
| if entity['start'] > last_end: | |
| highlighted_parts.append(safe_text[last_end:entity['start']]) | |
| # 获取实体颜色 | |
| entity_type = entity['entity_group'] | |
| color = ENTITY_COLORS.get(entity_type, ENTITY_COLORS['default']) | |
| # 添加高亮的实体 | |
| entity_text = safe_text[entity['start']:entity['end']] | |
| highlighted_parts.append( | |
| f'<mark style="background-color: {color}; padding: 2px 4px; border-radius: 3px;" ' | |
| f'title="{entity_type} (confidence: {entity["score"]:.3f})">' | |
| f'{entity_text}' | |
| f'</mark>' | |
| ) | |
| last_end = entity['end'] | |
| # 添加剩余文本 | |
| if last_end < len(safe_text): | |
| highlighted_parts.append(safe_text[last_end:]) | |
| return ''.join(highlighted_parts) | |
| # 应用标题 | |
| st.title("🏥 OpenMed Named Entity Recognition Demo") | |
| st.markdown("Using domain-specific pre-trained models for medical text analysis") | |
| # 侧边栏 - 模型选择 | |
| st.sidebar.header("Model Selection") | |
| selected_domain = st.sidebar.selectbox( | |
| "Select Domain", | |
| list(MODELS.keys()) | |
| ) | |
| # 加载选定模型 | |
| model_name = MODELS[selected_domain] | |
| # 如果模型改变,清除之前的实体结果 | |
| if st.session_state.model_loaded != model_name: | |
| st.session_state.entities = [] | |
| st.session_state.model_loaded = model_name | |
| ner_pipeline = load_model(model_name) | |
| # 显示模型信息 | |
| st.sidebar.header("Model Information") | |
| st.sidebar.write(f"**Domain**: {selected_domain}") | |
| st.sidebar.write(f"**Model**: {model_name.split('/')[-1]}") | |
| # 示例文本 (英文) | |
| example_texts = { | |
| "Pharmacology": "The patient was prescribed aspirin and warfarin for anticoagulation therapy.", | |
| "Oncology Genetics": "BRCA1 gene mutations are associated with increased risk of breast and ovarian cancer.", | |
| "Species Detection": "Researchers tested the new drug in a mouse model and observed significant effects.", | |
| "Chemical Detection": "Glucose and oxygen molecules play key roles in cellular respiration processes.", | |
| "Anatomy Detection": "The patient reported pain in the right knee joint radiating to the thigh.", | |
| "Blood Cancer Detection": "The patient was diagnosed with chronic lymphocytic leukemia and requires regular monitoring of lymphocyte counts.", | |
| "Disease Detection": "Patients with diabetes mellitus often have increased risk of hypertension and cardiovascular disease." | |
| } | |
| # 主区域 | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.header("Text Input") | |
| # 示例文本按钮 | |
| if st.button("Load Example Text"): | |
| st.session_state.text_input = example_texts[selected_domain] | |
| st.session_state.entities = [] # 清除之前的实体结果 | |
| # 文本输入区域 | |
| text = st.text_area( | |
| "Enter text to analyze:", | |
| value=st.session_state.text_input, | |
| height=200, | |
| help="Enter medical text for analysis", | |
| key="text_input_widget" | |
| ) | |
| # 更新会话状态中的文本 | |
| st.session_state.text_input = text | |
| # 分析按钮 | |
| if st.button("Analyze Text", type="primary"): | |
| if st.session_state.text_input.strip(): | |
| with st.spinner("Analyzing..."): | |
| try: | |
| entities = ner_pipeline(st.session_state.text_input) | |
| st.session_state.entities = entities | |
| st.success("Analysis completed!") | |
| except Exception as e: | |
| st.error(f"Error during analysis: {str(e)}") | |
| else: | |
| st.warning("Please enter text to analyze") | |
| with col2: | |
| st.header("NER Results") | |
| if st.session_state.entities and st.session_state.text_input: | |
| entities = st.session_state.entities | |
| # 显示高亮文本 | |
| st.markdown("### Highlighted Text") | |
| highlighted_text = highlight_entities(st.session_state.text_input, entities) | |
| st.markdown(highlighted_text, unsafe_allow_html=True) | |
| # 显示实体统计 | |
| st.markdown("### Entity Statistics") | |
| entity_counts = defaultdict(int) | |
| for entity in entities: | |
| entity_counts[entity['entity_group']] += 1 | |
| if entity_counts: | |
| for entity_type, count in entity_counts.items(): | |
| color = ENTITY_COLORS.get(entity_type, ENTITY_COLORS['default']) | |
| st.markdown( | |
| f'<span style="background-color: {color}; padding: 4px 8px; ' | |
| f'border-radius: 4px; margin-right: 8px; color: black;">' | |
| f'{entity_type}: {count}' | |
| f'</span>', | |
| unsafe_allow_html=True | |
| ) | |
| else: | |
| st.info("No entities detected") | |
| # 显示详细实体列表 | |
| st.markdown("### Entity Details") | |
| if entities: | |
| for i, entity in enumerate(entities): | |
| color = ENTITY_COLORS.get(entity['entity_group'], ENTITY_COLORS['default']) | |
| st.markdown( | |
| f"{i+1}. **{entity['word']}** - " | |
| f"<span style='color: {color};'>{entity['entity_group']}</span> " | |
| f"(confidence: {entity['score']:.3f})", | |
| unsafe_allow_html=True | |
| ) | |
| else: | |
| st.info("No entities detected") | |
| else: | |
| st.info("Please enter text and click 'Analyze Text'") | |
| # 底部信息 | |
| st.markdown("---") | |
| st.markdown( | |
| "### Instructions\n" | |
| "1. Select a domain-specific NER model from the left sidebar\n" | |
| "2. Enter or paste medical text in the input box\n" | |
| "3. Click the 'Analyze Text' button to run the model\n" | |
| "4. View the entity recognition results on the right\n\n" | |
| "Different colored highlights represent different entity types. Hover over entities to see type and confidence." | |
| ) | |
| # 隐藏Streamlit默认样式 | |
| hide_st_style = """ | |
| <style> | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| header {visibility: hidden;} | |
| </style> | |
| """ | |
| st.markdown(hide_st_style, unsafe_allow_html=True) |