File size: 10,243 Bytes
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
5f411d7
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f411d7
582ea12
 
 
4e8105c
 
582ea12
 
 
 
 
4e8105c
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0742fe
 
 
 
 
 
 
 
 
582ea12
 
 
 
 
 
a11e2a7
55b6215
c0742fe
 
 
 
 
 
 
 
 
 
582ea12
 
c0742fe
 
 
582ea12
c0742fe
582ea12
c0742fe
 
 
 
 
 
 
 
 
582ea12
 
c0742fe
 
a11e2a7
c0742fe
 
 
 
 
 
 
582ea12
 
8730f5f
582ea12
8730f5f
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f411d7
c0742fe
 
 
 
 
 
 
 
 
 
 
 
 
 
582ea12
 
 
 
 
 
 
 
 
 
 
4e8105c
582ea12
 
 
 
c0742fe
 
 
582ea12
 
 
5f411d7
582ea12
 
 
5f411d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a11e2a7
 
ae072a3
 
5f411d7
 
582ea12
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# app.py
"""
Hugging Face Space: MoE Expert Routing Visualizer (Gradio)
----------------------------------------------------------
This Space lets a user:
- Choose a model (from a dropdown or a free-text box)
- Enter a user prompt, and optionally an assistant prompt
- Call a backend function that returns 4 routing percentages (Language, Logic, Social, World)
- See a bar plot + table of the percentages

🧩 Plug your real routing function in router_backend.py -> get_expert_routing().
By default, a deterministic "mock mode" produces stable pseudo-random percentages from the prompt.
"""

import os
import hashlib
from typing import Dict, List, Tuple, Union
import gradio as gr
import plotly
import plotly.express as px
import pandas as pd
from router_backend import get_expert_routing 

# ---- Expected backend adapter ------------------------------------------------
# Implement your real function in router_backend.py with the following signature:
#   def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]
# It MUST return 4 values that sum to ~100 (percentages) in the fixed order:
#   ["Language", "Logic", "Social", "World"]
# or a mapping with those keys.
# try:
#     from router_backend import get_expert_routing  # your real backend
#     BACKEND_AVAILABLE = True
# except Exception as e:  # keep error for display if needed
#     BACKEND_AVAILABLE = False
#     _backend_import_error = e


EXPERTS = ["Language", "Logic", "Social", "World"]

DEFAULT_MODELS = [
    "micro-smollm2-135m",
    "micro-smollm2-360m",
    "micro-llama-1b", 
    "micro-llama-3b",
    "micro-llama-1b-dpo",
    "micro-moe-smollm2-135m",
    "micro-moe-smollm2-360m",
    "micro-moe-llama-1b",
]

def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
    """
    Deterministic mock routing percentages based on model_id + prompt + seed.
    Returns a list of 4 percentages summing to 100.0
    """
    h = hashlib.sha256(f"{model_id}||{prompt}||{seed}".encode()).digest()
    # split into 4 positive numbers
    vals = [int.from_bytes(h[i*8:(i+1)*8], "little") % 10_000 + 1 for i in range(4)]
    s = sum(vals)
    return [100.0 * v / s for v in vals]

def _normalize_output(r: Union[List[float], Tuple[float, float, float, float], Dict[str, float]]) -> List[float]:
    """
    Normalize different return types into a 4-length list ordered as EXPERTS.
    """
    if isinstance(r, dict):
        vals = [float(r.get(k, 0.0)) for k in EXPERTS]
    else:
        vals = [float(x) for x in list(r)]
        if len(vals) != 4:
            raise ValueError(f"Expected 4 values, got {len(vals)}.")
    # renormalize to 100 if needed
    s = sum(vals)
    if s <= 0:
        raise ValueError("Sum of routing percentages is non-positive.")
    vals = [100.0 * v / s for v in vals]
    return vals

def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
    user_prompt = (user_prompt or "").strip()
    assistant_prompt = (assistant_prompt or "").strip()
    if assistant_prompt:
        return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
    return user_prompt

