Spaces:
Runtime error
Runtime error
fen entry
Browse files- src/attention_interface.py +30 -16
src/attention_interface.py
CHANGED
|
@@ -13,23 +13,28 @@ from . import constants, state, visualisation
|
|
| 13 |
|
| 14 |
def compute_cache(
|
| 15 |
game_pgn,
|
|
|
|
| 16 |
attention_layer,
|
| 17 |
attention_head,
|
| 18 |
comp_index,
|
| 19 |
state_cache,
|
| 20 |
state_board_index,
|
| 21 |
):
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
|
| 34 |
return (
|
| 35 |
*make_plot(
|
|
@@ -156,10 +161,19 @@ def next_board(
|
|
| 156 |
with gr.Blocks() as interface:
|
| 157 |
with gr.Row():
|
| 158 |
with gr.Column():
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
compute_cache_button = gr.Button("Compute cache")
|
| 164 |
with gr.Group():
|
| 165 |
with gr.Row():
|
|
@@ -228,7 +242,7 @@ with gr.Blocks() as interface:
|
|
| 228 |
state_board_index = gr.State(value=0)
|
| 229 |
compute_cache_button.click(
|
| 230 |
compute_cache,
|
| 231 |
-
inputs=[game_pgn, *static_inputs, state_cache, state_board_index],
|
| 232 |
outputs=[*static_outputs, state_cache],
|
| 233 |
)
|
| 234 |
|
|
|
|
| 13 |
|
| 14 |
def compute_cache(
|
| 15 |
game_pgn,
|
| 16 |
+
board_fen,
|
| 17 |
attention_layer,
|
| 18 |
attention_head,
|
| 19 |
comp_index,
|
| 20 |
state_cache,
|
| 21 |
state_board_index,
|
| 22 |
):
|
| 23 |
+
if game_pgn == "" and board_fen != "":
|
| 24 |
+
board = chess.Board(board_fen)
|
| 25 |
+
fen_list = [board.fen()]
|
| 26 |
+
else:
|
| 27 |
+
board = chess.Board()
|
| 28 |
+
fen_list = [board.fen()]
|
| 29 |
+
for move in game_pgn.split():
|
| 30 |
+
if move.endswith("."):
|
| 31 |
+
continue
|
| 32 |
+
try:
|
| 33 |
+
board.push_san(move)
|
| 34 |
+
fen_list.append(board.fen())
|
| 35 |
+
except ValueError:
|
| 36 |
+
gr.Warning(f"Invalid move {move}, stopping before it.")
|
| 37 |
+
break
|
| 38 |
state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
|
| 39 |
return (
|
| 40 |
*make_plot(
|
|
|
|
| 161 |
with gr.Blocks() as interface:
|
| 162 |
with gr.Row():
|
| 163 |
with gr.Column():
|
| 164 |
+
with gr.Group():
|
| 165 |
+
gr.Markdown(
|
| 166 |
+
"Specify the game PGN of FEN string that you want to analyse (PGN overrides FEN)."
|
| 167 |
+
)
|
| 168 |
+
game_pgn = gr.Textbox(
|
| 169 |
+
label="Game PGN",
|
| 170 |
+
lines=1,
|
| 171 |
+
)
|
| 172 |
+
board_fen = gr.Textbox(
|
| 173 |
+
label="Board FEN",
|
| 174 |
+
lines=1,
|
| 175 |
+
max_lines=1,
|
| 176 |
+
)
|
| 177 |
compute_cache_button = gr.Button("Compute cache")
|
| 178 |
with gr.Group():
|
| 179 |
with gr.Row():
|
|
|
|
| 242 |
state_board_index = gr.State(value=0)
|
| 243 |
compute_cache_button.click(
|
| 244 |
compute_cache,
|
| 245 |
+
inputs=[game_pgn, board_fen, *static_inputs, state_cache, state_board_index],
|
| 246 |
outputs=[*static_outputs, state_cache],
|
| 247 |
)
|
| 248 |
|