Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import torch | |
| import transformer_lens | |
| import transformers | |
| from fancy_einsum import einsum | |
| from jaxtyping import Float, Int | |
| from typeguard import typechecked | |
| import streamlit as st | |
| from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm | |
| class _RunInfo: | |
| tokens: Int[torch.Tensor, "batch pos"] | |
| logits: Float[torch.Tensor, "batch pos d_vocab"] | |
| cache: transformer_lens.ActivationCache | |
| def load_hooked_transformer( | |
| model_name: str, | |
| hf_model: Optional[transformers.PreTrainedModel] = None, | |
| tlens_device: str = "cuda", | |
| dtype: torch.dtype = torch.float32, | |
| ): | |
| # if tlens_device == "cuda": | |
| # n_devices = torch.cuda.device_count() | |
| # else: | |
| # n_devices = 1 | |
| tlens_model = transformer_lens.HookedTransformer.from_pretrained( | |
| model_name, | |
| hf_model=hf_model, | |
| fold_ln=False, # Keep layer norm where it is. | |
| center_writing_weights=False, | |
| center_unembed=False, | |
| device=tlens_device, | |
| # n_devices=n_devices, | |
| dtype=dtype, | |
| ) | |
| tlens_model.eval() | |
| return tlens_model | |
| # TODO(igortufanov): If we want to scale the app to multiple users, we need more careful | |
| # thread-safe implementation. The simplest option could be to wrap the existing methods | |
| # in mutexes. | |
| class TransformerLensTransparentLlm(TransparentLlm): | |
| """ | |
| Implementation of Transparent LLM based on transformer lens. | |
| Args: | |
| - model_name: The official name of the model from HuggingFace. Even if the model was | |
| patched or loaded locally, the name should still be official because that's how | |
| transformer_lens treats the model. | |
| - hf_model: The language model as a HuggingFace class. | |
| - tokenizer, | |
| - device: "gpu" or "cpu" | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str, | |
| hf_model: Optional[transformers.PreTrainedModel] = None, | |
| tokenizer: Optional[transformers.PreTrainedTokenizer] = None, | |
| device: str = "gpu", | |
| dtype: torch.dtype = torch.float32, | |
| ): | |
| if device == "gpu": | |
| self.device = "cuda" | |
| if not torch.cuda.is_available(): | |
| RuntimeError("Asked to run on gpu, but torch couldn't find cuda") | |
| elif device == "cpu": | |
| self.device = "cpu" | |
| else: | |
| raise RuntimeError(f"Specified device {device} is not a valid option") | |
| self.dtype = dtype | |
| self.hf_tokenizer = tokenizer | |
| self.hf_model = hf_model | |
| # self._model = tlens_model | |
| self._model_name = model_name | |
| self._prepend_bos = True | |
| self._last_run = None | |
| self._run_exception = RuntimeError( | |
| "Tried to use the model output before calling the `run` method" | |
| ) | |
| def copy(self): | |
| import copy | |
| return copy.copy(self) | |
| def _model(self): | |
| tlens_model = load_hooked_transformer( | |
| self._model_name, | |
| hf_model=self.hf_model, | |
| tlens_device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| if self.hf_tokenizer is not None: | |
| tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left") | |
| tlens_model.set_use_attn_result(True) | |
| tlens_model.set_use_attn_in(False) | |
| tlens_model.set_use_split_qkv_input(False) | |
| return tlens_model | |
| def model_info(self) -> ModelInfo: | |
| cfg = self._model.cfg | |
| return ModelInfo( | |
| name=self._model_name, | |
| n_params_estimate=cfg.n_params, | |
| n_layers=cfg.n_layers, | |
| n_heads=cfg.n_heads, | |
| d_model=cfg.d_model, | |
| d_vocab=cfg.d_vocab, | |
| ) | |
| def run(self, sentences: List[str]) -> None: | |
| tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos) | |
| logits, cache = self._model.run_with_cache(tokens) | |
| self._last_run = _RunInfo( | |
| tokens=tokens, | |
| logits=logits, | |
| cache=cache, | |
| ) | |
| def batch_size(self) -> int: | |
| if not self._last_run: | |
| raise self._run_exception | |
| return self._last_run.logits.shape[0] | |
| def tokens(self) -> Int[torch.Tensor, "batch pos"]: | |
| if not self._last_run: | |
| raise self._run_exception | |
| return self._last_run.tokens | |
| def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]: | |
| return self._model.to_str_tokens(tokens) | |
| def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]: | |
| if not self._last_run: | |
| raise self._run_exception | |
| return self._last_run.logits | |
| def unembed( | |
| self, | |
| t: Float[torch.Tensor, "d_model"], | |
| normalize: bool, | |
| ) -> Float[torch.Tensor, "vocab"]: | |
| # t: [d_model] -> [batch, pos, d_model] | |
| tdim = t.unsqueeze(0).unsqueeze(0) | |
| if normalize: | |
| normalized = self._model.ln_final(tdim) | |
| result = self._model.unembed(normalized) | |
| else: | |
| result = self._model.unembed(tdim) | |
| return result[0][0] | |
| def _get_block(self, layer: int, block_name: str) -> str: | |
| if not self._last_run: | |
| raise self._run_exception | |
| return self._last_run.cache[f"blocks.{layer}.{block_name}"] | |
| # ================= Methods related to the residual stream ================= | |
| def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: | |
| if not self._last_run: | |
| raise self._run_exception | |
| return self._get_block(layer, "hook_resid_pre") | |
| def residual_after_attn( | |
| self, layer: int | |
| ) -> Float[torch.Tensor, "batch pos d_model"]: | |
| if not self._last_run: | |
| raise self._run_exception | |
| return self._get_block(layer, "hook_resid_mid") | |
| def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: | |
| if not self._last_run: | |
| raise self._run_exception | |
| return self._get_block(layer, "hook_resid_post") | |
| # ================ Methods related to the feed-forward layer =============== | |
| def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: | |
| if not self._last_run: | |
| raise self._run_exception | |
| return self._get_block(layer, "hook_mlp_out") | |
| def decomposed_ffn_out( | |
| self, | |
| batch_i: int, | |
| layer: int, | |
| pos: int, | |
| ) -> Float[torch.Tensor, "hidden d_model"]: | |
| # Take activations right before they're multiplied by W_out, i.e. non-linearity | |
| # and layer norm are already applied. | |
| processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos] | |
| return torch.mul(processed_activations.unsqueeze(-1), self._model.W_out[layer]) | |
| def neuron_activations( | |
| self, | |
| batch_i: int, | |
| layer: int, | |
| pos: int, | |
| ) -> Float[torch.Tensor, "hidden"]: | |
| return self._get_block(layer, "mlp.hook_pre")[batch_i][pos] | |
| def neuron_output( | |
| self, | |
| layer: int, | |
| neuron: int, | |
| ) -> Float[torch.Tensor, "d_model"]: | |
| return self._model.W_out[layer][neuron] | |
| # ==================== Methods related to the attention ==================== | |
| def attention_matrix( | |
| self, batch_i: int, layer: int, head: int | |
| ) -> Float[torch.Tensor, "query_pos key_pos"]: | |
| return self._get_block(layer, "attn.hook_pattern")[batch_i][head] | |
| def attention_output_per_head( | |
| self, | |
| batch_i: int, | |
| layer: int, | |
| pos: int, | |
| head: int, | |
| ) -> Float[torch.Tensor, "d_model"]: | |
| return self._get_block(layer, "attn.hook_result")[batch_i][pos][head] | |
| def attention_output( | |
| self, | |
| batch_i: int, | |
| layer: int, | |
| pos: int, | |
| ) -> Float[torch.Tensor, "d_model"]: | |
| return self._get_block(layer, "hook_attn_out")[batch_i][pos] | |
| def decomposed_attn( | |
| self, batch_i: int, layer: int | |
| ) -> Float[torch.Tensor, "pos key_pos head d_model"]: | |
| if not self._last_run: | |
| raise self._run_exception | |
| hook_v = self._get_block(layer, "attn.hook_v")[batch_i] | |
| b_v = self._model.b_V[layer] | |
| v = hook_v + b_v | |
| pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype) | |
| z = einsum( | |
| "key_pos head d_head, " | |
| "head query_pos key_pos -> " | |
| "query_pos key_pos head d_head", | |
| v, | |
| pattern, | |
| ) | |
| decomposed_attn = einsum( | |
| "pos key_pos head d_head, " | |
| "head d_head d_model -> " | |
| "pos key_pos head d_model", | |
| z, | |
| self._model.W_O[layer], | |
| ) | |
| return decomposed_attn | |