Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import List, Optional | |
| class NodeType(Enum): | |
| AFTER_ATTN = "after_attn" | |
| AFTER_FFN = "after_ffn" | |
| FFN = "ffn" | |
| ORIGINAL = "original" # The original tokens | |
| def _format_block_hierachy_string(blocks: List[str]) -> str: | |
| return " βΈ ".join(blocks) | |
| class GraphNode: | |
| layer: int | |
| token: int | |
| type: NodeType | |
| def is_in_residual_stream(self) -> bool: | |
| return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN] | |
| def get_residual_predecessor(self) -> Optional["GraphNode"]: | |
| """ | |
| Get another graph node which points to the state of the residual stream before | |
| this node. | |
| Retun None if current representation is the first one in the residual stream. | |
| """ | |
| scheme = { | |
| NodeType.AFTER_ATTN: GraphNode( | |
| layer=max(self.layer - 1, 0), | |
| token=self.token, | |
| type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL, | |
| ), | |
| NodeType.AFTER_FFN: GraphNode( | |
| layer=self.layer, | |
| token=self.token, | |
| type=NodeType.AFTER_ATTN, | |
| ), | |
| NodeType.FFN: GraphNode( | |
| layer=self.layer, | |
| token=self.token, | |
| type=NodeType.AFTER_ATTN, | |
| ), | |
| NodeType.ORIGINAL: None, | |
| } | |
| node = scheme[self.type] | |
| if node.layer < 0: | |
| return None | |
| return node | |
| def get_name(self) -> str: | |
| return _format_block_hierachy_string( | |
| [f"L{self.layer}", f"T{self.token}", str(self.type.value)] | |
| ) | |
| def get_predecessor_block_name(self) -> str: | |
| """ | |
| Return the name of the block standing between current node and its predecessor | |
| in the residual stream. | |
| """ | |
| scheme = { | |
| NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"], | |
| NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"], | |
| NodeType.FFN: [f"L{self.layer}", "ffn"], | |
| NodeType.ORIGINAL: ["Nothing"], | |
| } | |
| return _format_block_hierachy_string(scheme[self.type]) | |
| def get_head_name(self, head: Optional[int]) -> str: | |
| path = [f"L{self.layer}", "attn"] | |
| if head is not None: | |
| path.append(f"H{head}") | |
| return _format_block_hierachy_string(path) | |
| def get_neuron_name(self, neuron: Optional[int]) -> str: | |
| path = [f"L{self.layer}", "ffn"] | |
| if neuron is not None: | |
| path.append(f"N{neuron}") | |
| return _format_block_hierachy_string(path) | |