vishalvaka commited on
Commit
aaf7fac
Β·
1 Parent(s): 4913f36

made some fixes for CRF models

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +104 -44
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/*
2
+ .venv
app.py CHANGED
@@ -1,59 +1,120 @@
1
- import gradio as gr, torch
 
 
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForTokenClassification
3
  from peft import PeftModel
 
 
 
 
 
4
 
5
- # --------------------------------------------------------
6
  BASE = "dmis-lab/biobert-base-cased-v1.2"
7
- REPO = "vishalvaka/biobert-finetuned-ner"
8
-
9
- VARIANTS = {
10
- "Full fine-tune" : ("full", "full"),
11
- "LoRA-r32" : ("lora-r32", "lora"),
12
- "LoRA-r32-fast" : ("lora-r32-fast", "lora"),
13
- "LoRA-r16-CRF" : ("lora-r16-crf", "lora"),
14
- "LoRA-r16-CRF-long" : ("lora-r16-crf-long", "lora"),
15
  }
16
- # (folder_name, loader_type)
17
 
18
- @gr.memoize()
19
- def load_model(folder, mode):
 
 
 
 
 
20
  """
21
- folder = sub-directory inside the model repo
22
- mode = "full" | "lora"
 
 
23
  """
24
  if mode == "full":
25
  model = AutoModelForTokenClassification.from_pretrained(
26
  REPO, subfolder=folder
27
  )
28
  tok = AutoTokenizer.from_pretrained(REPO, subfolder=folder)
29
- else: # LoRA – reattach to BASE
30
- base = AutoModelForTokenClassification.from_pretrained(BASE)
31
- model = PeftModel.from_pretrained(base, REPO, subfolder=folder)
32
- tok = AutoTokenizer.from_pretrained(BASE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return tok, model.eval()
34
 
35
- def ner(text, variant):
36
- folder, mode = VARIANTS[variant]
37
- tok, model = load_model(folder, mode)
 
 
38
 
39
- tokens = tok(text, return_tensors="pt")
40
- with torch.no_grad():
41
- logits = model(**tokens).logits.squeeze(0)
42
- ids = logits.argmax(dim=-1).tolist()[1:-1] # drop CLS/SEP
43
- words = tok.convert_ids_to_tokens(tokens["input_ids"][0])[1:-1]
44
-
45
- # colours: B-Chemical / B-Disease / I-* already encoded in labels
46
- label_map = model.config.id2label
47
- html = ""
48
- for w, i in zip(words, ids):
49
- lab = label_map[i]
50
- if lab == "O":
51
- html += " " + w
52
  else:
53
- entity = lab.split("-")[-1] # Chemical | Disease
54
- html += f' <span class="{entity}">{w}</span>'
55
- return html.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
57
  CSS = """
58
  span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px}
59
  span.Disease {background:#ffdddd; padding:2px 4px; border-radius:4px}
@@ -67,14 +128,13 @@ demo = gr.Interface(
67
  ],
68
  outputs=gr.HTML(label="Tagged output"),
69
  examples=[
70
- ["Airway inflammation was reduced by salbutamol in cystic fibrosis patients."],
71
  ],
72
  css=CSS,
73
- title="Biomedical NER – full vs. LoRA",
74
- description=(
75
- "Compare a full fine-tune with several parameter-efficient LoRA/CRF adapters.\n"
76
- "Entities are highlighted inline. Pick a variant and paste any PubMed-style text."
77
- ),
78
  )
79
 
80
  if __name__ == "__main__":
 
1
+ # app.py ── Biomedical NER demo (full vs. LoRA/CRF)
2
+ # --------------------------------------------------
3
+ from __future__ import annotations
4
+ import html, logging, warnings
5
+ from functools import lru_cache
6
+
7
+ import gradio as gr
8
+ import torch
9
  from transformers import AutoTokenizer, AutoModelForTokenClassification
10
  from peft import PeftModel
11
+ from huggingface_hub import hf_hub_download # ← missing import added
12
+
13
+ # ─────────── silence library warnings ───────────
14
+ warnings.filterwarnings("ignore", category=UserWarning, module="peft")
15
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
16
 
17
+ # ─────────── constants ───────────
18
  BASE = "dmis-lab/biobert-base-cased-v1.2"
19
+ REPO = "vishalvaka/biobert-finetuned-ner" # one repo; variants in sub-folders
20
+
21
+ VARIANTS: dict[str, tuple[str, str]] = {
22
+ "Full fine-tune" : ("full", "full"),
23
+ "LoRA-r32" : ("lora-r32", "lora"),
24
+ "LoRA-r32-fast" : ("lora-r32-fast", "lora"),
25
+ "LoRA-r16-CRF" : ("lora-r16-crf", "lora"),
26
+ "LoRA-r16-CRF-long" : ("lora-r16-crf-long", "lora"),
27
  }
 
28
 
29
+ LABELS = ["O", "B-Chemical", "I-Chemical", "B-Disease", "I-Disease"]
30
+ id2label = {i: lab for i, lab in enumerate(LABELS)}
31
+ label2id = {lab: i for i, lab in id2label.items()}
32
+
33
+ # ─────────── model loader (cached) ───────────
34
+ @lru_cache(maxsize=None)
35
+ def load_model(folder: str, mode: str):
36
  """
37
+ Returns (tokenizer, model) cached per variant.
38
+
39
+ β€’ mode == "full" β†’ load full checkpoint from sub-folder
40
+ β€’ mode == "lora" β†’ load BASE + LoRA adapter from sub-folder
41
  """
42
  if mode == "full":
43
  model = AutoModelForTokenClassification.from_pretrained(
44
  REPO, subfolder=folder
45
  )
46
  tok = AutoTokenizer.from_pretrained(REPO, subfolder=folder)
47
+ # ensure human-readable label maps
48
+ model.config.id2label, model.config.label2id = id2label, label2id
49
+ return tok, model.eval()
50
+
51
+ # ---------- LoRA (with or without CRF) ----------
52
+ base = AutoModelForTokenClassification.from_pretrained(
53
+ BASE, num_labels=len(LABELS), id2label=id2label, label2id=label2id
54
+ )
55
+ model = PeftModel.from_pretrained(base, REPO, subfolder=folder)
56
+ tok = AutoTokenizer.from_pretrained(BASE)
57
+
58
+ # attach CRF/classifier weights **iff the file exists**
59
+ try:
60
+ if "crf" in folder: # only these have it
61
+ head_path = hf_hub_download(REPO,
62
+ "non_encoder_head.pth",
63
+ subfolder=folder,
64
+ repo_type="model")
65
+ extra = torch.load(head_path, map_location="cpu")
66
+ model.load_state_dict(extra, strict=False)
67
+ except Exception as e:
68
+ warnings.warn(f"[{folder}] couldn’t load CRF head: {e}")
69
+
70
  return tok, model.eval()
71
 
72
+ # ─────────── helper to build HTML output ───────────
73
+ def build_html(tokens: list[str], labels: list[str]) -> str:
74
+ """Merge WordPieces and contiguous I-tokens."""
75
+ segments: list[tuple[str | None, str]] = [] # (entity_tag, text)
76
+ cur_tag, buf = None, ""
77
 
78
+ for tok, lab in zip(tokens, labels):
79
+ tag = None if lab == "O" else lab.split("-")[-1] # Chemical / Disease
80
+ text = tok[2:] if tok.startswith("##") else tok # drop ##
81
+ continuation = tok.startswith("##") or lab.startswith("I-")
82
+
83
+ if tag == cur_tag and continuation:
84
+ buf += text
 
 
 
 
 
 
85
  else:
86
+ if buf:
87
+ segments.append((cur_tag, buf))
88
+ buf, cur_tag = text, tag
89
+ if buf:
90
+ segments.append((cur_tag, buf))
91
+
92
+ html_out, first = "", True
93
+ for tag, chunk in segments:
94
+ spacer = "" if first else " "
95
+ first = False
96
+ chunk = html.escape(chunk)
97
+ html_out += spacer + (
98
+ chunk if tag is None else f'<span class="{tag}">{chunk}</span>'
99
+ )
100
+ return html_out
101
+
102
+ # ─────────── Gradio inference fn ───────────
103
+ def ner(text: str, variant: str):
104
+ folder, mode = VARIANTS[variant]
105
+ tok, model = load_model(folder, mode)
106
+
107
+ enc = tok(text, return_tensors="pt")
108
+ with torch.no_grad():
109
+ logits = model(**enc).logits.squeeze(0)
110
+
111
+ ids = logits.argmax(dim=-1).tolist()[1:-1] # drop CLS/SEP
112
+ tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])[1:-1]
113
+ labels = [model.config.id2label[i] for i in ids]
114
+
115
+ return build_html(tokens, labels)
116
 
117
+ # ─────────── UI definition ───────────
118
  CSS = """
119
  span.Chemical {background:#ffddff; padding:2px 4px; border-radius:4px}
120
  span.Disease {background:#ffdddd; padding:2px 4px; border-radius:4px}
 
128
  ],
129
  outputs=gr.HTML(label="Tagged output"),
130
  examples=[
131
+ ["Intravenous administration of infliximab significantly reduced C-reactive protein levels and improved remission rates in Crohn’s disease patients."],
132
  ],
133
  css=CSS,
134
+ theme=gr.themes.Soft(), # quiet built-in theme, no 404
135
+ title="Biomedical NER β€” Full vs. LoRA / CRF",
136
+ description="Toggle a variant and watch **Chemical** / **Disease** entities light up. "
137
+ "Full checkpoints for CRF models, compact adapters for LoRA runs.",
 
138
  )
139
 
140
  if __name__ == "__main__":