Spaces:
Running
Running
| import os | |
| from huggingface_hub import login | |
| # run once at startup | |
| if "HF_TOKEN" in os.environ and os.environ["HF_TOKEN"]: | |
| login(token=os.environ["HF_TOKEN"]) | |
| # app.py | |
| import os; os.environ.setdefault('HF_HOME', '/data/hf-cache') | |
| os.environ.setdefault('HF_HUB_ENABLE_HF_TRANSFER', '1') | |
| from huggingface_hub import login | |
| hf_token = os.getenv("HF_TOKEN", "") | |
| if hf_token: | |
| login(token=hf_token) | |
| try: | |
| from spaces import GPU | |
| except ImportError: | |
| # For local testing, create a no-op decorator | |
| def GPU(f): | |
| return f | |
| import torch | |
| from exceptiongroup import catch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import pandas as pd | |
| from functools import lru_cache | |
| # ---------------------------------------------------------------------- | |
| # IMPORTANT: This version uses the PatchscopesRetriever implementation | |
| # from the Tokens2Words paper (https://github.com/schwartz-lab-NLP/Tokens2Words) | |
| # ---------------------------------------------------------------------- | |
| import torch | |
| from tqdm import tqdm | |
| from abc import ABC, abstractmethod | |
| from enums import MultiTokenKind, RetrievalTechniques | |
| from processor import RetrievalProcessor | |
| from logit_lens import ReverseLogitLens | |
| from model_utils import extract_token_i_hidden_states | |
| class WordRetrieverBase(ABC): | |
| def __init__(self, model, tokenizer): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3): | |
| pass | |
| class PatchscopesRetriever(WordRetrieverBase): | |
| def __init__( | |
| self, | |
| model, | |
| tokenizer, | |
| representation_prompt: str = "{word}", | |
| patchscopes_prompt: str = "Next is the same word twice: 1) {word} 2)", | |
| prompt_target_placeholder: str = "{word}", | |
| representation_token_idx_to_extract: int = -1, | |
| num_tokens_to_generate: int = 10, | |
| ): | |
| super().__init__(model, tokenizer) | |
| self.prompt_input_ids, self.prompt_target_idx = \ | |
| self._build_prompt_input_ids_template(patchscopes_prompt, prompt_target_placeholder) | |
| self._prepare_representation_prompt = \ | |
| self._build_representation_prompt_func(representation_prompt, prompt_target_placeholder) | |
| self.representation_token_idx = representation_token_idx_to_extract | |
| self.num_tokens_to_generate = num_tokens_to_generate | |
| def _build_prompt_input_ids_template(self, prompt, target_placeholder): | |
| prompt_input_ids = [self.tokenizer.bos_token_id] if self.tokenizer.bos_token_id is not None else [] | |
| target_idx = [] | |
| if prompt: | |
| assert target_placeholder is not None, \ | |
| "Trying to set a prompt for Patchscopes without defining the prompt's target placeholder string, e.g., [MASK]" | |
| prompt_parts = prompt.split(target_placeholder) | |
| for part_i, prompt_part in enumerate(prompt_parts): | |
| prompt_input_ids += self.tokenizer.encode(prompt_part, add_special_tokens=False) | |
| if part_i < len(prompt_parts)-1: | |
| target_idx += [len(prompt_input_ids)] | |
| prompt_input_ids += [0] | |
| else: | |
| prompt_input_ids += [0] | |
| target_idx = [len(prompt_input_ids)] | |
| prompt_input_ids = torch.tensor(prompt_input_ids, dtype=torch.long) | |
| target_idx = torch.tensor(target_idx, dtype=torch.long) | |
| return prompt_input_ids, target_idx | |
| def _build_representation_prompt_func(self, prompt, target_placeholder): | |
| return lambda word: prompt.replace(target_placeholder, word) | |
| def generate_states(self, tokenizer, word='Wakanda', with_prompt=True): | |
| prompt = self.generate_prompt() if with_prompt else word | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt') | |
| return input_ids | |
| def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=None): | |
| self.model.eval() | |
| # insert hidden states into patchscopes prompt | |
| if hidden_states.dim() == 1: | |
| hidden_states = hidden_states.unsqueeze(0) | |
| inputs_embeds = self.model.get_input_embeddings()(self.prompt_input_ids.to(self.model.device)).unsqueeze(0) | |
| batched_patchscope_inputs = inputs_embeds.repeat(len(hidden_states), 1, 1).to(hidden_states.dtype) | |
| batched_patchscope_inputs[:, self.prompt_target_idx] = hidden_states.unsqueeze(1).to(self.model.device) | |
| attention_mask = (self.prompt_input_ids != self.tokenizer.eos_token_id).long().unsqueeze(0).repeat( | |
| len(hidden_states), 1).to(self.model.device) | |
| num_tokens_to_generate = num_tokens_to_generate if num_tokens_to_generate else self.num_tokens_to_generate | |
| with torch.no_grad(): | |
| patchscope_outputs = self.model.generate( | |
| do_sample=False, num_beams=1, top_p=1.0, temperature=None, | |
| inputs_embeds=batched_patchscope_inputs,# attention_mask=attention_mask, | |
| max_new_tokens=num_tokens_to_generate, pad_token_id=self.tokenizer.eos_token_id, ) | |
| decoded_patchscope_outputs = self.tokenizer.batch_decode(patchscope_outputs) | |
| return decoded_patchscope_outputs | |
| def extract_hidden_states(self, word): | |
| representation_input = self._prepare_representation_prompt(word) | |
| last_token_hidden_states = extract_token_i_hidden_states( | |
| self.model, self.tokenizer, representation_input, token_idx_to_extract=self.representation_token_idx, return_dict=False, verbose=False) | |
| return last_token_hidden_states | |
| def get_hidden_states_and_retrieve_word(self, word, num_tokens_to_generate=None): | |
| last_token_hidden_states = self.extract_hidden_states(word) | |
| patchscopes_description_by_layers = self.retrieve_word( | |
| last_token_hidden_states, num_tokens_to_generate=num_tokens_to_generate) | |
| return patchscopes_description_by_layers, last_token_hidden_states | |
| class ReverseLogitLensRetriever(WordRetrieverBase): | |
| def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16): | |
| super().__init__(model, tokenizer) | |
| self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype) | |
| def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3): | |
| result = self.reverse_logit_lens(hidden_states, layer_idx) | |
| token = self.tokenizer.decode(torch.argmax(result, dim=-1).item()) | |
| return token | |
| class AnalysisWordRetriever: | |
| def __init__(self, model, tokenizer, multi_token_kind, num_tokens_to_generate=1, add_context=True, | |
| model_name='LLaMa-2B', device='cuda', dataset=None): | |
| self.model = model.to(device) | |
| self.tokenizer = tokenizer | |
| self.multi_token_kind = multi_token_kind | |
| self.num_tokens_to_generate = num_tokens_to_generate | |
| self.add_context = add_context | |
| self.model_name = model_name | |
| self.device = device | |
| self.dataset = dataset | |
| self.retriever = self._initialize_retriever() | |
| self.RetrievalTechniques = (RetrievalTechniques.Patchscopes if self.multi_token_kind == MultiTokenKind.Natural | |
| else RetrievalTechniques.ReverseLogitLens) | |
| self.whitespace_token = 'Ġ' if model_name in ['gemma-2-9b', 'pythia-6.9b', 'LLaMA3-8B', 'Yi-6B'] else '▁' | |
| self.processor = RetrievalProcessor(self.model, self.tokenizer, self.multi_token_kind, | |
| self.num_tokens_to_generate, self.add_context, self.model_name, | |
| self.whitespace_token) | |
| def _initialize_retriever(self): | |
| if self.multi_token_kind == MultiTokenKind.Natural: | |
| return PatchscopesRetriever(self.model, self.tokenizer) | |
| else: | |
| return ReverseLogitLensRetriever(self.model, self.tokenizer) | |
| def retrieve_words_in_dataset(self, number_of_examples_to_retrieve=2, max_length=1000): | |
| self.model.eval() | |
| results = [] | |
| for text in tqdm(self.dataset['train']['text'][:number_of_examples_to_retrieve], self.model_name): | |
| tokenized_input = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length).to( | |
| self.device) | |
| tokens = tokenized_input.input_ids[0] | |
| print(f'Processing text: {text}') | |
| i = 5 | |
| while i < len(tokens): | |
| if self.multi_token_kind == MultiTokenKind.Natural: | |
| j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_word( | |
| tokens, i, device=self.device) | |
| elif self.multi_token_kind == MultiTokenKind.Typo: | |
| j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_full_word_typo( | |
| tokens, i, device=self.device) | |
| else: | |
| j, word_tokens, word, context, tokenized_combined_text, combined_text, original_word = self.processor.get_next_full_word_separated( | |
| tokens, i, device=self.device) | |
| if len(word_tokens) > 1: | |
| with torch.no_grad(): | |
| outputs = self.model(**tokenized_combined_text, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states | |
| for layer_idx, hidden_state in enumerate(hidden_states): | |
| postfix_hidden_state = hidden_states[layer_idx][0, -1, :].unsqueeze(0) | |
| retrieved_word_str = self.retriever.retrieve_word(postfix_hidden_state, layer_idx=layer_idx, | |
| num_tokens_to_generate=len(word_tokens)) | |
| results.append({ | |
| 'text': combined_text, | |
| 'original_word': original_word, | |
| 'word': word, | |
| 'word_tokens': self.tokenizer.convert_ids_to_tokens(word_tokens), | |
| 'num_tokens': len(word_tokens), | |
| 'layer': layer_idx, | |
| 'retrieved_word_str': retrieved_word_str, | |
| 'context': "With Context" if self.add_context else "Without Context" | |
| }) | |
| else: | |
| i = j | |
| return results | |
| DEFAULT_MODEL = "meta-llama/Llama-3.1-8B" # light default so the demo boots everywhere | |
| DEVICE = ( | |
| "cuda" if torch.cuda.is_available() else 'cpu' | |
| ) | |
| def get_model_and_tokenizer(model_name: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16 , | |
| output_hidden_states=True, | |
| ).to(DEVICE) | |
| model.eval() | |
| return model, tokenizer | |
| def find_last_token_index(full_ids, word_ids): | |
| """Locate end position of word_ids inside full_ids (first match).""" | |
| for i in range(len(full_ids) - len(word_ids) + 1): | |
| if full_ids[i : i + len(word_ids)] == word_ids: | |
| return i + len(word_ids) - 1 | |
| return None | |
| # this block runs on a job GPU | |
| def analyse_word(model_name: str, word: str, patchscopes_template: str, context:str = ""): | |
| try: | |
| # text = context+ " " + word | |
| model, tokenizer = get_model_and_tokenizer(model_name) | |
| # Build extraction prompt (where hidden states will be collected) | |
| extraction_prompt ="X" | |
| # Identify last token position of the *word* inside the prompt IDs | |
| word_token_ids = tokenizer.encode(word, add_special_tokens=False) | |
| # Instantiate Patchscopes retriever | |
| patch_retriever = PatchscopesRetriever( | |
| model, | |
| tokenizer, | |
| extraction_prompt, | |
| patchscopes_template, | |
| prompt_target_placeholder="X", | |
| ) | |
| # Run retrieval for the word across all layers (one pass) | |
| retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word( | |
| word, | |
| num_tokens_to_generate=len(tokenizer.tokenize(word)), | |
| )[0] | |
| # Build a table summarising which layers match | |
| records = [] | |
| matches = 0 | |
| for layer_idx, ret_word in enumerate(retrieved_words): | |
| match = ret_word.strip(" ") == word.strip(" ") | |
| if match: | |
| matches += 1 | |
| records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""}) | |
| df = pd.DataFrame(records) | |
| def _style(row): | |
| color = "background-color: lightgreen" if row["Match?"] else "" | |
| return [color] * len(row) | |
| html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False) | |
| sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids) | |
| top = ( | |
| f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>" | |
| f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>" | |
| ) | |
| return top + html_table | |
| except Exception as e: | |
| return f"<p style='color:red'>❌ Error: {str(e)}</p>" | |
| # ----------------------------- GRADIO UI ------------------------------- | |
| with gr.Blocks(theme="soft") as demo: | |
| gr.Markdown( | |
| """# Tokens→Words Viewer\nInteractively inspect how hidden‑state patching (Patchscopes) reveals a word's detokenised representation across model layers.""" | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| label="🤖 Model", | |
| choices=[DEFAULT_MODEL, "mistralai/Mistral-7B-v0.1", "meta-llama/Llama-2-7b-hf", "Qwen/Qwen2-7B"], | |
| value=DEFAULT_MODEL, | |
| ) | |
| patchscopes_template = gr.Textbox( | |
| label="Patchscopes prompt (use X as placeholder)", | |
| value="repeat the following word X twice: 1)X 2)", | |
| ) | |
| # context_box = gr.Textbox(label="context", value="") | |
| word_box = gr.Textbox(label="Word to test", value="interpretable") | |
| run_btn = gr.Button("Analyse") | |
| out_html = gr.HTML() | |
| run_btn.click( | |
| analyse_word, | |
| inputs=[model_name, word_box, patchscopes_template], #, context_box], | |
| outputs=out_html, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |