File size: 9,165 Bytes
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3768f67
582ea12
 
 
 
 
 
 
 
 
 
 
c0742fe
582ea12
c0742fe
582ea12
 
 
4e8105c
8730f5f
582ea12
 
 
4e82a89
 
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8730f5f
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f411d7
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0742fe
582ea12
 
 
 
3768f67
 
 
5c65b5e
582ea12
 
 
9a48e97
582ea12
c0742fe
582ea12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# router_backend.py
"""
Plug your real model routing function here.

Implement the function:
    get_expert_routing(model_id: str, prompt: str) -> list[float] | dict[str, float] | tuple[float, float, float, float]

It must return 4 values (percentages) corresponding to the experts:
["Language", "Logic", "Social", "World"]

Example return formats:
- [12.5, 45.0, 22.5, 20.0]
- {"Language": 12.5, "Logic": 45.0, "Social": 22.5, "World": 20.0}
- (12.5, 45.0, 22.5, 20.0)
"""
import torch
import pathlib
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from typing import Union, Dict, List, Tuple

from models.micro_olmo import MiCRoOLMo
from models.micro_llama import MiCRoLlama
from models.micro_moe_llama import MiCRoLlamaMoE

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_expert_routing(model_id: str, hf_token: str, prompt: Union[str, List[Dict[str, str]]], ablations: List[str] = None) -> Union[List[float], Dict[str, float], Tuple[float, float, float, float]]:

    model, tokenizer = build_model(model_id, hf_token, ablations=ablations)

    if isinstance(prompt, str):
        generation, routing_weights = generate_continuation(model, tokenizer, prompt)
        generation = generation[0] if type(generation) is list else generation
    elif isinstance(prompt, list):
        generation = None
        routing_weights = get_routing_weights(model, tokenizer, [prompt])

    model_routing_percentages = aggregate_routing_weights(routing_weights)[0]
    print(model_routing_percentages)

    if generation is not None:
        print(f"Generation:\n{generation}")
    
    return {
        "Language": float(model_routing_percentages[3]),
        "Logic": float(model_routing_percentages[0]),
        "Social": float(model_routing_percentages[1]),
        "World": float(model_routing_percentages[2]),
    }, generation

