Spaces:
Runtime error
Runtime error
| """ | |
| Gradio interface for plotting attention. | |
| """ | |
| import chess | |
| import gradio as gr | |
| import torch | |
| import uuid | |
| import re | |
| from . import constants, state, visualisation | |
| def compute_cache( | |
| game_pgn, | |
| board_fen, | |
| attention_layer, | |
| attention_head, | |
| comp_index, | |
| state_cache, | |
| state_board_index, | |
| ): | |
| if game_pgn == "" and board_fen != "": | |
| board = chess.Board(board_fen) | |
| fen_list = [board.fen()] | |
| else: | |
| board = chess.Board() | |
| fen_list = [board.fen()] | |
| for move in game_pgn.split(): | |
| if move.endswith("."): | |
| continue | |
| try: | |
| board.push_san(move) | |
| fen_list.append(board.fen()) | |
| except ValueError: | |
| gr.Warning(f"Invalid move {move}, stopping before it.") | |
| break | |
| state_cache = [(fen, state.model_cache(fen)) for fen in fen_list] | |
| return ( | |
| *make_plot( | |
| attention_layer, attention_head, comp_index, state_cache, state_board_index | |
| ), | |
| state_cache, | |
| ) | |
| def make_plot( | |
| attention_layer, | |
| attention_head, | |
| comp_index, | |
| state_cache, | |
| state_board_index, | |
| ): | |
| if state_cache is None: | |
| gr.Warning("Cache not computed!") | |
| return None, None, None, None, None | |
| fen, (out, cache) = state_cache[state_board_index] | |
| attn_list = [a[0, attention_head - 1] for a in cache[attention_layer - 1]] | |
| prompt_attn, *comp_attn = attn_list | |
| comp_attn.insert(0, prompt_attn[-1:]) | |
| comp_attn = [a.squeeze(0) for a in comp_attn] | |
| if len(comp_attn) != 5: | |
| raise NotImplementedError("This is not implemented yet.") | |
| config_total = meta_total = dump_total = 0 | |
| config_done = False | |
| heatmap = torch.zeros(64) | |
| h_index = 0 | |
| for i, t_o in enumerate(out[0]): | |
| try: | |
| t_attn = comp_attn[comp_index - 1][i] | |
| if (i < 3) or (i > len(out[0]) - 10): | |
| dump_total += t_attn | |
| continue | |
| t_str = state.model.tokenizer.decode(t_o) | |
| if t_str.startswith(" ") and h_index > 0: | |
| config_done = True | |
| if not config_done: | |
| if t_str == "/": | |
| dump_total += t_attn | |
| continue | |
| t_str = re.sub(r"\d", lambda m: "0" * int(m.group(0)), t_str) | |
| config_total += t_attn | |
| t_str_len = len(t_str.strip()) | |
| pre_t_attn = t_attn / t_str_len | |
| for j in range(t_str_len): | |
| heatmap[h_index + j] = pre_t_attn | |
| h_index += t_str_len | |
| else: | |
| meta_total += t_attn | |
| except IndexError: | |
| break | |
| raw_attention = comp_attn[comp_index - 1] | |
| highlited_tokens = [ | |
| (state.model.tokenizer.decode(out[0][i]), raw_attention[i]) | |
| for i in range(len(raw_attention)) | |
| ] | |
| uci_move = state.model.tokenizer.decode(out[0][-5:-1]).strip() | |
| board = chess.Board(fen) | |
| heatmap = heatmap.view(8, 8).flip(0).view(64) | |
| move = chess.Move.from_uci(uci_move) | |
| svg_board, fig = visualisation.render_heatmap( | |
| board, heatmap, arrows=[(move.from_square, move.to_square)] | |
| ) | |
| info = ( | |
| f"[Completion] Complete: '{state.model.tokenizer.decode(out[0][-5:])}'" | |
| f" Chosen: '{state.model.tokenizer.decode(out[0][-5:][comp_index-1])}'" | |
| f"\n[Distribution] Config: {config_total:.2f} Meta: {meta_total:.2f} Dump: {dump_total:.2f}" | |
| ) | |
| id = str(uuid.uuid4()) | |
| with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f: | |
| f.write(svg_board) | |
| return ( | |
| board.fen(), | |
| info, | |
| fig, | |
| f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", | |
| highlited_tokens, | |
| ) | |
| def previous_board( | |
| attention_layer, | |
| attention_head, | |
| comp_index, | |
| state_cache, | |
| state_board_index, | |
| ): | |
| state_board_index -= 1 | |
| if state_board_index < 0: | |
| gr.Warning("Already at first board.") | |
| state_board_index = 0 | |
| return ( | |
| *make_plot( | |
| attention_layer, attention_head, comp_index, state_cache, state_board_index | |
| ), | |
| state_board_index, | |
| ) | |
| def next_board( | |
| attention_layer, | |
| attention_head, | |
| comp_index, | |
| state_cache, | |
| state_board_index, | |
| ): | |
| state_board_index += 1 | |
| if state_board_index >= len(state_cache): | |
| gr.Warning("Already at last board.") | |
| state_board_index = len(state_cache) - 1 | |
| return ( | |
| *make_plot( | |
| attention_layer, attention_head, comp_index, state_cache, state_board_index | |
| ), | |
| state_board_index, | |
| ) | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown( | |
| "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)." | |
| ) | |
| game_pgn = gr.Textbox( | |
| label="Game PGN", | |
| lines=1, | |
| ) | |
| board_fen = gr.Textbox( | |
| label="Board FEN", | |
| lines=1, | |
| max_lines=1, | |
| ) | |
| compute_cache_button = gr.Button("Compute cache") | |
| with gr.Group(): | |
| with gr.Row(): | |
| attention_layer = gr.Slider( | |
| label="Attention layer", | |
| minimum=1, | |
| maximum=12, | |
| step=1, | |
| value=1, | |
| ) | |
| attention_head = gr.Slider( | |
| label="Attention head", | |
| minimum=1, | |
| maximum=12, | |
| step=1, | |
| value=1, | |
| ) | |
| comp_index = gr.Slider( | |
| label="Completion index", | |
| minimum=1, | |
| maximum=6, | |
| step=1, | |
| value=1, | |
| ) | |
| with gr.Row(): | |
| previous_board_button = gr.Button("Previous board") | |
| next_board_button = gr.Button("Next board") | |
| current_board_fen = gr.Textbox( | |
| label="Board FEN", | |
| lines=1, | |
| max_lines=1, | |
| ) | |
| info = gr.Textbox( | |
| label="Info", | |
| lines=1, | |
| info=( | |
| "'Config' refers to the board configuration tokens." | |
| "\n'Meta' to the additional board tokens (like color or castling)." | |
| "\n'Dump' to the rest of the tokens (including '/')." | |
| ), | |
| ) | |
| gr.Markdown( | |
| "Note that only the 'Config' attention is plotted.\n\nSee below for the raw attention." | |
| ) | |
| raw_attention_html = gr.HighlightedText( | |
| label="Raw attention", | |
| ) | |
| with gr.Column(): | |
| image_board = gr.Image(label="Board") | |
| colorbar = gr.Plot(label="Colorbar") | |
| static_inputs = [ | |
| attention_layer, | |
| attention_head, | |
| comp_index, | |
| ] | |
| static_outputs = [ | |
| current_board_fen, | |
| info, | |
| colorbar, | |
| image_board, | |
| raw_attention_html, | |
| ] | |
| state_cache = gr.State(value=None) | |
| state_board_index = gr.State(value=0) | |
| compute_cache_button.click( | |
| compute_cache, | |
| inputs=[game_pgn, board_fen, *static_inputs, state_cache, state_board_index], | |
| outputs=[*static_outputs, state_cache], | |
| ) | |
| previous_board_button.click( | |
| previous_board, | |
| inputs=[*static_inputs, state_cache, state_board_index], | |
| outputs=[*static_outputs, state_board_index], | |
| ) | |
| next_board_button.click( | |
| next_board, | |
| inputs=[*static_inputs, state_cache, state_board_index], | |
| outputs=[*static_outputs, state_board_index], | |
| ) | |
| attention_layer.change( | |
| make_plot, | |
| inputs=[*static_inputs, state_cache, state_board_index], | |
| outputs=[*static_outputs], | |
| ) | |
| attention_head.change( | |
| make_plot, | |
| inputs=[*static_inputs, state_cache, state_board_index], | |
| outputs=[*static_outputs], | |
| ) | |
| comp_index.change( | |
| make_plot, | |
| inputs=[*static_inputs, state_cache, state_board_index], | |
| outputs=[*static_outputs], | |
| ) | |