Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from typing import List, Optional | |
| import networkx as nx | |
| import streamlit.components.v1 as components | |
| from llm_transparency_tool.models.transparent_llm import ModelInfo | |
| from llm_transparency_tool.server.graph_selection import GraphSelection, UiGraphNode | |
| _RELEASE = True | |
| if _RELEASE: | |
| parent_dir = os.path.dirname(os.path.abspath(__file__)) | |
| config = { | |
| "path": os.path.join(parent_dir, "frontend/build"), | |
| } | |
| else: | |
| config = { | |
| "url": "http://localhost:3001", | |
| } | |
| _component_func = components.declare_component("contribution_graph", **config) | |
| def is_node_valid(node: UiGraphNode, n_layers: int, n_tokens: int): | |
| return node.layer < n_layers and node.token < n_tokens | |
| def is_selection_valid(s: GraphSelection, n_layers: int, n_tokens: int): | |
| if not s: | |
| return True | |
| if s.node: | |
| if not is_node_valid(s.node, n_layers, n_tokens): | |
| return False | |
| if s.edge: | |
| for node in [s.edge.source, s.edge.target]: | |
| if not is_node_valid(node, n_layers, n_tokens): | |
| return False | |
| return True | |
| def contribution_graph( | |
| model_info: ModelInfo, | |
| tokens: List[str], | |
| graphs: List[nx.Graph], | |
| key: str, | |
| ) -> Optional[GraphSelection]: | |
| """Create a new instance of contribution graph. | |
| Returns selected graph node or None if nothing was selected. | |
| """ | |
| assert len(tokens) == len(graphs) | |
| result = _component_func( | |
| component="graph", | |
| model_info=model_info.__dict__, | |
| tokens=tokens, | |
| edges_per_token=[nx.node_link_data(g)["links"] for g in graphs], | |
| default=None, | |
| key=key, | |
| ) | |
| selection = GraphSelection.from_json(result) | |
| n_tokens = len(tokens) | |
| n_layers = model_info.n_layers | |
| # We need this extra protection because even though the component has to check for | |
| # the validity of the selection, sometimes it allows invalid output. It's some | |
| # unexpected effect that has something to do with React and how the output value is | |
| # set for the component. | |
| if not is_selection_valid(selection, n_layers, n_tokens): | |
| selection = None | |
| return selection | |
| def selector( | |
| items: List[str], | |
| indices: List[int], | |
| temperatures: Optional[List[float]], | |
| preselected_index: Optional[int], | |
| key: str, | |
| ) -> Optional[int]: | |
| """Create a new instance of selector. | |
| Returns selected item index. | |
| """ | |
| n = len(items) | |
| assert n == len(indices) | |
| items = [{"index": i, "text": s} for s, i in zip(items, indices)] | |
| if temperatures is not None: | |
| assert n == len(temperatures) | |
| for i, t in enumerate(temperatures): | |
| items[i]["temperature"] = t | |
| result = _component_func( | |
| component="selector", | |
| items=items, | |
| preselected_index=preselected_index, | |
| default=None, | |
| key=key, | |
| ) | |
| return None if result is None else int(result) | |