Spaces:
Running
Running
| # app.py | |
| """ | |
| Hugging Face Space: MoE Expert Routing Visualizer (Gradio) | |
| ---------------------------------------------------------- | |
| This Space lets a user: | |
| - Choose a model (from a dropdown or a free-text box) | |
| - Enter a user prompt, and optionally an assistant prompt | |
| - Call a backend function that returns 4 routing percentages (Language, Logic, Social, World) | |
| - See a bar plot + table of the percentages | |
| 🧩 Plug your real routing function in router_backend.py -> get_expert_routing(). | |
| By default, a deterministic "mock mode" produces stable pseudo-random percentages from the prompt. | |
| """ | |
| import hashlib | |
| from typing import Dict, List, Tuple, Union | |
| import gradio as gr | |
| import plotly | |
| import plotly.express as px | |
| import pandas as pd | |
| from router_backend import get_expert_routing | |
| # ---- Expected backend adapter ------------------------------------------------ | |
| # Implement your real function in router_backend.py with the following signature: | |
| # def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]] | |
| # It MUST return 4 values that sum to ~100 (percentages) in the fixed order: | |
| # ["Language", "Logic", "Social", "World"] | |
| # or a mapping with those keys. | |
| # try: | |
| # from router_backend import get_expert_routing # your real backend | |
| # BACKEND_AVAILABLE = True | |
| # except Exception as e: # keep error for display if needed | |
| # BACKEND_AVAILABLE = False | |
| # _backend_import_error = e | |
| EXPERTS = ["Language", "Logic", "Social", "World"] | |
| DEFAULT_MODELS = [ | |
| "micro-llama-1b", | |
| "micro-llama-3b", | |
| "micro-llama-1b-dpo", | |
| "micro-moe-llama-1b", | |
| "micro-smollm2-135m", | |
| "micro-smollm2-360m", | |
| "micro-moe-smollm2-135m", | |
| "micro-moe-smollm2-360m", | |
| ] | |
| def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]: | |
| """ | |
| Deterministic mock routing percentages based on model_id + prompt + seed. | |
| Returns a list of 4 percentages summing to 100.0 | |
| """ | |
| h = hashlib.sha256(f"{model_id}||{prompt}||{seed}".encode()).digest() | |
| # split into 4 positive numbers | |
| vals = [int.from_bytes(h[i*8:(i+1)*8], "little") % 10_000 + 1 for i in range(4)] | |
| s = sum(vals) | |
| return [100.0 * v / s for v in vals] | |
| def _normalize_output(r: Union[List[float], Tuple[float, float, float, float], Dict[str, float]]) -> List[float]: | |
| """ | |
| Normalize different return types into a 4-length list ordered as EXPERTS. | |
| """ | |
| if isinstance(r, dict): | |
| vals = [float(r.get(k, 0.0)) for k in EXPERTS] | |
| else: | |
| vals = [float(x) for x in list(r)] | |
| if len(vals) != 4: | |
| raise ValueError(f"Expected 4 values, got {len(vals)}.") | |
| # renormalize to 100 if needed | |
| s = sum(vals) | |
| if s <= 0: | |
| raise ValueError("Sum of routing percentages is non-positive.") | |
| vals = [100.0 * v / s for v in vals] | |
| return vals | |
| def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str: | |
| user_prompt = (user_prompt or "").strip() | |
| assistant_prompt = (assistant_prompt or "").strip() | |
| if assistant_prompt: | |
| return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}] | |
| return user_prompt | |
| def route_and_plot(model_choice: str, hf_token: str, user_prompt: str, assistant_prompt: str) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]: | |
| """ | |
| Main pipeline: | |
| - Compose prompt (user + optional assistant) | |
| - Call backend (real or mock) | |
| - Return a table + bar plot + status message | |
| """ | |
| model_id = model_choice.strip() | |
| if not model_id: | |
| raise gr.Error("Please select a model or enter a custom model id.") | |
| prompt = _compose_prompt(user_prompt, assistant_prompt) | |
| if not prompt: | |
| raise gr.Error("Please enter a prompt.") | |
| seed = 42 | |
| use_mock = False | |
| if use_mock: | |
| msg = "Using mock data." | |
| vals = _mock_routing(model_id, prompt, seed=seed) | |
| generation = None | |
| else: | |
| try: | |
| raw, generation = get_expert_routing(model_id, hf_token, prompt) # <-- your real function | |
| vals = _normalize_output(raw) | |
| msg = "Routed with real backend." | |
| except Exception as e: | |
| # fallback to mock on error, but surface message | |
| msg = f"Backend error: {e}\nFalling back to mock data." | |
| vals = _mock_routing(model_id, prompt, seed=seed) | |
| generation = None | |
| df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals}) | |
| fig = px.bar(df, x="Expert", y="Percent", title="Token Routing by Expert (%)", text="Percent") | |
| fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside") | |
| fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35) | |
| status = f"Model: {model_id}<br>{msg}" | |
| if generation is None: | |
| generation = assistant_prompt | |
| return generation, df, fig, status | |
| with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer | |
| ## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts. | |
| Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331) | |
| ---- | |
| This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt. | |
| Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure. | |
| """.strip() | |
| ) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0]) | |
| hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="hf token", lines=1) | |
| with gr.Row(): | |
| user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...") | |
| assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...") | |
| # with gr.Row(): | |
| # use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)") | |
| # seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed") | |
| run = gr.Button("Run Routing", variant="primary") | |
| generation_output = gr.Textbox(lines=4, label="Generated continuation", placeholder="Generated text will appear here...", interactive=False) | |
| with gr.Row(): | |
| table = gr.Dataframe(label="Routing Percentages", interactive=False) | |
| plot = gr.Plot(label="Bar Plot") | |
| status = gr.Markdown("") | |
| run.click( | |
| route_and_plot, | |
| inputs=[model_choice, hf_token, user_prompt, assistant_prompt], | |
| outputs=[generation_output, table, plot, status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |