# app.py """ Hugging Face Space: MoE Expert Routing Visualizer (Gradio) ---------------------------------------------------------- This Space lets a user: - Choose a model (from a dropdown or a free-text box) - Enter a user prompt, and optionally an assistant prompt - Call a backend function that returns 4 routing percentages (Language, Logic, Social, World) - See a bar plot + table of the percentages đź§© Plug your real routing function in router_backend.py -> get_expert_routing(). By default, a deterministic "mock mode" produces stable pseudo-random percentages from the prompt. """ import os import hashlib from typing import Dict, List, Tuple, Union import gradio as gr import plotly import plotly.express as px import pandas as pd from router_backend import get_expert_routing EXPERTS = ["Language", "Logic", "Social", "World"] DEFAULT_MODELS = [ "micro-smollm2-135m", "micro-smollm2-360m", "micro-llama-1b", "micro-llama-3b", "micro-llama-1b-dpo", "micro-moe-smollm2-135m", "micro-moe-smollm2-360m", "micro-moe-llama-1b", ] def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]: """ Deterministic mock routing percentages based on model_id + prompt + seed. Returns a list of 4 percentages summing to 100.0 """ h = hashlib.sha256(f"{model_id}||{prompt}||{seed}".encode()).digest() # split into 4 positive numbers vals = [int.from_bytes(h[i*8:(i+1)*8], "little") % 10_000 + 1 for i in range(4)] s = sum(vals) return [100.0 * v / s for v in vals] def _normalize_output(r: Union[List[float], Tuple[float, float, float, float], Dict[str, float]]) -> List[float]: """ Normalize different return types into a 4-length list ordered as EXPERTS. """ if isinstance(r, dict): vals = [float(r.get(k, 0.0)) for k in EXPERTS] else: vals = [float(x) for x in list(r)] if len(vals) != 4: raise ValueError(f"Expected 4 values, got {len(vals)}.") # renormalize to 100 if needed s = sum(vals) if s <= 0: raise ValueError("Sum of routing percentages is non-positive.") vals = [100.0 * v / s for v in vals] return vals def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str: user_prompt = (user_prompt or "").strip() assistant_prompt = (assistant_prompt or "").strip() if assistant_prompt: return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}] return user_prompt def plot_lines(arrays): names = EXPERTS LINE_COLORS = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"] LINE_COLORS = { name: color for name, color in zip(names, LINE_COLORS) } # Build a tidy DataFrame: columns = index, value, series records = [] for i, array in enumerate(arrays): for name, v in zip(names, array): records.append({"index": i+1, "value": v, "series": name}) df = pd.DataFrame.from_records(records) fig = px.line( df, x="index", y="value", color="series", color_discrete_map=LINE_COLORS, title="", markers=True, ) fig.update_layout( xaxis_title="Layer Index", yaxis_title="Percentage (%)", legend_title="Layer-wise Expert Routing", ) return fig def route_and_plot( model_choice: str, user_prompt: str, assistant_prompt: str, ablate_language: bool, ablate_logic: bool, ablate_social: bool, ablate_world: bool, ) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]: """ Main pipeline: - Compose prompt (user + optional assistant) - Call backend (real or mock) - Return a table + bar plot + status message """ hf_token = os.getenv("HF_TOKEN") ablations = [] if ablate_language: ablations.append("language") if ablate_logic: ablations.append("logic") if ablate_social: ablations.append("social") if ablate_world: ablations.append("world") seed = 42 use_mock = False if len(ablations) == 4: msg = "Error message: you can't ablate all experts.
Falling back to mock data." generation = None vals = _mock_routing(model_id, prompt, seed=seed) else: model_id = model_choice.strip() if not model_id: raise gr.Error("Please select a model or enter a custom model id.") prompt = _compose_prompt(user_prompt, assistant_prompt) if not prompt: raise gr.Error("Please enter a prompt.") if use_mock: msg = "Using mock data." vals = _mock_routing(model_id, prompt, seed=seed) generation = None layer_routing = None else: try: raw, layer_routing, generation = get_expert_routing(model_id, hf_token, prompt, ablations) # <-- your real function vals = _normalize_output(raw) msg = "Routed with real backend." except Exception as e: # fallback to mock on error, but surface message print(f"Backend error: {e}") msg = f"Backend error: {e}\nFalling back to mock data." vals = _mock_routing(model_id, prompt, seed=seed) generation = None layer_routing = None df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals}) colors = ["#97D077", "#4285F4", "#FFAB40", "#A64D79"] fig = px.bar(df, x="Expert", y="Percent", title="Token Routing by Expert (%)", text="Percent") fig.update_traces(marker_color=colors) fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside") fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35) line_fig = plot_lines(layer_routing) if layer_routing is not None else None status = f"Model: {model_id}
{msg}" if generation is None: generation = assistant_prompt return generation, df, fig, line_fig, status with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo: gr.Markdown( """ # 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer ## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts. ### Project Website: [https://cognitive-reasoners.epfl.ch](https://cognitive-reasoners.epfl.ch) | Paper: [https://arxiv.org/abs/2506.13331](https://arxiv.org/abs/2506.13331) ---- This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt. Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure. """.strip() ) with gr.Row(): model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0]) # hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="Required for Llama-based models", lines=1) with gr.Column(): with gr.Row(): gr.Markdown( """ #### Ablate Experts (Check to disable an expert; the routing percentages will be redistributed among the remaining experts) """, label="Ablate Experts" ) with gr.Row(): ablate_language = gr.Checkbox(value=False, label="Language Expert") ablate_logic = gr.Checkbox(value=False, label="Logic Expert") ablate_social = gr.Checkbox(value=False, label="Social Expert") ablate_world = gr.Checkbox(value=False, label="World Expert") with gr.Row(): user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...") assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...") run = gr.Button("Run Routing", variant="primary") generation_output = gr.Textbox(lines=4, label="Generated Response", placeholder="Generated text will appear here...", interactive=False) with gr.Row(): table = gr.Dataframe(label="Routing Percentages", interactive=False) plot = gr.Plot(label="Bar Plot") line_plot = gr.Plot(label="Layer-wise Routing Percentages") status = gr.Markdown("", label="System Message") run.click( route_and_plot, inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world], outputs=[generation_output, table, plot, line_plot, status], ) # example prompts examples = [ [ "micro-llama-1b", # dropdown model "Correct the grammar: \"She go to the park every morning.\"", # user prompt "She goes to the park every morning.", # assistant prompt (empty) False, False, False, False # no ablations ], [ "micro-llama-1b", # dropdown model "What is 27 multiplied by 14?", # user prompt "First, break it down: 27 * 10 = 270. Then 27 * 4 = 108. Add them together: 270 + 108 = 378. So the answer is 378.", # assistant prompt (empty) False, False, False, False # no ablations ], [ "micro-llama-1b", # dropdown model "Why did Sarah look away when John asked if she was okay?", # user prompt "Because she didn't want him to see that she was upset.", # assistant prompt (empty) False, False, False, False # no ablations ], [ "micro-llama-1b", # dropdown model "Why do people usually eat breakfast in the morning?", # user prompt "Because after sleeping, the body needs energy to start the day.", # assistant prompt (empty) False, False, False, False # no ablations ], ] gr.Examples( examples=examples, inputs=[model_choice, user_prompt, assistant_prompt, ablate_language, ablate_logic, ablate_social, ablate_world], label="Try these examples:", cache_examples=True, fn=route_and_plot, outputs=[generation_output, table, plot, line_plot, status], ) if __name__ == "__main__": demo.launch()