Spaces:
Running
Running
| # router_backend.py | |
| """ | |
| Plug your real model routing function here. | |
| Implement the function: | |
| get_expert_routing(model_id: str, prompt: str) -> list[float] | dict[str, float] | tuple[float, float, float, float] | |
| It must return 4 values (percentages) corresponding to the experts: | |
| ["Language", "Logic", "Social", "World"] | |
| Example return formats: | |
| - [12.5, 45.0, 22.5, 20.0] | |
| - {"Language": 12.5, "Logic": 45.0, "Social": 22.5, "World": 20.0} | |
| - (12.5, 45.0, 22.5, 20.0) | |
| """ | |
| import torch | |
| import pathlib | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
| from typing import Union, Dict, List, Tuple | |
| from models.micro_olmo import MiCRoOLMo | |
| from models.micro_llama import MiCRoLlama | |
| from models.micro_moe_llama import MiCRoLlamaMoE | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]], ablations: List[str] = None) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]: | |
| model, tokenizer = build_model(model_id, hf_token, ablations=ablations) | |
| if isinstance(prompt, str): | |
| generation, routing_weights = generate_continuation(model, tokenizer, prompt) | |
| generation = generation[0] if type(generation) is list else generation | |
| elif isinstance(prompt, list): | |
| generation = None | |
| routing_weights = get_routing_weights(model, tokenizer, [prompt]) | |
| model_routing_percentages = aggregate_routing_weights(routing_weights)[0] | |
| print(model_routing_percentages) | |
| if generation is not None: | |
| print(f"Generation:\n{generation}") | |
| return { | |
| "Language": float(model_routing_percentages[3]), | |
| "Logic": float(model_routing_percentages[0]), | |
| "Social": float(model_routing_percentages[1]), | |
| "World": float(model_routing_percentages[2]), | |
| }, generation | |
| def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]: | |
| return { | |
| # MiCRo-Llama | |
| "micro-llama-1b": ("bkhmsi/micro-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama), | |
| "micro-llama-3b": ("bkhmsi/micro-llama-3b", "meta-llama/Llama-3.2-3B-Instruct", MiCRoLlama), | |
| "micro-llama-1b-dpo": ("bkhmsi/micro-llama-1b-dpo", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama), | |
| # MiCRo-MoE-Llama | |
| "micro-moe-llama-1b": ("bkhmsi/micro-moe-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlamaMoE), | |
| # MiCRo-OLMo | |
| "micro-olmo": ("bkhmsi/micro-olmo-1b", "allenai/OLMo-2-0425-1B-Instruct", MiCRoOLMo), | |
| # MiCRo-SmolLM2 | |
| "micro-smollm2-135m": ("bkhmsi/micro-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlama), | |
| "micro-smollm2-360m": ("bkhmsi/micro-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlama), | |
| # MiCRo-MoE-SmolLM2 | |
| "micro-moe-smollm2-135m": ("bkhmsi/micro-moe-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlamaMoE), | |
| "micro-moe-smollm2-360m": ("bkhmsi/micro-moe-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlamaMoE), | |
| }.get(model_name, (model_name, model_name, AutoModelForCausalLM)) | |
| def aggregate_routing_weights(routing_weights): | |
| experts = ["Logic", "Social", "World", "Language"] | |
| expert_token_model = np.zeros((len(experts)), dtype=int) | |
| expert_layer_token = np.zeros((routing_weights.shape[0], len(experts)), dtype=int) | |
| num_layers = routing_weights.shape[0] | |
| for layer_idx in range(num_layers): | |
| for token_idx in range(len(routing_weights[layer_idx])): | |
| expert_idx = routing_weights[layer_idx][token_idx].argmax() | |
| if layer_idx >= 2 and layer_idx < num_layers - 2: | |
| expert_token_model[expert_idx] += 1 | |
| expert_layer_token[layer_idx][expert_idx] += 1 | |
| return expert_token_model, expert_layer_token | |
| def generate_continuation(model, | |
| tokenizer, | |
| prompts, | |
| max_tokens=128, | |
| use_cache=True, | |
| return_routing_weights=True | |
| ): | |
| if isinstance(prompts, str): | |
| prompts = [{"role": "user", "content": prompts}] | |
| tokenizer.padding_side = "left" | |
| inputs = tokenizer.apply_chat_template([ | |
| prompt for prompt in prompts | |
| ], return_tensors="pt", padding=True, add_generation_prompt=True).to(DEVICE) | |
| attention_mask = torch.ones_like(inputs) | |
| attention_mask[inputs == tokenizer.pad_token_id] = 0 | |
| outputs = model.generate( | |
| input_ids=inputs, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_tokens, | |
| use_cache=use_cache, | |
| stop_strings=["</s>","<|eot_id|>", "<|im_start|>user", "user"], | |
| tokenizer=tokenizer, | |
| pad_token_id=tokenizer.pad_token_id, | |
| temperature=0, | |
| top_p=1.0, | |
| do_sample=False, | |
| ) | |
| if return_routing_weights: | |
| attention_mask = torch.ones_like(outputs) | |
| attention_mask[outputs == tokenizer.pad_token_id] = 0 | |
| model_output = model(input_ids=outputs, attention_mask=attention_mask) | |
| torch.cuda.empty_cache() | |
| routing_weights = model_output.routing_weights | |
| routing_weights = np.concatenate([ | |
| F.softmax(rw, dim=-1)[:, inputs.shape[1]:].detach().float().cpu().numpy() | |
| for rw in routing_weights | |
| ]) | |
| else: | |
| routing_weights = None | |
| inputs_text = tokenizer.batch_decode(inputs, skip_special_tokens=False) | |
| generations = [] | |
| for i, output in enumerate(outputs): | |
| decoded_output = tokenizer.decode(output, skip_special_tokens=False) | |
| decoded_output = decoded_output.replace(inputs_text[i], "") | |
| decoded_output = decoded_output.replace(tokenizer.pad_token, "").strip() | |
| decoded_output = decoded_output.replace("<|end_of_text|>", "").strip() | |
| decoded_output = decoded_output.replace("<|endoftext|>", "").strip() | |
| decoded_output = decoded_output.replace("<|eot_id|>", "").strip() | |
| decoded_output = decoded_output.replace("\n<|im_start|>user", "").strip() | |
| generations.append(decoded_output) | |
| return (generations, routing_weights) if return_routing_weights else generations | |
| def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True): | |
| """ | |
| Get routing weights for the given prompts using the model. | |
| Args: | |
| model: The MiCRoLlama or MiCRoOLMo model. | |
| tokenizer: The tokenizer for the model. | |
| prompts: A string or list of dictionaries containing the prompts. | |
| Returns: | |
| routing_weights: A list of routing weights for each layer. | |
| """ | |
| tokenizer.padding_side = "left" | |
| if apply_chat_template: | |
| if isinstance(prompts, str): | |
| prompts = [{"role": "user", "content": prompts}] | |
| inputs = tokenizer.apply_chat_template([ | |
| prompt for prompt in prompts | |
| ], return_tensors="pt", padding=True).to(DEVICE) | |
| input_without_response = tokenizer.apply_chat_template([ | |
| prompt[:-1] for prompt in prompts | |
| ], return_tensors="pt", padding=True, | |
| ).to(DEVICE) | |
| else: | |
| inputs = tokenizer(prompts[0] + prompts[1], return_tensors="pt", padding=True).input_ids.to(DEVICE) | |
| input_without_response = tokenizer(prompts[0], return_tensors="pt", padding=True).input_ids.to(DEVICE) | |
| attention_mask = torch.ones_like(inputs) | |
| attention_mask[inputs == tokenizer.pad_token_id] = 0 | |
| model_output = model(input_ids=inputs, attention_mask=attention_mask) | |
| routing_weights = model_output.routing_weights | |
| routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze() | |
| offset = len(input_without_response[0])-1 | |
| routing_weights = routing_weights[:, offset:-1] | |
| return routing_weights | |
| def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: bool = True): | |
| model_path, base_model, model_class = get_model_path(model_id) | |
| model_config = AutoConfig.from_pretrained(base_model, use_auth_token=hf_token) | |
| parent_path = pathlib.Path(__file__).parent | |
| model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml" | |
| model_config.torch_dtype = torch.bfloat16 | |
| model_config.use_bfloat16 = True | |
| model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager} | |
| model_config.use_cache = use_cache | |
| model_config.ablate = ablations | |
| tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token) | |
| tokenizer.padding_side = "left" | |
| if "llama" in model_id: | |
| tokenizer.pad_token_id = 128004 | |
| if "olmo" in model_id: | |
| tokenizer.pad_token_id = 100277 | |
| tokenizer.add_special_tokens({'additional_special_tokens': ['<|assistant|>']}) | |
| elif "smollm2" in model_id: | |
| tokenizer.pad_token_id = 2 | |
| else: | |
| tokenizer.pad_token_id = 128004 | |
| if "olmo" in model_id: | |
| model_config.vocab_size = len(tokenizer) | |
| model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True) | |
| model.to(DEVICE) | |
| model = model.bfloat16() | |
| model.eval() | |
| return model, tokenizer |