manu02's picture
Initial commit
8b72e45 verified
raw
history blame
18.5 kB
# app.py
"""
Gradio word-level attention visualizer with:
- Paragraph-style wrapping and semi-transparent backgrounds per word
- Proper detokenization to words (regex)
- Ability to pick from many causal LMs
- Trailing EOS/PAD special tokens removed (no <|endoftext|> shown)
- FIX: safely reset Radio with value=None to avoid Gradio choices error
"""
import re
from typing import List, Tuple
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import numpy as np
# =========================
# Config
# =========================
ALLOWED_MODELS = [
# ---- GPT-2 family
"gpt2", "distilgpt2", "gpt2-medium", "gpt2-large", "gpt2-xl",
# ---- EleutherAI (Neo/J/NeoX/Pythia)
"EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B",
"EleutherAI/gpt-j-6B", "EleutherAI/gpt-neox-20b",
"EleutherAI/pythia-70m", "EleutherAI/pythia-160m", "EleutherAI/pythia-410m",
"EleutherAI/pythia-1b", "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b",
"EleutherAI/pythia-6.9b", "EleutherAI/pythia-12b",
# ---- Meta OPT
"facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b",
"facebook/opt-6.7b", "facebook/opt-13b", "facebook/opt-30b",
# ---- Mistral
"mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.3", "mistralai/Mistral-7B-Instruct-v0.2",
# ---- TinyLlama / OpenLLaMA
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
"openlm-research/open_llama_3b", "openlm-research/open_llama_7b",
# ---- Microsoft Phi
"microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2",
# ---- Qwen
"Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B", "Qwen/Qwen1.5-7B",
"Qwen/Qwen2-1.5B", "Qwen/Qwen2-7B",
# ---- MPT
"mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct",
# ---- Falcon
"tiiuae/falcon-7b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b",
# ---- Cerebras GPT
"cerebras/Cerebras-GPT-111M", "cerebras/Cerebras-GPT-256M",
"cerebras/Cerebras-GPT-590M", "cerebras/Cerebras-GPT-1.3B", "cerebras/Cerebras-GPT-2.7B",
]
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
tokenizer = None
# Word regex (words + punctuation)
WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
# =========================
# Model loading
# =========================
def _safe_set_attn_impl(m):
try:
m.config._attn_implementation = "eager"
except Exception:
pass
def load_model(model_name: str):
"""Load tokenizer+model globally."""
global model, tokenizer
try:
del model
torch.cuda.empty_cache()
except Exception:
pass
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# Ensure pad token id
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
model = AutoModelForCausalLM.from_pretrained(model_name)
_safe_set_attn_impl(model)
if hasattr(model, "resize_token_embeddings") and tokenizer.pad_token_id >= model.get_input_embeddings().num_embeddings:
model.resize_token_embeddings(len(tokenizer))
model.eval()
model.to(device)
def model_heads_layers():
try:
L = int(getattr(model.config, "num_hidden_layers", 12))
except Exception:
L = 12
try:
H = int(getattr(model.config, "num_attention_heads", 12))
except Exception:
H = 12
return max(1, L), max(1, H)
# =========================
# Attention utils
# =========================
def get_attention_for_token_layer(
attentions,
token_index,
layer_index,
batch_index=0,
head_index=0,
mean_across_layers=True,
mean_across_heads=True,
):
"""
attentions: tuple length = #generated tokens
attentions[t] -> tuple of len = num_layers, each: (batch, heads, q, k)
"""
token_attention = attentions[token_index]
if mean_across_layers:
layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k)
else:
layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k)
batch_attention = layer_attention[int(batch_index)] # (heads, q, k)
if mean_across_heads:
head_attention = batch_attention.mean(dim=0) # (q, k)
else:
head_attention = batch_attention[int(head_index)] # (q, k)
return head_attention.squeeze(0) # q==1 -> (k,)
# =========================
# Tokens -> words mapping
# =========================
def _words_and_map_from_tokens(gen_token_ids: List[int]) -> Tuple[List[str], List[int]]:
"""
From *generated* token ids, return:
- words: detokenized words (regex-split)
- word2tok: list where word2tok[i] = index (relative to generated) of the
LAST token that composes that word.
"""
if not gen_token_ids:
return [], []
gen_tokens_str = tokenizer.convert_ids_to_tokens(gen_token_ids)
detok_text = tokenizer.convert_tokens_to_string(gen_tokens_str)
words = WORD_RE.findall(detok_text)
enc = tokenizer(detok_text, return_offsets_mapping=True, add_special_tokens=False)
tok_offsets = enc["offset_mapping"]
n = min(len(tok_offsets), len(gen_token_ids))
spans = [m.span() for m in re.finditer(WORD_RE, detok_text)]
word2tok: List[int] = []
t = 0
for (ws, we) in spans:
last_t = None
while t < n:
ts, te = tok_offsets[t]
if not (te <= ws or ts >= we):
last_t = t
t += 1
else:
if te <= ws:
t += 1
else:
break
if last_t is None:
last_t = max(0, min(n - 1, t - 1))
word2tok.append(int(last_t))
return words, word2tok
# =========================
# Helpers
# =========================
def _strip_trailing_special(ids: List[int]) -> List[int]:
"""Remove trailing EOS/PAD/other special tokens from the generated ids."""
specials = set(getattr(tokenizer, "all_special_ids", []) or [])
j = len(ids)
while j > 0 and ids[j - 1] in specials:
j -= 1
return ids[:j]
def clamp01(x: float) -> float:
x = float(x)
return 0.0 if x < 0 else 1.0 if x > 1 else x
# =========================
# Visualization (WORD-LEVEL)
# =========================
def generate_word_visualization(words: List[str],
abs_word_ends: List[int],
attention_values: np.ndarray,
selected_token_abs_idx: int) -> str:
"""
Paragraph-style visualization over words.
For each word, aggregate attention over its composing tokens (sum),
normalize across words, and render opacity as a semi-transparent background.
"""
if not words or attention_values is None or len(attention_values) == 0:
return (
"<div style='width:100%;'>"
" <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
" <div style='color:#ddd;'>No attention values.</div>"
" </div>"
"</div>"
)
# Start..end spans from ends
starts = []
for i, end in enumerate(abs_word_ends):
if i == 0:
starts.append(0)
else:
starts.append(min(abs_word_ends[i - 1] + 1, end))
# Sum attention per word
word_scores = []
for i, end in enumerate(abs_word_ends):
start = starts[i]
if start > end:
start = end
s = max(0, min(start, len(attention_values) - 1))
e = max(0, min(end, len(attention_values) - 1))
if e < s:
s, e = e, s
word_scores.append(float(attention_values[s:e + 1].sum()))
max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
# Which word holds the selected token?
selected_word_idx = None
for i, end in enumerate(abs_word_ends):
if selected_token_abs_idx <= end:
selected_word_idx = i
break
if selected_word_idx is None and abs_word_ends:
selected_word_idx = len(abs_word_ends) - 1
spans = []
for i, w in enumerate(words):
alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
bg = f"rgba(66,133,244,{alpha:.3f})"
border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
spans.append(
f"<span style='display:inline-block;background:{bg};border:{border};"
f"border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>"
f"{w}</span>"
)
return (
"<div style='width:100%;'>"
" <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
" <div style='white-space:normal;line-height:1.8;'>"
f" {''.join(spans)}"
" </div>"
" </div>"
"</div>"
)
# =========================
# Core functions
# =========================
def run_generation(prompt, max_new_tokens, temperature, top_p):
"""Generate and prepare word-level selector + initial visualization."""
inputs = tokenizer(prompt or "", return_tensors="pt").to(device)
prompt_len = inputs["input_ids"].shape[1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
output_attentions=True,
return_dict_in_generate=True,
)
all_token_ids = outputs.sequences[0].tolist()
generated_token_ids = _strip_trailing_special(all_token_ids[prompt_len:])
# Words and map (word -> last generated token index)
words, word2tok = _words_and_map_from_tokens(generated_token_ids)
display_choices = [(w, i) for i, w in enumerate(words)]
if not display_choices:
return {
state_attentions: None,
state_all_token_ids: None,
state_prompt_len: 0,
state_words: None,
state_word2tok: None,
# SAFE RADIO RESET
radio_word_selector: gr.update(choices=[], value=None),
html_visualization: "<div style='text-align:center;padding:20px;'>No new tokens generated.</div>",
}
first_word_idx = 0
html_init = update_visualization(
first_word_idx,
outputs.attentions,
all_token_ids,
prompt_len,
0, 0, True, True,
words,
word2tok,
)
return {
state_attentions: outputs.attentions,
state_all_token_ids: all_token_ids,
state_prompt_len: prompt_len,
state_words: words,
state_word2tok: word2tok,
radio_word_selector: gr.update(choices=display_choices, value=first_word_idx),
html_visualization: html_init,
}
def update_visualization(
selected_word_index,
attentions,
all_token_ids,
prompt_len,
layer,
head,
mean_layers,
mean_heads,
words,
word2tok,
):
"""Recompute visualization for the chosen word (maps to its last token)."""
if selected_word_index is None or attentions is None or word2tok is None:
return "<div style='text-align:center;padding:20px;'>Generate text first.</div>"
widx = int(selected_word_index)
if not (0 <= widx < len(word2tok)):
return "<div style='text-align:center;padding:20px;'>Invalid selection.</div>"
token_index_relative = int(word2tok[widx])
token_index_absolute = int(prompt_len) + token_index_relative
token_attn = get_attention_for_token_layer(
attentions,
token_index=token_index_relative,
layer_index=int(layer),
head_index=int(head),
mean_across_layers=bool(mean_layers),
mean_across_heads=bool(mean_heads),
)
attn_vals = token_attn.detach().cpu().numpy()
# Pad attention to full (prompt + generated) length
total_tokens = len(all_token_ids)
padded = np.zeros(total_tokens, dtype=float)
if attn_vals.ndim == 2:
attn_vals = attn_vals[-1]
padded[: len(attn_vals)] = attn_vals
# Absolute word ends (prompt offset + relative token index)
abs_word_ends = [int(prompt_len) + int(t) for t in (word2tok or [])]
return generate_word_visualization(words, abs_word_ends, padded, token_index_absolute)
def toggle_slider(is_mean):
return gr.update(interactive=not bool(is_mean))
# =========================
# Gradio UI
# =========================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤖 Word-Level Attention Visualizer — choose a model & explore")
gr.Markdown(
"Pick a model, generate text, then select a **generated word** to see where it attends. "
"Words wrap in a paragraph; opacity is the summed attention over the word’s tokens. "
"EOS tokens are stripped so `<|endoftext|>` doesn’t appear."
)
# States
state_attentions = gr.State(None)
state_all_token_ids = gr.State(None)
state_prompt_len = gr.State(None)
state_words = gr.State(None)
state_word2tok = gr.State(None)
state_model_name = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 0) Model")
dd_model = gr.Dropdown(
ALLOWED_MODELS, value=ALLOWED_MODELS[0], label="Causal LM",
info="Models that work with AutoModelForCausalLM + attentions"
)
btn_load = gr.Button("Load / Switch Model", variant="secondary")
gr.Markdown("### 1) Generation")
txt_prompt = gr.Textbox("In a distant future, humanity", label="Prompt")
btn_generate = gr.Button("Generate", variant="primary")
slider_max_tokens = gr.Slider(10, 200, value=50, step=10, label="Max New Tokens")
slider_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature")
slider_top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top P")
gr.Markdown("### 2) Attention")
check_mean_layers = gr.Checkbox(True, label="Mean Across Layers")
check_mean_heads = gr.Checkbox(True, label="Mean Across Heads")
slider_layer = gr.Slider(0, 11, value=0, step=1, label="Layer", interactive=False)
slider_head = gr.Slider(0, 11, value=0, step=1, label="Head", interactive=False)
with gr.Column(scale=3):
radio_word_selector = gr.Radio(
[], label="Select Generated Word to Visualize",
info="Click Generate to populate"
)
html_visualization = gr.HTML(
"<div style='text-align:center;padding:20px;color:#888;border:1px dashed #888;border-radius:8px;'>"
"Attention visualization will appear here.</div>"
)
# Load/switch model
def on_load_model(selected_name, mean_layers, mean_heads):
load_model(selected_name)
L, H = model_heads_layers()
return (
selected_name, # state_model_name
gr.update(minimum=0, maximum=L - 1, value=0, interactive=not bool(mean_layers)),
gr.update(minimum=0, maximum=H - 1, value=0, interactive=not bool(mean_heads)),
# SAFE RADIO RESET (avoid Value: [] not in choices)
gr.update(choices=[], value=None),
"<div style='text-align:center;padding:20px;'>Model loaded. Generate to visualize.</div>",
)
btn_load.click(
fn=on_load_model,
inputs=[dd_model, check_mean_layers, check_mean_heads],
outputs=[state_model_name, slider_layer, slider_head, radio_word_selector, html_visualization],
)
# Load default model at app start
def _init_model(_):
load_model(ALLOWED_MODELS[0])
L, H = model_heads_layers()
return (
ALLOWED_MODELS[0],
gr.update(minimum=0, maximum=L - 1, value=0, interactive=False if check_mean_layers.value else True),
gr.update(minimum=0, maximum=H - 1, value=0, interactive=False if check_mean_heads.value else True),
# Also ensure radio is clean at start
gr.update(choices=[], value=None),
)
demo.load(_init_model, inputs=[gr.State(None)], outputs=[state_model_name, slider_layer, slider_head, radio_word_selector])
# Generate
btn_generate.click(
fn=run_generation,
inputs=[txt_prompt, slider_max_tokens, slider_temp, slider_top_p],
outputs=[
state_attentions,
state_all_token_ids,
state_prompt_len,
state_words,
state_word2tok,
radio_word_selector,
html_visualization,
],
)
# Update viz on any control
for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
control.change(
fn=update_visualization,
inputs=[
radio_word_selector,
state_attentions,
state_all_token_ids,
state_prompt_len,
slider_layer,
slider_head,
check_mean_layers,
check_mean_heads,
state_words,
state_word2tok,
],
outputs=html_visualization,
)
# Toggle slider interactivity
check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
if __name__ == "__main__":
print(f"Device: {device}")
# Ensure a default model is ready
load_model(ALLOWED_MODELS[0])
demo.launch(debug=True)