| <pre><code class="language-python"># In the model's config (example: ERNIE 4.5-style decoder blocks) | |
| base_model_tp_plan = { | |
| "layers.*.self_attn.q_proj": "colwise", | |
| "layers.*.self_attn.k_proj": "colwise", | |
| "layers.*.self_attn.v_proj": "colwise", | |
| "layers.*.self_attn.o_proj": "rowwise", | |
| "layers.*.mlp.gate_proj": "colwise", | |
| "layers.*.mlp.up_proj": "colwise", | |
| "layers.*.mlp.down_proj": "rowwise", | |
| } | |
| # Runtime | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model_id = "your/model-or-local-checkpoint" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| dtype=torch.bfloat16, | |
| tp_plan=base_model_tp_plan, # <-- plan defined above | |
| ) | |
| tok = AutoTokenizer.from_pretrained(model_id) | |
| inputs = tok("Hello", return_tensors="pt").to(model.device) | |
| out = model(**inputs)</code></pre> |