Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import pandas as pd | |
| from functools import lru_cache | |
| # ---------------------------------------------------------------------- | |
| # IMPORTANT: This version uses the PatchscopesRetriever implementation | |
| # from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words) | |
| # ---------------------------------------------------------------------- | |
| try: | |
| from .word_retriever import PatchscopesRetriever # pip install tokens2words | |
| except ImportError: | |
| PatchscopesRetriever = None | |
| DEFAULT_MODEL = "meta-llama/Llama-3.1-8B" # light default so the demo boots everywhere | |
| DEVICE = 'mps' | |
| # ( | |
| # "cuda" if torch.cuda.is_available() else ("mps" if torch.word_retriever.pybackends.mps.is_available() else "cpu") | |
| # ) | |
| def get_model_and_tokenizer(model_name: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16 , | |
| output_hidden_states=True, | |
| ).to(DEVICE) | |
| model.eval() | |
| return model, tokenizer | |
| def find_last_token_index(full_ids, word_ids): | |
| """Locate end position of word_ids inside full_ids (first match).""" | |
| for i in range(len(full_ids) - len(word_ids) + 1): | |
| if full_ids[i : i + len(word_ids)] == word_ids: | |
| return i + len(word_ids) - 1 | |
| return None | |
| def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str): | |
| if PatchscopesRetriever is None: | |
| return ( | |
| "<p style='color:red'>❌ Patchscopes library not found. Run:<br/>" | |
| "<code>pip install git+https://github.com/schwartz-lab-NLP/Tokens2Words</code></p>" | |
| ) | |
| model, tokenizer = get_model_and_tokenizer(model_name) | |
| # Build extraction prompt (where hidden states will be collected) | |
| extraction_prompt ="X" | |
| # Identify last token position of the *word* inside the prompt IDs | |
| word_token_ids = tokenizer.encode(word, add_special_tokens=False) | |
| # Instantiate Patchscopes retriever | |
| patch_retriever = PatchscopesRetriever( | |
| model, | |
| tokenizer, | |
| extraction_prompt, | |
| patchscopes_template, | |
| prompt_target_placeholder="X", | |
| ) | |
| # Run retrieval for the word across all layers (one pass) | |
| retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word( | |
| word, | |
| num_tokens_to_generate=len(tokenizer.tokenize(word)), | |
| )[0] | |
| # Build a table summarising which layers match | |
| records = [] | |
| matches = 0 | |
| for layer_idx, ret_word in enumerate(retrieved_words): | |
| match = ret_word.strip(" ") == word.strip(" ") | |
| if match: | |
| matches += 1 | |
| records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""}) | |
| df = pd.DataFrame(records) | |
| def _style(row): | |
| color = "background-color: lightgreen" if row["Match?"] else "" | |
| return [color] * len(row) | |
| html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False) | |
| sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids) | |
| top = ( | |
| f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>" | |
| f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>" | |
| ) | |
| return top + html_table | |
| # ----------------------------- GRADIO UI ------------------------------- | |
| with gr.Blocks(theme="soft") as demo: | |
| gr.Markdown( | |
| """# Tokens→Words Viewer\nInteractively inspect how hidden‑state patching (Patchscopes) reveals a word's detokenised representation across model layers.""" | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| label="🤖 Model", | |
| choices=[DEFAULT_MODEL, "mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b", "Qwen/Qwen2-7B"], | |
| value=DEFAULT_MODEL, | |
| ) | |
| extraction_template = gr.Textbox( | |
| label="Extraction prompt (use X as placeholder)", | |
| value="repeat the following word X twice: 1)X 2)", | |
| ) | |
| patchscopes_template = gr.Textbox( | |
| label="Patchscopes prompt (use X as placeholder)", | |
| value="repeat the following word X twice: 1)X 2)", | |
| ) | |
| word_box = gr.Textbox(label="Word to test", value="interpretable") | |
| run_btn = gr.Button("Analyse") | |
| out_html = gr.HTML() | |
| run_btn.click( | |
| analyse_word, | |
| inputs=[model_name, extraction_template, word_box, patchscopes_template], | |
| outputs=out_html, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |