Spaces:
Running
Running
File size: 7,058 Bytes
582ea12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# app.py
"""
Hugging Face Space: MoE Expert Routing Visualizer (Gradio)
----------------------------------------------------------
This Space lets a user:
- Choose a model (from a dropdown or a free-text box)
- Enter a user prompt, and optionally an assistant prompt
- Call a backend function that returns 4 routing percentages (Language, Logic, Social, World)
- See a bar plot + table of the percentages
🧩 Plug your real routing function in router_backend.py -> get_expert_routing().
By default, a deterministic "mock mode" produces stable pseudo-random percentages from the prompt.
"""
import hashlib
from typing import Dict, List, Tuple, Union
import gradio as gr
import plotly
import plotly.express as px
import pandas as pd
from router_backend import get_expert_routing
# ---- Expected backend adapter ------------------------------------------------
# Implement your real function in router_backend.py with the following signature:
# def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]
# It MUST return 4 values that sum to ~100 (percentages) in the fixed order:
# ["Language", "Logic", "Social", "World"]
# or a mapping with those keys.
# try:
# from router_backend import get_expert_routing # your real backend
# BACKEND_AVAILABLE = True
# except Exception as e: # keep error for display if needed
# BACKEND_AVAILABLE = False
# _backend_import_error = e
EXPERTS = ["Language", "Logic", "Social", "World"]
DEFAULT_MODELS = [
"micro-llama-1b",
"micro-llama-3b",
"micro-llama-1b-dpo",
"micro-moe-llama-1b",
"micro-smollm2-135m",
"micro-smollm2-360m",
"micro-moe-smollm2-135m",
"micro-moe-smollm2-360m",
]
def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
"""
Deterministic mock routing percentages based on model_id + prompt + seed.
Returns a list of 4 percentages summing to 100.0
"""
h = hashlib.sha256(f"{model_id}||{prompt}||{seed}".encode()).digest()
# split into 4 positive numbers
vals = [int.from_bytes(h[i*8:(i+1)*8], "little") % 10_000 + 1 for i in range(4)]
s = sum(vals)
return [100.0 * v / s for v in vals]
def _normalize_output(r: Union[List[float], Tuple[float, float, float, float], Dict[str, float]]) -> List[float]:
"""
Normalize different return types into a 4-length list ordered as EXPERTS.
"""
if isinstance(r, dict):
vals = [float(r.get(k, 0.0)) for k in EXPERTS]
else:
vals = [float(x) for x in list(r)]
if len(vals) != 4:
raise ValueError(f"Expected 4 values, got {len(vals)}.")
# renormalize to 100 if needed
s = sum(vals)
if s <= 0:
raise ValueError("Sum of routing percentages is non-positive.")
vals = [100.0 * v / s for v in vals]
return vals
def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
user_prompt = (user_prompt or "").strip()
assistant_prompt = (assistant_prompt or "").strip()
if assistant_prompt:
return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
return user_prompt
def route_and_plot(model_choice: str, hf_token: str, user_prompt: str, assistant_prompt: str) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
"""
Main pipeline:
- Compose prompt (user + optional assistant)
- Call backend (real or mock)
- Return a table + bar plot + status message
"""
model_id = model_choice.strip()
if not model_id:
raise gr.Error("Please select a model or enter a custom model id.")
prompt = _compose_prompt(user_prompt, assistant_prompt)
if not prompt:
raise gr.Error("Please enter a prompt.")
seed = 42
use_mock = False
if use_mock:
msg = "Using mock data."
vals = _mock_routing(model_id, prompt, seed=seed)
generation = None
else:
try:
raw, generation = get_expert_routing(model_id, hf_token, prompt) # <-- your real function
vals = _normalize_output(raw)
msg = "Routed with real backend."
except Exception as e:
# fallback to mock on error, but surface message
msg = f"Backend error: {e}\nFalling back to mock data."
vals = _mock_routing(model_id, prompt, seed=seed)
generation = None
df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
fig = px.bar(df, x="Expert", y="Percent", title="Token Routing by Expert (%)", text="Percent")
fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
status = f"Model: {model_id}<br>{msg}"
if generation is None:
generation = assistant_prompt
return generation, df, fig, status
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
gr.Markdown(
"""
# 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331)
----
This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
""".strip()
)
with gr.Row():
model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="hf token", lines=1)
with gr.Row():
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
# with gr.Row():
# use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)")
# seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed")
run = gr.Button("Run Routing", variant="primary")
generation_output = gr.Textbox(lines=4, label="Generated continuation", placeholder="Generated text will appear here...", interactive=False)
with gr.Row():
table = gr.Dataframe(label="Routing Percentages", interactive=False)
plot = gr.Plot(label="Bar Plot")
status = gr.Markdown("")
run.click(
route_and_plot,
inputs=[model_choice, hf_token, user_prompt, assistant_prompt],
outputs=[generation_output, table, plot, status],
)
if __name__ == "__main__":
demo.launch()
|