Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from typing import List | |
| import torch | |
| from jaxtyping import Float, Int | |
| class ModelInfo: | |
| name: str | |
| # Not the actual number of parameters, but rather the order of magnitude | |
| n_params_estimate: int | |
| n_layers: int | |
| n_heads: int | |
| d_model: int | |
| d_vocab: int | |
| class TransparentLlm(ABC): | |
| """ | |
| An abstract stateful interface for a language model. The model is supposed to be | |
| loaded at the class initialization. | |
| The internal state is the resulting tensors from the last call of the `run` method. | |
| Most of the methods could return values based on the state, but some may do cheap | |
| computations based on them. | |
| """ | |
| def model_info(self) -> ModelInfo: | |
| """ | |
| Gives general info about the model. This method must be available before any | |
| calls of the `run`. | |
| """ | |
| pass | |
| def run(self, sentences: List[str]) -> None: | |
| """ | |
| Run the inference on the given sentences in a single batch and store all | |
| necessary info in the internal state. | |
| """ | |
| pass | |
| def batch_size(self) -> int: | |
| """ | |
| The size of the batch that was used for the last call of `run`. | |
| """ | |
| pass | |
| def tokens(self) -> Int[torch.Tensor, "batch pos"]: | |
| pass | |
| def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]: | |
| pass | |
| def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]: | |
| pass | |
| def unembed( | |
| self, | |
| t: Float[torch.Tensor, "d_model"], | |
| normalize: bool, | |
| ) -> Float[torch.Tensor, "vocab"]: | |
| """ | |
| Project the given vector (for example, the state of the residual stream for a | |
| layer and token) into the output vocabulary. | |
| normalize: whether to apply the final normalization before the unembedding. | |
| Setting it to True and applying to output of the last layer gives the output of | |
| the model. | |
| """ | |
| pass | |
| # ================= Methods related to the residual stream ================= | |
| def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: | |
| """ | |
| The state of the residual stream before entering the layer. For example, when | |
| layer == 0 these must the embedded tokens (including positional embedding). | |
| """ | |
| pass | |
| def residual_after_attn( | |
| self, layer: int | |
| ) -> Float[torch.Tensor, "batch pos d_model"]: | |
| """ | |
| The state of the residual stream after attention, but before the FFN in the | |
| given layer. | |
| """ | |
| pass | |
| def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: | |
| """ | |
| The state of the residual stream after the given layer. This is equivalent to the | |
| next layer's input. | |
| """ | |
| pass | |
| # ================ Methods related to the feed-forward layer =============== | |
| def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: | |
| """ | |
| The output of the FFN layer, before it gets merged into the residual stream. | |
| """ | |
| pass | |
| def decomposed_ffn_out( | |
| self, | |
| batch_i: int, | |
| layer: int, | |
| pos: int, | |
| ) -> Float[torch.Tensor, "hidden d_model"]: | |
| """ | |
| A collection of vectors added to the residual stream by each neuron. It should | |
| be the same as neuron activations multiplied by neuron outputs. | |
| """ | |
| pass | |
| def neuron_activations( | |
| self, | |
| batch_i: int, | |
| layer: int, | |
| pos: int, | |
| ) -> Float[torch.Tensor, "d_ffn"]: | |
| """ | |
| The content of the hidden layer right after the activation function was applied. | |
| """ | |
| pass | |
| def neuron_output( | |
| self, | |
| layer: int, | |
| neuron: int, | |
| ) -> Float[torch.Tensor, "d_model"]: | |
| """ | |
| Return the value that the given neuron adds to the residual stream. It's a raw | |
| vector from the model parameters, no activation involved. | |
| """ | |
| pass | |
| # ==================== Methods related to the attention ==================== | |
| def attention_matrix( | |
| self, batch_i, layer: int, head: int | |
| ) -> Float[torch.Tensor, "query_pos key_pos"]: | |
| """ | |
| Return a lower-diagonal attention matrix. | |
| """ | |
| pass | |
| def attention_output( | |
| self, | |
| batch_i: int, | |
| layer: int, | |
| pos: int, | |
| head: int, | |
| ) -> Float[torch.Tensor, "d_model"]: | |
| """ | |
| Return what the given head at the given layer and pos added to the residual | |
| stream. | |
| """ | |
| pass | |
| def decomposed_attn( | |
| self, batch_i: int, layer: int | |
| ) -> Float[torch.Tensor, "source target head d_model"]: | |
| """ | |
| Here | |
| - source: index of token from the previous layer | |
| - target: index of token on the current layer | |
| The decomposed attention tells what vector from source representation was used | |
| in order to contribute to the taget representation. | |
| """ | |
| pass | |