Commit
·
1365804
1
Parent(s):
40912b5
layout_map patch for gemma-2b-it-keras
Browse files
models.py
CHANGED
|
@@ -40,11 +40,27 @@ def get_default_layout_map(preset_name, device_mesh):
|
|
| 40 |
or "vicuna" in preset_name
|
| 41 |
):
|
| 42 |
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
|
|
|
|
| 43 |
# This line is missing for some Llama models (TODO: fix this in keras_hub)
|
| 44 |
layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
|
| 45 |
return layout_map
|
|
|
|
| 46 |
elif "gemma" in preset_name:
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def log_applied_layout_map(model):
|
|
|
|
| 40 |
or "vicuna" in preset_name
|
| 41 |
):
|
| 42 |
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
|
| 43 |
+
# Default layout map patch:
|
| 44 |
# This line is missing for some Llama models (TODO: fix this in keras_hub)
|
| 45 |
layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
|
| 46 |
return layout_map
|
| 47 |
+
|
| 48 |
elif "gemma" in preset_name:
|
| 49 |
+
layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
|
| 50 |
+
|
| 51 |
+
if "gemma-2b-" in preset_name:
|
| 52 |
+
# Default layout map patch:
|
| 53 |
+
# Gemma QKV weigts are shaped [NB_HEADS, EMBED_DIM, INNER_DIM]
|
| 54 |
+
# Llama QKV weights are shaped [EMBED_DIM, NB_HEADS, INNER_DIM]
|
| 55 |
+
# However:
|
| 56 |
+
# The default layout map for KQV weights on Gemma is: (model_dim,data_dim,None)
|
| 57 |
+
# Which means sharding NB_HEADS on the "model" dimension.
|
| 58 |
+
# But gemma-2b-it-keras has only 1 head so this won't work: must patch it
|
| 59 |
+
# TODO: fix this in the Gemma layout map in Keras hub.
|
| 60 |
+
patch_key = "decoder_block.*attention.*(query|key|value).kernel"
|
| 61 |
+
layout_map.pop(patch_key)
|
| 62 |
+
layout_map[patch_key] = (None, "model", "batch")
|
| 63 |
+
return layout_map
|
| 64 |
|
| 65 |
|
| 66 |
def log_applied_layout_map(model):
|