def route_and_plot(
    model_choice: str, 
    user_prompt: str, 
    assistant_prompt: str,
    ablate_language: bool, 
    ablate_logic: bool, 
    ablate_social: bool, 
    ablate_world: bool,
) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
    """
    Main pipeline:
    - Compose prompt (user + optional assistant)
    - Call backend (real or mock)
    - Return a table + bar plot + status message
    """
    hf_token = os.getenv("HF_TOKEN")

    ablations = []
    if ablate_language:
        ablations.append("language")
    if ablate_logic:
        ablations.append("logic")
    if ablate_social:
        ablations.append("social")
    if ablate_world:
        ablations.append("world")

    seed = 42
    use_mock = False
    
    if len(ablations) == 4:
        msg = "Error message: you can't ablate all experts.<br>Falling back to mock data."
        generation = None
        vals = _mock_routing(model_id, prompt, seed=seed)
    else:
        model_id = model_choice.strip()
        if not model_id:
            raise gr.Error("Please select a model or enter a custom model id.")
        prompt = _compose_prompt(user_prompt, assistant_prompt)
        if not prompt:
            raise gr.Error("Please enter a prompt.")
        
        if use_mock:
            msg = "Using mock data."
            vals = _mock_routing(model_id, prompt, seed=seed)
            generation = None
        else:
            try:
                raw, generation = get_expert_routing(model_id, hf_token, prompt, ablations)  # <-- your real function
                vals = _normalize_output(raw)
                msg = "Routed with real backend."
            except Exception as e:
                # fallback to mock on error, but surface message
                msg = f"Backend error: {e}\nFalling back to mock data."
                vals = _mock_routing(model_id, prompt, seed=seed)
                generation = None

    df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
    colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"]
    fig = px.bar(df, x="Expert", y="Percent", title="Token Routing by Expert (%)", text="Percent")
    fig.update_traces(marker_color=colors)
    fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
    fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)

    status = f"Model: {model_id}<br>{msg}"
    if generation is None:
        generation = assistant_prompt

    return generation, df, fig, status

with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
    gr.Markdown(
        """
        # 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
        ## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
        Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331)
        ----
        This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
        Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
        """.strip()
    )

    with gr.Row():
        model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
        # hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1)
    
    with gr.Column():
        with gr.Row():
            gr.Markdown(
                """
                #### Ablate Experts
                (Check to disable an expert; the routing percentages will be redistributed among the remaining experts)
                """, label="Ablate Experts"
            )
        with gr.Row():
            ablate_language = gr.Checkbox(value=False, label="Language Expert")
            ablate_logic = gr.Checkbox(value=False, label="Logic Expert")
            ablate_social = gr.Checkbox(value=False, label="Social Expert")
            ablate_world = gr.Checkbox(value=False, label="World Expert")

    with gr.Row():
        user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
        assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")

    # with gr.Row():
    #     use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)")
    #     seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed")

    run = gr.Button("Run Routing", variant="primary")

    generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False)

    with gr.Row():
        table = gr.Dataframe(label="Routing Percentages", interactive=False)
    plot = gr.Plot(label="Bar Plot")


    status = gr.Markdown("", label="System Message")

    run.click(
        route_and_plot,
        inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
        outputs=[generation_output, table, plot, status],
    )

    # example prompts
    examples = [
        [
            "micro-llama-1b",     # dropdown model
            "Correct the grammar: \"She go to the park every morning.\"",  # user prompt
            "She goes to the park every morning.",  # assistant prompt (empty)
            False, False, False, False  # no ablations
        ],
        [
            "micro-llama-1b",     # dropdown model
            "What is 27 multiplied by 14?",  # user prompt
            "First, break it down: 27 * 10 = 270. Then 27 * 4 = 108. Add them together: 270 + 108 = 378. So the answer is 378.",  # assistant prompt (empty)
            False, False, False, False  # no ablations
        ],
        [
            "micro-llama-1b",     # dropdown model
            "Why did Sarah look away when John asked if she was okay?",  # user prompt
            "Because she didn't want him to see that she was upset.",  # assistant prompt (empty)
            False, False, False, False  # no ablations
        ],
        [
            "micro-llama-1b",     # dropdown model
            "Why do people usually eat breakfast in the morning?",  # user prompt
            "Because after sleeping, the body needs energy to start the day.",  # assistant prompt (empty)
            False, False, False, False  # no ablations
        ],
    ]

    gr.Examples(
        examples=examples,
        inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world],
        label="Try these examples:",
        cache_examples=True,
        fn=route_and_plot,
        outputs=[generation_output, table, plot, status],
    )

if __name__ == "__main__":
    demo.launch()