# app.py ── Biomedical NER demo (full vs. LoRA/CRF) # -------------------------------------------------- from __future__ import annotations import html, logging, warnings from functools import lru_cache import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForTokenClassification from peft import PeftModel from huggingface_hub import hf_hub_download # ← missing import added # ─────────── silence library warnings ─────────── warnings.filterwarnings("ignore", category=UserWarning, module="peft") logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) # ─────────── constants ─────────── BASE = "dmis-lab/biobert-base-cased-v1.2" REPO = "vishalvaka/biobert-finetuned-ner" # one repo; variants in sub-folders VARIANTS: dict[str, tuple[str, str]] = { "Full fine-tune" : ("full", "full"), "LoRA-r32" : ("lora-r32", "lora"), "LoRA-r32-fast" : ("lora-r32-fast", "lora"), "LoRA-r16-CRF" : ("lora-r16-crf", "lora"), "LoRA-r16-CRF-long" : ("lora-r16-crf-long", "lora"), } LABELS = ["O", "B-Chemical", "I-Chemical", "B-Disease", "I-Disease"] id2label = {i: lab for i, lab in enumerate(LABELS)} label2id = {lab: i for i, lab in id2label.items()} # ─────────── model loader (cached) ─────────── @lru_cache(maxsize=None) def load_model(folder: str, mode: str): """ Returns (tokenizer, model) cached per variant. • mode == "full" → load full checkpoint from sub-folder • mode == "lora" → load BASE + LoRA adapter from sub-folder """ if mode == "full": model = AutoModelForTokenClassification.from_pretrained( REPO, subfolder=folder ) tok = AutoTokenizer.from_pretrained(REPO, subfolder=folder) # ensure human-readable label maps model.config.id2label, model.config.label2id = id2label, label2id return tok, model.eval() # ---------- LoRA (with or without CRF) ---------- base = AutoModelForTokenClassification.from_pretrained( BASE, num_labels=len(LABELS), id2label=id2label, label2id=label2id ) model = PeftModel.from_pretrained(base, REPO, subfolder=folder) tok = AutoTokenizer.from_pretrained(BASE) # attach CRF/classifier weights **iff the file exists** try: if "crf" in folder: # only these have it head_path = hf_hub_download(REPO, "non_encoder_head.pth", subfolder=folder, repo_type="model") extra = torch.load(head_path, map_location="cpu") model.load_state_dict(extra, strict=False) except Exception as e: warnings.warn(f"[{folder}] couldn’t load CRF head: {e}") return tok, model.eval() # ─────────── helper to build HTML output ─────────── def build_html(tokens: list[str], labels: list[str]) -> str: """Merge WordPieces and contiguous I-tokens.""" segments: list[tuple[str | None, str]] = [] # (entity_tag, text) cur_tag, buf = None, "" for tok, lab in zip(tokens, labels): tag = None if lab == "O" else lab.split("-")[-1] # Chemical / Disease text = tok[2:] if tok.startswith("##") else tok # drop ## continuation = tok.startswith("##") or lab.startswith("I-") if tag == cur_tag and continuation: buf += text else: if buf: segments.append((cur_tag, buf)) buf, cur_tag = text, tag if buf: segments.append((cur_tag, buf)) html_out, first = "", True for tag, chunk in segments: spacer = "" if first else " " first = False chunk = html.escape(chunk) html_out += spacer + ( chunk if tag is None else f'{chunk}' ) return html_out # ─────────── Gradio inference fn ─────────── def ner(text: str, variant: str): folder, mode = VARIANTS[variant] tok, model = load_model(folder, mode) enc = tok(text, return_tensors="pt") with torch.no_grad(): logits = model(**enc).logits.squeeze(0) ids = logits.argmax(dim=-1).tolist()[1:-1] # drop CLS/SEP tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])[1:-1] labels = [model.config.id2label[i] for i in ids] return build_html(tokens, labels) # ─────────── UI definition ─────────── CSS = """ span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px} span.Disease {background:#ffdddd; padding:2px 4px; border-radius:4px} """ demo = gr.Interface( fn=ner, inputs=[ gr.Textbox(lines=7, label="Paste biomedical text"), gr.Radio(list(VARIANTS.keys()), value="LoRA-r32", label="Model variant"), ], outputs=gr.HTML(label="Tagged output"), examples=[ ["Intravenous administration of infliximab significantly reduced C-reactive protein levels and improved remission rates in Crohn's disease patients."], ], css=CSS, theme=gr.themes.Soft(), # quiet built-in theme, no 404 cache_examples=False, title="Biomedical NER — Full vs. LoRA / CRF", description="Toggle a variant and watch **Chemical** / **Disease** entities light up. " "Full checkpoints for CRF models, compact adapters for LoRA runs.", ) if __name__ == "__main__": demo.launch()