cognitive-reasoners / router_backend.py
bkhmsi's picture
added examples
5f411d7
# router_backend.py
"""
Plug your real model routing function here.
Implement the function:
get_expert_routing(model_id: str, prompt: str) -> list[float] | dict[str, float] | tuple[float, float, float, float]
It must return 4 values (percentages) corresponding to the experts:
["Language", "Logic", "Social", "World"]
Example return formats:
- [12.5, 45.0, 22.5, 20.0]
- {"Language": 12.5, "Logic": 45.0, "Social": 22.5, "World": 20.0}
- (12.5, 45.0, 22.5, 20.0)
"""
import torch
import pathlib
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from typing import Union, Dict, List, Tuple
from models.micro_olmo import MiCRoOLMo
from models.micro_llama import MiCRoLlama
from models.micro_moe_llama import MiCRoLlamaMoE
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]], ablations: List[str] = None) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:
model, tokenizer = build_model(model_id, hf_token, ablations=ablations)
if isinstance(prompt, str):
generation, routing_weights = generate_continuation(model, tokenizer, prompt)
generation = generation[0] if type(generation) is list else generation
elif isinstance(prompt, list):
generation = None
routing_weights = get_routing_weights(model, tokenizer, [prompt])
model_routing_percentages = aggregate_routing_weights(routing_weights)[0]
print(model_routing_percentages)
if generation is not None:
print(f"Generation:\n{generation}")
return {
"Language": float(model_routing_percentages[3]),
"Logic": float(model_routing_percentages[0]),
"Social": float(model_routing_percentages[1]),
"World": float(model_routing_percentages[2]),
}, generation
def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
return {
# MiCRo-Llama
"micro-llama-1b": ("bkhmsi/micro-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),
"micro-llama-3b": ("bkhmsi/micro-llama-3b", "meta-llama/Llama-3.2-3B-Instruct", MiCRoLlama),
"micro-llama-1b-dpo": ("bkhmsi/micro-llama-1b-dpo", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),
# MiCRo-MoE-Llama
"micro-moe-llama-1b": ("bkhmsi/micro-moe-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlamaMoE),
# MiCRo-OLMo
"micro-olmo": ("bkhmsi/micro-olmo-1b", "allenai/OLMo-2-0425-1B-Instruct", MiCRoOLMo),
# MiCRo-SmolLM2
"micro-smollm2-135m": ("bkhmsi/micro-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlama),
"micro-smollm2-360m": ("bkhmsi/micro-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlama),
# MiCRo-MoE-SmolLM2
"micro-moe-smollm2-135m": ("bkhmsi/micro-moe-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlamaMoE),
"micro-moe-smollm2-360m": ("bkhmsi/micro-moe-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlamaMoE),
}.get(model_name, (model_name, model_name, AutoModelForCausalLM))
def aggregate_routing_weights(routing_weights):
experts = ["Logic", "Social", "World", "Language"]
expert_token_model = np.zeros((len(experts)), dtype=int)
expert_layer_token = np.zeros((routing_weights.shape[0], len(experts)), dtype=int)
num_layers = routing_weights.shape[0]
for layer_idx in range(num_layers):
for token_idx in range(len(routing_weights[layer_idx])):
expert_idx = routing_weights[layer_idx][token_idx].argmax()
if layer_idx >= 2 and layer_idx < num_layers - 2:
expert_token_model[expert_idx] += 1
expert_layer_token[layer_idx][expert_idx] += 1
return expert_token_model, expert_layer_token
def generate_continuation(model,
tokenizer,
prompts,
max_tokens=128,
use_cache=True,
return_routing_weights=True
):
if isinstance(prompts, str):
prompts = [{"role": "user", "content": prompts}]
tokenizer.padding_side = "left"
inputs = tokenizer.apply_chat_template([
prompt for prompt in prompts
], return_tensors="pt", padding=True, add_generation_prompt=True).to(DEVICE)
attention_mask = torch.ones_like(inputs)
attention_mask[inputs == tokenizer.pad_token_id] = 0
outputs = model.generate(
input_ids=inputs,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
use_cache=use_cache,
stop_strings=["</s>","<|eot_id|>", "<|im_start|>user", "user"],
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
temperature=0,
top_p=1.0,
do_sample=False,
)
if return_routing_weights:
attention_mask = torch.ones_like(outputs)
attention_mask[outputs == tokenizer.pad_token_id] = 0
model_output = model(input_ids=outputs, attention_mask=attention_mask)
torch.cuda.empty_cache()
routing_weights = model_output.routing_weights
routing_weights = np.concatenate([
F.softmax(rw, dim=-1)[:, inputs.shape[1]:].detach().float().cpu().numpy()
for rw in routing_weights
])
else:
routing_weights = None
inputs_text = tokenizer.batch_decode(inputs, skip_special_tokens=False)
generations = []
for i, output in enumerate(outputs):
decoded_output = tokenizer.decode(output, skip_special_tokens=False)
decoded_output = decoded_output.replace(inputs_text[i], "")
decoded_output = decoded_output.replace(tokenizer.pad_token, "").strip()
decoded_output = decoded_output.replace("<|end_of_text|>", "").strip()
decoded_output = decoded_output.replace("<|endoftext|>", "").strip()
decoded_output = decoded_output.replace("<|eot_id|>", "").strip()
decoded_output = decoded_output.replace("\n<|im_start|>user", "").strip()
generations.append(decoded_output)
return (generations, routing_weights) if return_routing_weights else generations
def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
"""
Get routing weights for the given prompts using the model.
Args:
model: The MiCRoLlama or MiCRoOLMo model.
tokenizer: The tokenizer for the model.
prompts: A string or list of dictionaries containing the prompts.
Returns:
routing_weights: A list of routing weights for each layer.
"""
tokenizer.padding_side = "left"
if apply_chat_template:
if isinstance(prompts, str):
prompts = [{"role": "user", "content": prompts}]
inputs = tokenizer.apply_chat_template([
prompt for prompt in prompts
], return_tensors="pt", padding=True).to(DEVICE)
input_without_response = tokenizer.apply_chat_template([
prompt[:-1] for prompt in prompts
], return_tensors="pt", padding=True,
).to(DEVICE)
else:
inputs = tokenizer(prompts[0] + prompts[1], return_tensors="pt", padding=True).input_ids.to(DEVICE)
input_without_response = tokenizer(prompts[0], return_tensors="pt", padding=True).input_ids.to(DEVICE)
attention_mask = torch.ones_like(inputs)
attention_mask[inputs == tokenizer.pad_token_id] = 0
model_output = model(input_ids=inputs, attention_mask=attention_mask)
routing_weights = model_output.routing_weights
routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()
offset = len(input_without_response[0])-1
routing_weights = routing_weights[:, offset:-1]
return routing_weights
def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: bool = True):
model_path, base_model, model_class = get_model_path(model_id)
model_config = AutoConfig.from_pretrained(base_model, use_auth_token=hf_token)
parent_path = pathlib.Path(__file__).parent
model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml"
model_config.torch_dtype = torch.bfloat16
model_config.use_bfloat16 = True
model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
model_config.use_cache = use_cache
model_config.ablate = ablations
tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
tokenizer.padding_side = "left"
if "llama" in model_id:
tokenizer.pad_token_id = 128004
if "olmo" in model_id:
tokenizer.pad_token_id = 100277
tokenizer.add_special_tokens({'additional_special_tokens': ['<|assistant|>']})
elif "smollm2" in model_id:
tokenizer.pad_token_id = 2
else:
tokenizer.pad_token_id = 128004
if "olmo" in model_id:
model_config.vocab_size = len(tokenizer)
model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True)
model.to(DEVICE)
model = model.bfloat16()
model.eval()
return model, tokenizer