def get_model_path(model_name: str) -> Tuple[str, str, AutoModelForCausalLM]:
    return {
        # MiCRo-Llama
        "micro-llama-1b": ("bkhmsi/micro-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),
        "micro-llama-3b": ("bkhmsi/micro-llama-3b", "meta-llama/Llama-3.2-3B-Instruct", MiCRoLlama),
        "micro-llama-1b-dpo": ("bkhmsi/micro-llama-1b-dpo", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlama),

        # MiCRo-MoE-Llama
        "micro-moe-llama-1b": ("bkhmsi/micro-moe-llama-1b", "meta-llama/Llama-3.2-1B-Instruct", MiCRoLlamaMoE),
        
        # MiCRo-OLMo
        "micro-olmo": ("bkhmsi/micro-olmo-1b", "allenai/OLMo-2-0425-1B-Instruct", MiCRoOLMo),

        # MiCRo-SmolLM2
        "micro-smollm2-135m": ("bkhmsi/micro-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlama),
        "micro-smollm2-360m": ("bkhmsi/micro-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlama),

        # MiCRo-MoE-SmolLM2
        "micro-moe-smollm2-135m": ("bkhmsi/micro-moe-smollm2-135m", "HuggingFaceTB/SmolLM2-135M-Instruct", MiCRoLlamaMoE),
        "micro-moe-smollm2-360m": ("bkhmsi/micro-moe-smollm2-360m", "HuggingFaceTB/SmolLM2-360M-Instruct", MiCRoLlamaMoE),
    }.get(model_name, (model_name, model_name, AutoModelForCausalLM))

def aggregate_routing_weights(routing_weights):
    experts = ["Logic", "Social", "World", "Language"]
    expert_token_model = np.zeros((len(experts)), dtype=int)
    expert_layer_token = np.zeros((routing_weights.shape[0], len(experts)), dtype=int)
    num_layers = routing_weights.shape[0]

    for layer_idx in range(num_layers):
        for token_idx in range(len(routing_weights[layer_idx])):
            expert_idx = routing_weights[layer_idx][token_idx].argmax()
            if layer_idx >= 2 and layer_idx < num_layers - 2:
                expert_token_model[expert_idx] += 1
            expert_layer_token[layer_idx][expert_idx] += 1
    return expert_token_model, expert_layer_token

def generate_continuation(model, 
    tokenizer, 
    prompts, 
    max_tokens=128,
    use_cache=True, 
    return_routing_weights=True
):

    if isinstance(prompts, str):
        prompts = [{"role": "user", "content": prompts}]

    tokenizer.padding_side = "left"
    inputs = tokenizer.apply_chat_template([
        prompt for prompt in prompts
    ], return_tensors="pt", padding=True, add_generation_prompt=True).to(DEVICE)

    attention_mask = torch.ones_like(inputs)
    attention_mask[inputs == tokenizer.pad_token_id] = 0

    outputs = model.generate(
        input_ids=inputs,
        attention_mask=attention_mask, 
        max_new_tokens=max_tokens,
        use_cache=use_cache,
        stop_strings=["</s>","<|eot_id|>", "<|im_start|>user", "user"],
        tokenizer=tokenizer,
        pad_token_id=tokenizer.pad_token_id,
        temperature=0,
        top_p=1.0,
        do_sample=False,
    )
    
    if return_routing_weights:
        attention_mask = torch.ones_like(outputs)
        attention_mask[outputs == tokenizer.pad_token_id] = 0
        model_output = model(input_ids=outputs, attention_mask=attention_mask)
        torch.cuda.empty_cache()

        routing_weights = model_output.routing_weights        
        routing_weights = np.concatenate([
            F.softmax(rw, dim=-1)[:, inputs.shape[1]:].detach().float().cpu().numpy() 
            for rw in routing_weights
        ])
        
    else:
        routing_weights = None

    inputs_text = tokenizer.batch_decode(inputs, skip_special_tokens=False)

    generations = []
    for i, output in enumerate(outputs):
        decoded_output = tokenizer.decode(output, skip_special_tokens=False)
        decoded_output = decoded_output.replace(inputs_text[i], "")
        decoded_output = decoded_output.replace(tokenizer.pad_token, "").strip()
        decoded_output = decoded_output.replace("<|end_of_text|>", "").strip()
        decoded_output = decoded_output.replace("<|endoftext|>", "").strip()
        decoded_output = decoded_output.replace("<|eot_id|>", "").strip()
        decoded_output = decoded_output.replace("\n<|im_start|>user", "").strip()
        generations.append(decoded_output)

    return (generations, routing_weights) if return_routing_weights else generations

def get_routing_weights(model, tokenizer, prompts, apply_chat_template=True):
    """
    Get routing weights for the given prompts using the model.
    Args:
        model: The MiCRoLlama or MiCRoOLMo model.
        tokenizer: The tokenizer for the model.
        prompts: A string or list of dictionaries containing the prompts.
    Returns:
        routing_weights: A list of routing weights for each layer.
    """

    tokenizer.padding_side = "left"
    if apply_chat_template:
        if isinstance(prompts, str):
            prompts = [{"role": "user", "content": prompts}]

        inputs = tokenizer.apply_chat_template([
            prompt for prompt in prompts
        ], return_tensors="pt", padding=True).to(DEVICE)

        input_without_response = tokenizer.apply_chat_template([
                prompt[:-1] for prompt in prompts
            ], return_tensors="pt", padding=True,
        ).to(DEVICE)
    else:
        inputs = tokenizer(prompts[0] + prompts[1], return_tensors="pt", padding=True).input_ids.to(DEVICE)
        input_without_response = tokenizer(prompts[0], return_tensors="pt", padding=True).input_ids.to(DEVICE)

    attention_mask = torch.ones_like(inputs)
    attention_mask[inputs == tokenizer.pad_token_id] = 0

    model_output = model(input_ids=inputs, attention_mask=attention_mask)

    routing_weights = model_output.routing_weights   
    routing_weights = np.stack([F.softmax(rw, dim=-1).detach().float().cpu().numpy() for rw in routing_weights], axis=0).squeeze()

    offset = len(input_without_response[0])-1
    routing_weights = routing_weights[:, offset:-1]

    return routing_weights

def build_model(model_id: str, hf_token: str, ablations: List[str], use_cache: bool = True):

    model_path, base_model, model_class = get_model_path(model_id)

    model_config = AutoConfig.from_pretrained(base_model, use_auth_token=hf_token)

    parent_path = pathlib.Path(__file__).parent

    model_config.config_path = f"{parent_path}/configs/{model_id.replace('-', '_')}.yml"

    model_config.torch_dtype = torch.bfloat16
    model_config.use_bfloat16 = True
    model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
    model_config.use_cache = use_cache
    model_config.ablate = ablations

    tokenizer = AutoTokenizer.from_pretrained(base_model, use_auth_token=hf_token)
    tokenizer.padding_side = "left"

    if "llama" in model_id:
        tokenizer.pad_token_id = 128004
    if "olmo" in model_id:
        tokenizer.pad_token_id = 100277
        tokenizer.add_special_tokens({'additional_special_tokens': ['<|assistant|>']})
    elif "smollm2" in model_id:
        tokenizer.pad_token_id = 2
    else:
        tokenizer.pad_token_id = 128004

    if "olmo" in model_id:
        model_config.vocab_size = len(tokenizer)

    model = model_class.from_pretrained(model_path, config=model_config, low_cpu_mem_usage=True)

    model.to(DEVICE)
    model = model.bfloat16()
    model.eval()
    return model, tokenizer