Molbap's picture
Molbap HF Staff
big update
347ff85
raw
history blame
842 Bytes
<pre><code class="language-python"># In the model's config (example: ERNIE 4.5-style decoder blocks)
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
# Runtime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "your/model-or-local-checkpoint"
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.bfloat16,
tp_plan=base_model_tp_plan, # <-- plan defined above
)
tok = AutoTokenizer.from_pretrained(model_id)
inputs = tok("Hello", return_tensors="pt").to(model.device)
out = model(**inputs)</code></pre>