Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Tuple | |
| import networkx as nx | |
| import pandas as pd | |
| import plotly.express | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| import streamlit_extras.row as st_row | |
| import torch | |
| from jaxtyping import Float | |
| from torch.amp import autocast | |
| from transformers import HfArgumentParser | |
| import llm_transparency_tool.components | |
| from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm | |
| import llm_transparency_tool.routes.contributions as contributions | |
| import llm_transparency_tool.routes.graph | |
| from llm_transparency_tool.models.transparent_llm import TransparentLlm | |
| from llm_transparency_tool.routes.graph_node import NodeType | |
| from llm_transparency_tool.server.graph_selection import ( | |
| GraphSelection, | |
| UiGraphEdge, | |
| UiGraphNode, | |
| ) | |
| from llm_transparency_tool.server.styles import ( | |
| RenderSettings, | |
| logits_color_map, | |
| margins_css, | |
| string_to_display, | |
| ) | |
| from llm_transparency_tool.server.utils import ( | |
| B0, | |
| get_contribution_graph, | |
| load_dataset, | |
| load_model, | |
| possible_devices, | |
| run_model_with_session_caching, | |
| st_placeholder, | |
| ) | |
| from llm_transparency_tool.server.monitor import SystemMonitor | |
| from networkx.classes.digraph import DiGraph | |
| def cached_build_paths_to_predictions( | |
| graph: nx.Graph, | |
| n_layers: int, | |
| n_tokens: int, | |
| starting_tokens: List[int], | |
| threshold: float, | |
| ): | |
| return llm_transparency_tool.routes.graph.build_paths_to_predictions( | |
| graph, n_layers, n_tokens, starting_tokens, threshold | |
| ) | |
| def cached_run_inference_and_populate_state( | |
| stateless_model, | |
| sentences, | |
| ): | |
| stateful_model = stateless_model.copy() | |
| stateful_model.run(sentences) | |
| return stateful_model | |
| class LlmViewerConfig: | |
| debug: bool = field( | |
| default=False, | |
| metadata={"help": "Show debugging information, like the time profile."}, | |
| ) | |
| preloaded_dataset_filename: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "The name of the text file to load the lines from."}, | |
| ) | |
| demo_mode: bool = field( | |
| default=False, | |
| metadata={"help": "Whether the app should be in the demo mode."}, | |
| ) | |
| allow_loading_dataset_files: bool = field( | |
| default=True, | |
| metadata={"help": "Whether the app should be able to load the dataset files " "on the server side."}, | |
| ) | |
| max_user_string_length: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Limit for the length of user-provided sentences (in characters), " "or None if there is no limit." | |
| }, | |
| ) | |
| models: Dict[str, str] = field( | |
| default_factory=dict, | |
| metadata={ | |
| "help": "Locations of models which are stored locally. Dictionary: official " | |
| "HuggingFace name -> path to dir. If None is specified, the model will be" | |
| "downloaded from HuggingFace." | |
| }, | |
| ) | |
| default_model: str = field( | |
| default="", | |
| metadata={"help": "The model to load once the UI is started."}, | |
| ) | |
| class App: | |
| _stateful_model: TransparentLlm = None | |
| render_settings = RenderSettings() | |
| _graph: Optional[nx.Graph] = None | |
| _contribution_threshold: float = 0.0 | |
| _renormalize_after_threshold: bool = False | |
| _normalize_before_unembedding: bool = True | |
| def stateful_model(self) -> TransparentLlm: | |
| return self._stateful_model | |
| def __init__(self, config: LlmViewerConfig): | |
| self._config = config | |
| st.set_page_config(layout="wide") | |
| st.markdown(margins_css, unsafe_allow_html=True) | |
| def _get_representation(self, node: Optional[UiGraphNode]) -> Optional[Float[torch.Tensor, "d_model"]]: | |
| if node is None: | |
| return None | |
| fn = { | |
| NodeType.AFTER_ATTN: self.stateful_model.residual_after_attn, | |
| NodeType.AFTER_FFN: self.stateful_model.residual_out, | |
| NodeType.FFN: None, | |
| NodeType.ORIGINAL: self.stateful_model.residual_in, | |
| } | |
| return fn[node.type](node.layer)[B0][node.token] | |
| def draw_model_info(self): | |
| info = self.stateful_model.model_info().__dict__ | |
| df = pd.DataFrame( | |
| data=[str(x) for x in info.values()], | |
| index=info.keys(), | |
| columns=["Model parameter"], | |
| ) | |
| st.dataframe(df, use_container_width=False) | |
| def draw_dataset_selection(self) -> int: | |
| def update_dataset(filename: Optional[str]): | |
| dataset = load_dataset(filename) if filename is not None else [] | |
| st.session_state["dataset"] = dataset | |
| st.session_state["dataset_file"] = filename | |
| if "dataset" not in st.session_state: | |
| update_dataset(self._config.preloaded_dataset_filename) | |
| if not self._config.demo_mode: | |
| if self._config.allow_loading_dataset_files: | |
| row_f = st_row.row([2, 1], vertical_align="bottom") | |
| filename = row_f.text_input("Dataset", value=st.session_state.dataset_file or "") | |
| if row_f.button("Load"): | |
| update_dataset(filename) | |
| row_s = st_row.row([2, 1], vertical_align="bottom") | |
| new_sentence = row_s.text_input("New sentence") | |
| new_sentence_added = False | |
| if row_s.button("Add"): | |
| max_len = self._config.max_user_string_length | |
| n = len(new_sentence) | |
| if max_len is None or n <= max_len: | |
| st.session_state.dataset.append(new_sentence) | |
| new_sentence_added = True | |
| st.session_state.sentence_selector = new_sentence | |
| else: | |
| st.warning(f"Sentence length {n} is larger than " f"the configured limit of {max_len}") | |
| sentences = st.session_state.dataset | |
| selection = st.selectbox( | |
| "Sentence", | |
| sentences, | |
| index=len(sentences) - 1, | |
| key="sentence_selector", | |
| ) | |
| return selection | |
| def _unembed( | |
| self, | |
| representation: torch.Tensor, | |
| ) -> torch.Tensor: | |
| return self.stateful_model.unembed(representation, normalize=self._normalize_before_unembedding) | |
| def draw_graph(self, contribution_threshold: float) -> Optional[GraphSelection]: | |
| tokens = self.stateful_model.tokens()[B0] | |
| n_tokens = tokens.shape[0] | |
| model_info = self.stateful_model.model_info() | |
| graphs = cached_build_paths_to_predictions( | |
| self._graph, | |
| model_info.n_layers, | |
| n_tokens, | |
| range(n_tokens), | |
| contribution_threshold, | |
| ) | |
| return llm_transparency_tool.components.contribution_graph( | |
| model_info, | |
| self.stateful_model.tokens_to_strings(tokens), | |
| graphs, | |
| key=f"graph_{hash(self.sentence)}", | |
| ) | |
| def draw_token_matrix( | |
| self, | |
| values: Float[torch.Tensor, "t t"], | |
| tokens: List[str], | |
| value_name: str, | |
| title: str, | |
| ): | |
| assert values.shape[0] == len(tokens) | |
| labels = { | |
| "x": "<b>src</b>", | |
| "y": "<b>tgt</b>", | |
| "color": value_name, | |
| } | |
| captions = [f"({i}){t}" for i, t in enumerate(tokens)] | |
| fig = plotly.express.imshow( | |
| values.cpu(), | |
| title=f'<b>{title}</b>', | |
| labels=labels, | |
| x=captions, | |
| y=captions, | |
| color_continuous_scale=self.render_settings.attention_color_map, | |
| aspect="equal", | |
| ) | |
| fig.update_layout( | |
| autosize=True, | |
| margin=go.layout.Margin( | |
| l=50, # left margin | |
| r=0, # right margin | |
| b=100, # bottom margin | |
| t=100, # top margin | |
| # pad=10 # padding | |
| ) | |
| ) | |
| fig.update_xaxes(tickmode="linear") | |
| fig.update_yaxes(tickmode="linear") | |
| fig.update_coloraxes(showscale=False) | |
| st.plotly_chart(fig, use_container_width=True, theme=None) | |
| def draw_attn_info(self, edge: UiGraphEdge, container_attention_map) -> Optional[int]: | |
| """ | |
| Returns: the index of the selected head. | |
| """ | |
| n_heads = self.stateful_model.model_info().n_heads | |
| layer = edge.target.layer | |
| head_contrib, _ = contributions.get_attention_contributions( | |
| resid_pre=self.stateful_model.residual_in(layer)[B0].unsqueeze(0), | |
| resid_mid=self.stateful_model.residual_after_attn(layer)[B0].unsqueeze(0), | |
| decomposed_attn=self.stateful_model.decomposed_attn(B0, layer).unsqueeze(0), | |
| ) | |
| # [batch pos key_pos head] -> [head] | |
| flat_contrib = head_contrib[0, edge.target.token, edge.source.token, :] | |
| assert flat_contrib.shape[0] == n_heads, f"{flat_contrib.shape} vs {n_heads}" | |
| selected_head = llm_transparency_tool.components.selector( | |
| items=[f"H{h}" if h >= 0 else "All" for h in range(-1, n_heads)], | |
| indices=range(-1, n_heads), | |
| temperatures=[sum(flat_contrib).item()] + flat_contrib.tolist(), | |
| preselected_index=flat_contrib.argmax().item(), | |
| key=f"head_selector_layer_{layer}" #_from_tok_{edge.source.token}_to_tok_{edge.target.token}", | |
| ) | |
| print(f"head_selector_layer_{layer}_from_tok_{edge.source.token}_to_tok_{edge.target.token}") | |
| if selected_head == -1 or selected_head is None: | |
| # selected_head = None | |
| selected_head = flat_contrib.argmax().item() | |
| print('****\n' * 3 + f"selected_head: {selected_head}" + '\n****\n' * 3) | |
| # Draw attention matrix and contributions for the selected head. | |
| if selected_head is not None: | |
| tokens = [ | |
| string_to_display(s) for s in self.stateful_model.tokens_to_strings(self.stateful_model.tokens()[B0]) | |
| ] | |
| with container_attention_map: | |
| attn_container, contrib_container = st.columns([1, 1]) | |
| with attn_container: | |
| attn = self.stateful_model.attention_matrix(B0, layer, selected_head) | |
| self.draw_token_matrix( | |
| attn, | |
| tokens, | |
| "attention", | |
| f"Attention map L{layer} H{selected_head}", | |
| ) | |
| with contrib_container: | |
| contrib = head_contrib[B0, :, :, selected_head] | |
| self.draw_token_matrix( | |
| contrib, | |
| tokens, | |
| "contribution", | |
| f"Contribution map L{layer} H{selected_head}", | |
| ) | |
| return selected_head | |
| def draw_ffn_info(self, node: UiGraphNode) -> Optional[int]: | |
| """ | |
| Returns: the index of the selected neuron. | |
| """ | |
| resid_mid = self.stateful_model.residual_after_attn(node.layer)[B0][node.token] | |
| resid_post = self.stateful_model.residual_out(node.layer)[B0][node.token] | |
| decomposed_ffn = self.stateful_model.decomposed_ffn_out(B0, node.layer, node.token) | |
| c_ffn, _ = contributions.get_decomposed_mlp_contributions(resid_mid, resid_post, decomposed_ffn) | |
| top_values, top_i = c_ffn.sort(descending=True) | |
| n = min(self.render_settings.n_top_neurons, c_ffn.shape[0]) | |
| top_neurons = top_i[0:n].tolist() | |
| selected_neuron = llm_transparency_tool.components.selector( | |
| items=[f"{top_neurons[i]}" if i >= 0 else "All" for i in range(-1, n)], | |
| indices=range(-1, n), | |
| temperatures=[0.0] + top_values[0:n].tolist(), | |
| preselected_index=-1, | |
| key="neuron_selector", | |
| ) | |
| if selected_neuron is None: | |
| selected_neuron = -1 | |
| selected_neuron = None if selected_neuron == -1 else top_neurons[selected_neuron] | |
| return selected_neuron | |
| def _draw_token_table( | |
| self, | |
| n_top: int, | |
| n_bottom: int, | |
| representation: torch.Tensor, | |
| predecessor: Optional[torch.Tensor] = None, | |
| ): | |
| n_total = n_top + n_bottom | |
| logits = self._unembed(representation) | |
| n_vocab = logits.shape[0] | |
| scores, indices = torch.topk(logits, n_top, largest=True) | |
| positions = list(range(n_top)) | |
| if n_bottom > 0: | |
| low_scores, low_indices = torch.topk(logits, n_bottom, largest=False) | |
| indices = torch.cat((indices, low_indices.flip(0))) | |
| scores = torch.cat((scores, low_scores.flip(0))) | |
| positions += range(n_vocab - n_bottom, n_vocab) | |
| tokens = [string_to_display(w) for w in self.stateful_model.tokens_to_strings(indices)] | |
| if predecessor is not None: | |
| pre_logits = self._unembed(predecessor) | |
| _, sorted_pre_indices = pre_logits.sort(descending=True) | |
| pre_indices_dict = {index: pos for pos, index in enumerate(sorted_pre_indices.tolist())} | |
| old_positions = [pre_indices_dict[i] for i in indices.tolist()] | |
| def pos_gain_string(pos, old_pos): | |
| if pos == old_pos: | |
| return "" | |
| sign = "β" if pos > old_pos else "β" | |
| return f"({sign}{abs(pos - old_pos)})" | |
| position_strings = [f"{i} {pos_gain_string(i, old_i)}" for (i, old_i) in zip(positions, old_positions)] | |
| else: | |
| position_strings = [str(pos) for pos in positions] | |
| def pos_gain_color(s): | |
| color = "black" | |
| if isinstance(s, str): | |
| if "β" in s: | |
| color = "red" | |
| if "β" in s: | |
| color = "green" | |
| return f"color: {color}" | |
| top_df = pd.DataFrame( | |
| data=zip(position_strings, tokens, scores.tolist()), | |
| columns=["Pos", "Token", "Score"], | |
| ) | |
| st.dataframe( | |
| top_df.style.map(pos_gain_color) | |
| .background_gradient( | |
| axis=0, | |
| cmap=logits_color_map(positive_and_negative=n_bottom > 0), | |
| ) | |
| .format(precision=3), | |
| hide_index=True, | |
| height=self.render_settings.table_cell_height * (n_total + 1), | |
| use_container_width=True, | |
| ) | |
| def draw_token_dynamics(self, representation: torch.Tensor, block_name: str) -> None: | |
| st.caption(block_name) | |
| self._draw_token_table( | |
| self.render_settings.n_promoted_tokens, | |
| self.render_settings.n_suppressed_tokens, | |
| representation, | |
| None, | |
| ) | |
| def draw_top_tokens( | |
| self, | |
| node: UiGraphNode, | |
| container_top_tokens, | |
| container_token_dynamics, | |
| ) -> None: | |
| pre_node = node.get_residual_predecessor() | |
| if pre_node is None: | |
| return | |
| representation = self._get_representation(node) | |
| predecessor = self._get_representation(pre_node) | |
| with container_top_tokens: | |
| st.caption(node.get_name()) | |
| self._draw_token_table( | |
| self.render_settings.n_top_tokens, | |
| 0, | |
| representation, | |
| predecessor, | |
| ) | |
| if container_token_dynamics is not None: | |
| with container_token_dynamics: | |
| self.draw_token_dynamics(representation - predecessor, node.get_predecessor_block_name()) | |
| def draw_attention_dynamics(self, node: UiGraphNode, head: Optional[int]): | |
| block_name = node.get_head_name(head) | |
| block_output = ( | |
| self.stateful_model.attention_output_per_head(B0, node.layer, node.token, head) | |
| if head is not None | |
| else self.stateful_model.attention_output(B0, node.layer, node.token) | |
| ) | |
| self.draw_token_dynamics(block_output, block_name) | |
| def draw_ffn_dynamics(self, node: UiGraphNode, neuron: Optional[int]): | |
| block_name = node.get_neuron_name(neuron) | |
| block_output = ( | |
| self.stateful_model.neuron_output(node.layer, neuron) | |
| if neuron is not None | |
| else self.stateful_model.ffn_out(node.layer)[B0][node.token] | |
| ) | |
| self.draw_token_dynamics(block_output, block_name) | |
| def draw_precision_controls(self, device: str) -> Tuple[torch.dtype, bool]: | |
| """ | |
| Draw fp16/fp32 switch and AMP control. | |
| return: The selected precision and whether AMP should be enabled. | |
| """ | |
| if device == "cpu": | |
| dtype = torch.float32 | |
| else: | |
| dtype = st.selectbox( | |
| "Precision", | |
| [torch.float16, torch.bfloat16, torch.float32], | |
| index=0, | |
| ) | |
| amp_enabled = dtype != torch.float32 | |
| return dtype, amp_enabled | |
| def draw_controls(self): | |
| # model_container, data_container = st.columns([1, 1]) | |
| with st.sidebar.expander("Model", expanded=True): | |
| list_of_devices = possible_devices() | |
| if len(list_of_devices) > 1: | |
| self.device = st.selectbox( | |
| "Device", | |
| possible_devices(), | |
| index=0, | |
| ) | |
| else: | |
| self.device = list_of_devices[0] | |
| self.dtype, self.amp_enabled = self.draw_precision_controls(self.device) | |
| model_list = list(self._config.models) | |
| default_choice = model_list.index(self._config.default_model) | |
| self.model_name = st.selectbox( | |
| "Model", | |
| model_list, | |
| index=default_choice, | |
| ) | |
| if self.model_name: | |
| self._stateful_model = load_model( | |
| model_name=self.model_name, | |
| _model_path=self._config.models[self.model_name], | |
| _device=self.device, | |
| _dtype=self.dtype, | |
| ) | |
| self.model_key = self.model_name # TODO maybe something else? | |
| self.draw_model_info() | |
| self.sentence = self.draw_dataset_selection() | |
| with st.sidebar.expander("Graph", expanded=True): | |
| self._contribution_threshold = st.slider( | |
| min_value=0.01, | |
| max_value=0.1, | |
| step=0.01, | |
| value=0.04, | |
| format=r"%.3f", | |
| label="Contribution threshold", | |
| ) | |
| self._renormalize_after_threshold = st.checkbox("Renormalize after threshold", value=True) | |
| self._normalize_before_unembedding = st.checkbox("Normalize before unembedding", value=True) | |
| def run_inference(self): | |
| with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype): | |
| self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence]) | |
| with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype): | |
| self._graph = get_contribution_graph( | |
| self.stateful_model, | |
| self.model_key, | |
| self.stateful_model.tokens()[B0].tolist(), | |
| (self._contribution_threshold if self._renormalize_after_threshold else 0.0), | |
| ) | |
| def draw_graph_and_selection( | |
| self, | |
| ) -> None: | |
| ( | |
| container_graph, | |
| container_tokens, | |
| ) = st.columns(self.render_settings.column_proportions) | |
| container_graph_left, container_graph_right = container_graph.columns([5, 1]) | |
| container_graph_left.write('##### Graph') | |
| heads_placeholder = container_graph_right.empty() | |
| heads_placeholder.write('##### Blocks') | |
| container_graph_right_used = False | |
| container_top_tokens, container_token_dynamics = container_tokens.columns([1, 1]) | |
| container_top_tokens.write('##### Top Tokens') | |
| container_top_tokens_used = False | |
| container_token_dynamics.write('##### Promoted Tokens') | |
| container_token_dynamics_used = False | |
| try: | |
| if self.sentence is None: | |
| return | |
| with container_graph_left: | |
| selection = self.draw_graph(self._contribution_threshold if not self._renormalize_after_threshold else 0.0) | |
| if selection is None: | |
| return | |
| node = selection.node | |
| edge = selection.edge | |
| if edge is not None and edge.target.type == NodeType.AFTER_ATTN: | |
| with container_graph_right: | |
| container_graph_right_used = True | |
| heads_placeholder.write('##### Heads') | |
| head = self.draw_attn_info(edge, container_graph) | |
| with container_token_dynamics: | |
| self.draw_attention_dynamics(edge.target, head) | |
| container_token_dynamics_used = True | |
| elif node is not None and node.type == NodeType.FFN: | |
| with container_graph_right: | |
| container_graph_right_used = True | |
| heads_placeholder.write('##### Neurons') | |
| neuron = self.draw_ffn_info(node) | |
| with container_token_dynamics: | |
| self.draw_ffn_dynamics(node, neuron) | |
| container_token_dynamics_used = True | |
| if node is not None and node.is_in_residual_stream(): | |
| self.draw_top_tokens( | |
| node, | |
| container_top_tokens, | |
| container_token_dynamics if not container_token_dynamics_used else None, | |
| ) | |
| container_top_tokens_used = True | |
| container_token_dynamics_used = True | |
| finally: | |
| if not container_graph_right_used: | |
| st_placeholder('Click on an edge to see head contributions. \n\n' | |
| 'Or click on FFN to see individual neuron contributions.', container_graph_right, height=1100) | |
| if not container_top_tokens_used: | |
| st_placeholder('Select a node from residual stream to see its top tokens.', container_top_tokens, height=1100) | |
| if not container_token_dynamics_used: | |
| st_placeholder('Select a node to see its promoted tokens.', container_token_dynamics, height=1100) | |
| def run(self): | |
| with st.sidebar.expander("About", expanded=True): | |
| if self._config.demo_mode: | |
| st.caption(""" | |
| The app is deployed in Demo Mode, thus only predefined models and inputs are available.\n | |
| You can still install the app locally and use your own models and inputs.\n | |
| See https://github.com/facebookresearch/llm-transparency-tool for more information. | |
| """) | |
| self.draw_controls() | |
| if not self.model_name: | |
| st.warning("No model selected") | |
| st.stop() | |
| if self.sentence is None: | |
| st.warning("No sentence selected") | |
| else: | |
| with torch.inference_mode(): | |
| self.run_inference() | |
| self.draw_graph_and_selection() | |
| if __name__ == "__main__": | |
| top_parser = argparse.ArgumentParser() | |
| top_parser.add_argument("config_file") | |
| args = top_parser.parse_args() | |
| parser = HfArgumentParser([LlmViewerConfig]) | |
| config = parser.parse_json_file(args.config_file)[0] | |
| with SystemMonitor(config.debug) as prof: | |
| app = App(config) | |
| app.run() | |