bkhmsi's picture
created micro hf space
582ea12
raw
history blame
7.06 kB
# app.py
"""
Hugging Face Space: MoE Expert Routing Visualizer (Gradio)
----------------------------------------------------------
This Space lets a user:
- Choose a model (from a dropdown or a free-text box)
- Enter a user prompt, and optionally an assistant prompt
- Call a backend function that returns 4 routing percentages (Language, Logic, Social, World)
- See a bar plot + table of the percentages
🧩 Plug your real routing function in router_backend.py -> get_expert_routing().
By default, a deterministic "mock mode" produces stable pseudo-random percentages from the prompt.
"""
import hashlib
from typing import Dict, List, Tuple, Union
import gradio as gr
import plotly
import plotly.express as px
import pandas as pd
from router_backend import get_expert_routing
# ---- Expected backend adapter ------------------------------------------------
# Implement your real function in router_backend.py with the following signature:
# def get_expert_routing(model_id: str, prompt: str) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]
# It MUST return 4 values that sum to ~100 (percentages) in the fixed order:
# ["Language", "Logic", "Social", "World"]
# or a mapping with those keys.
# try:
# from router_backend import get_expert_routing # your real backend
# BACKEND_AVAILABLE = True
# except Exception as e: # keep error for display if needed
# BACKEND_AVAILABLE = False
# _backend_import_error = e
EXPERTS = ["Language", "Logic", "Social", "World"]
DEFAULT_MODELS = [
"micro-llama-1b",
"micro-llama-3b",
"micro-llama-1b-dpo",
"micro-moe-llama-1b",
"micro-smollm2-135m",
"micro-smollm2-360m",
"micro-moe-smollm2-135m",
"micro-moe-smollm2-360m",
]
def _mock_routing(model_id: str, prompt: str, seed: int = 0) -> List[float]:
"""
Deterministic mock routing percentages based on model_id + prompt + seed.
Returns a list of 4 percentages summing to 100.0
"""
h = hashlib.sha256(f"{model_id}||{prompt}||{seed}".encode()).digest()
# split into 4 positive numbers
vals = [int.from_bytes(h[i*8:(i+1)*8], "little") % 10_000 + 1 for i in range(4)]
s = sum(vals)
return [100.0 * v / s for v in vals]
def _normalize_output(r: Union[List[float], Tuple[float, float, float, float], Dict[str, float]]) -> List[float]:
"""
Normalize different return types into a 4-length list ordered as EXPERTS.
"""
if isinstance(r, dict):
vals = [float(r.get(k, 0.0)) for k in EXPERTS]
else:
vals = [float(x) for x in list(r)]
if len(vals) != 4:
raise ValueError(f"Expected 4 values, got {len(vals)}.")
# renormalize to 100 if needed
s = sum(vals)
if s <= 0:
raise ValueError("Sum of routing percentages is non-positive.")
vals = [100.0 * v / s for v in vals]
return vals
def _compose_prompt(user_prompt: str, assistant_prompt: str) -> str:
user_prompt = (user_prompt or "").strip()
assistant_prompt = (assistant_prompt or "").strip()
if assistant_prompt:
return [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt}]
return user_prompt
def route_and_plot(model_choice: str, hf_token: str, user_prompt: str, assistant_prompt: str) -> Tuple[pd.DataFrame, "plotly.graph_objs._figure.Figure", str]:
"""
Main pipeline:
- Compose prompt (user + optional assistant)
- Call backend (real or mock)
- Return a table + bar plot + status message
"""
model_id = model_choice.strip()
if not model_id:
raise gr.Error("Please select a model or enter a custom model id.")
prompt = _compose_prompt(user_prompt, assistant_prompt)
if not prompt:
raise gr.Error("Please enter a prompt.")
seed = 42
use_mock = False
if use_mock:
msg = "Using mock data."
vals = _mock_routing(model_id, prompt, seed=seed)
generation = None
else:
try:
raw, generation = get_expert_routing(model_id, hf_token, prompt) # <-- your real function
vals = _normalize_output(raw)
msg = "Routed with real backend."
except Exception as e:
# fallback to mock on error, but surface message
msg = f"Backend error: {e}\nFalling back to mock data."
vals = _mock_routing(model_id, prompt, seed=seed)
generation = None
df = pd.DataFrame({"Expert": EXPERTS, "Percent": vals})
fig = px.bar(df, x="Expert", y="Percent", title="Token Routing by Expert (%)", text="Percent")
fig.update_traces(texttemplate="%{text:.2f}%", textposition="outside")
fig.update_layout(yaxis_range=[0, max(100, max(vals) * 1.25)], bargap=0.35)
status = f"Model: {model_id}<br>{msg}"
if generation is None:
generation = assistant_prompt
return generation, df, fig, status
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
gr.Markdown(
"""
# 🧠 Mixture of Cognitive Reasoner (MiCRo) Expert Routing Visualizer
## Enter a prompt (and optionally an assistant reply), pick a model, and visualize how tokens were routed across experts.
Paper: [Mixture of Cognitive Reasoners: Modular Reasoning with Brain-Like Specialization](https://arxiv.org/abs/2506.13331)
----
This demo visualizes how modular language models allocate computation across specialized experts—Language, Logic, Social, and World—when processing a given prompt.
Each expert corresponds to a cognitive domain inspired by human brain networks. Enter a prompt to see how tokens are dynamically routed across modules, revealing the model's internal reasoning structure.
""".strip()
)
with gr.Row():
model_choice = gr.Dropdown(choices=DEFAULT_MODELS, label="Select a model", value=DEFAULT_MODELS[0])
hf_token = gr.Textbox(label="Huggingface token for authentication", placeholder="hf token", lines=1)
with gr.Row():
user_prompt = gr.Textbox(lines=6, label="User prompt", placeholder="Type the user message here...")
assistant_prompt = gr.Textbox(lines=6, label="Assistant prompt (optional)", placeholder="Type the assistant message here (optional)...")
# with gr.Row():
# use_mock = gr.Checkbox(value=True, label="Use mock data (uncheck to call your backend)")
# seed = gr.Slider(value=0, minimum=0, maximum=10_000, step=1, label="Mock seed")
run = gr.Button("Run Routing", variant="primary")
generation_output = gr.Textbox(lines=4, label="Generated continuation", placeholder="Generated text will appear here...", interactive=False)
with gr.Row():
table = gr.Dataframe(label="Routing Percentages", interactive=False)
plot = gr.Plot(label="Bar Plot")
status = gr.Markdown("")
run.click(
route_and_plot,
inputs=[model_choice, hf_token, user_prompt, assistant_prompt],
outputs=[generation_output, table, plot, status],
)
if __name__ == "__main__":
demo.launch()