Spaces:
Sleeping
Sleeping
Commit
Β·
aaf7fac
1
Parent(s):
4913f36
made some fixes for CRF models
Browse files- .gitignore +2 -0
- app.py +104 -44
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/*
|
| 2 |
+
.venv
|
app.py
CHANGED
|
@@ -1,59 +1,120 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
| 3 |
from peft import PeftModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
#
|
| 6 |
BASE = "dmis-lab/biobert-base-cased-v1.2"
|
| 7 |
-
REPO = "vishalvaka/biobert-finetuned-ner"
|
| 8 |
-
|
| 9 |
-
VARIANTS = {
|
| 10 |
-
"Full fine-tune"
|
| 11 |
-
"LoRA-r32"
|
| 12 |
-
"LoRA-r32-fast"
|
| 13 |
-
"LoRA-r16-CRF"
|
| 14 |
-
"LoRA-r16-CRF-long"
|
| 15 |
}
|
| 16 |
-
# (folder_name, loader_type)
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
"""
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
if mode == "full":
|
| 25 |
model = AutoModelForTokenClassification.from_pretrained(
|
| 26 |
REPO, subfolder=folder
|
| 27 |
)
|
| 28 |
tok = AutoTokenizer.from_pretrained(REPO, subfolder=folder)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
return tok, model.eval()
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
label_map = model.config.id2label
|
| 47 |
-
html = ""
|
| 48 |
-
for w, i in zip(words, ids):
|
| 49 |
-
lab = label_map[i]
|
| 50 |
-
if lab == "O":
|
| 51 |
-
html += " " + w
|
| 52 |
else:
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
|
|
|
| 57 |
CSS = """
|
| 58 |
span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px}
|
| 59 |
span.Disease {background:#ffdddd; padding:2px 4px; border-radius:4px}
|
|
@@ -67,14 +128,13 @@ demo = gr.Interface(
|
|
| 67 |
],
|
| 68 |
outputs=gr.HTML(label="Tagged output"),
|
| 69 |
examples=[
|
| 70 |
-
["
|
| 71 |
],
|
| 72 |
css=CSS,
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
),
|
| 78 |
)
|
| 79 |
|
| 80 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
# app.py ββ Biomedical NER demo (full vs. LoRA/CRF)
|
| 2 |
+
# --------------------------------------------------
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
import html, logging, warnings
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import torch
|
| 9 |
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
| 10 |
from peft import PeftModel
|
| 11 |
+
from huggingface_hub import hf_hub_download # β missing import added
|
| 12 |
+
|
| 13 |
+
# βββββββββββ silence library warnings βββββββββββ
|
| 14 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="peft")
|
| 15 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
| 16 |
|
| 17 |
+
# βββββββββββ constants βββββββββββ
|
| 18 |
BASE = "dmis-lab/biobert-base-cased-v1.2"
|
| 19 |
+
REPO = "vishalvaka/biobert-finetuned-ner" # one repo; variants in sub-folders
|
| 20 |
+
|
| 21 |
+
VARIANTS: dict[str, tuple[str, str]] = {
|
| 22 |
+
"Full fine-tune" : ("full", "full"),
|
| 23 |
+
"LoRA-r32" : ("lora-r32", "lora"),
|
| 24 |
+
"LoRA-r32-fast" : ("lora-r32-fast", "lora"),
|
| 25 |
+
"LoRA-r16-CRF" : ("lora-r16-crf", "lora"),
|
| 26 |
+
"LoRA-r16-CRF-long" : ("lora-r16-crf-long", "lora"),
|
| 27 |
}
|
|
|
|
| 28 |
|
| 29 |
+
LABELS = ["O", "B-Chemical", "I-Chemical", "B-Disease", "I-Disease"]
|
| 30 |
+
id2label = {i: lab for i, lab in enumerate(LABELS)}
|
| 31 |
+
label2id = {lab: i for i, lab in id2label.items()}
|
| 32 |
+
|
| 33 |
+
# βββββββββββ model loader (cached) βββββββββββ
|
| 34 |
+
@lru_cache(maxsize=None)
|
| 35 |
+
def load_model(folder: str, mode: str):
|
| 36 |
"""
|
| 37 |
+
Returns (tokenizer, model) cached per variant.
|
| 38 |
+
|
| 39 |
+
β’ mode == "full" β load full checkpoint from sub-folder
|
| 40 |
+
β’ mode == "lora" β load BASE + LoRA adapter from sub-folder
|
| 41 |
"""
|
| 42 |
if mode == "full":
|
| 43 |
model = AutoModelForTokenClassification.from_pretrained(
|
| 44 |
REPO, subfolder=folder
|
| 45 |
)
|
| 46 |
tok = AutoTokenizer.from_pretrained(REPO, subfolder=folder)
|
| 47 |
+
# ensure human-readable label maps
|
| 48 |
+
model.config.id2label, model.config.label2id = id2label, label2id
|
| 49 |
+
return tok, model.eval()
|
| 50 |
+
|
| 51 |
+
# ---------- LoRA (with or without CRF) ----------
|
| 52 |
+
base = AutoModelForTokenClassification.from_pretrained(
|
| 53 |
+
BASE, num_labels=len(LABELS), id2label=id2label, label2id=label2id
|
| 54 |
+
)
|
| 55 |
+
model = PeftModel.from_pretrained(base, REPO, subfolder=folder)
|
| 56 |
+
tok = AutoTokenizer.from_pretrained(BASE)
|
| 57 |
+
|
| 58 |
+
# attach CRF/classifier weights **iff the file exists**
|
| 59 |
+
try:
|
| 60 |
+
if "crf" in folder: # only these have it
|
| 61 |
+
head_path = hf_hub_download(REPO,
|
| 62 |
+
"non_encoder_head.pth",
|
| 63 |
+
subfolder=folder,
|
| 64 |
+
repo_type="model")
|
| 65 |
+
extra = torch.load(head_path, map_location="cpu")
|
| 66 |
+
model.load_state_dict(extra, strict=False)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
warnings.warn(f"[{folder}] couldnβt load CRF head: {e}")
|
| 69 |
+
|
| 70 |
return tok, model.eval()
|
| 71 |
|
| 72 |
+
# βββββββββββ helper to build HTML output βββββββββββ
|
| 73 |
+
def build_html(tokens: list[str], labels: list[str]) -> str:
|
| 74 |
+
"""Merge WordPieces and contiguous I-tokens."""
|
| 75 |
+
segments: list[tuple[str | None, str]] = [] # (entity_tag, text)
|
| 76 |
+
cur_tag, buf = None, ""
|
| 77 |
|
| 78 |
+
for tok, lab in zip(tokens, labels):
|
| 79 |
+
tag = None if lab == "O" else lab.split("-")[-1] # Chemical / Disease
|
| 80 |
+
text = tok[2:] if tok.startswith("##") else tok # drop ##
|
| 81 |
+
continuation = tok.startswith("##") or lab.startswith("I-")
|
| 82 |
+
|
| 83 |
+
if tag == cur_tag and continuation:
|
| 84 |
+
buf += text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
else:
|
| 86 |
+
if buf:
|
| 87 |
+
segments.append((cur_tag, buf))
|
| 88 |
+
buf, cur_tag = text, tag
|
| 89 |
+
if buf:
|
| 90 |
+
segments.append((cur_tag, buf))
|
| 91 |
+
|
| 92 |
+
html_out, first = "", True
|
| 93 |
+
for tag, chunk in segments:
|
| 94 |
+
spacer = "" if first else " "
|
| 95 |
+
first = False
|
| 96 |
+
chunk = html.escape(chunk)
|
| 97 |
+
html_out += spacer + (
|
| 98 |
+
chunk if tag is None else f'<span class="{tag}">{chunk}</span>'
|
| 99 |
+
)
|
| 100 |
+
return html_out
|
| 101 |
+
|
| 102 |
+
# βββββββββββ Gradio inference fn βββββββββββ
|
| 103 |
+
def ner(text: str, variant: str):
|
| 104 |
+
folder, mode = VARIANTS[variant]
|
| 105 |
+
tok, model = load_model(folder, mode)
|
| 106 |
+
|
| 107 |
+
enc = tok(text, return_tensors="pt")
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
logits = model(**enc).logits.squeeze(0)
|
| 110 |
+
|
| 111 |
+
ids = logits.argmax(dim=-1).tolist()[1:-1] # drop CLS/SEP
|
| 112 |
+
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])[1:-1]
|
| 113 |
+
labels = [model.config.id2label[i] for i in ids]
|
| 114 |
+
|
| 115 |
+
return build_html(tokens, labels)
|
| 116 |
|
| 117 |
+
# βββββββββββ UI definition βββββββββββ
|
| 118 |
CSS = """
|
| 119 |
span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px}
|
| 120 |
span.Disease {background:#ffdddd; padding:2px 4px; border-radius:4px}
|
|
|
|
| 128 |
],
|
| 129 |
outputs=gr.HTML(label="Tagged output"),
|
| 130 |
examples=[
|
| 131 |
+
["Intravenous administration of infliximab significantly reduced C-reactive protein levels and improved remission rates in Crohnβs disease patients."],
|
| 132 |
],
|
| 133 |
css=CSS,
|
| 134 |
+
theme=gr.themes.Soft(), # quiet built-in theme, no 404
|
| 135 |
+
title="Biomedical NER β Full vs. LoRA / CRF",
|
| 136 |
+
description="Toggle a variant and watch **Chemical** / **Disease** entities light up. "
|
| 137 |
+
"Full checkpoints for CRF models, compact adapters for LoRA runs.",
|
|
|
|
| 138 |
)
|
| 139 |
|
| 140 |
if __name__ == "__main__":
|