Commit
·
38f8411
1
Parent(s):
d20c9b9
bug fix
Browse files
models.py
CHANGED
|
@@ -60,7 +60,8 @@ def get_default_layout_map(preset_name, device_mesh):
|
|
| 60 |
patch_key = "decoder_block.*attention.*(query|key|value).kernel"
|
| 61 |
layout_map.pop(patch_key)
|
| 62 |
layout_map[patch_key] = (None, "model", "batch")
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
def log_applied_layout_map(model):
|
|
|
|
| 60 |
patch_key = "decoder_block.*attention.*(query|key|value).kernel"
|
| 61 |
layout_map.pop(patch_key)
|
| 62 |
layout_map[patch_key] = (None, "model", "batch")
|
| 63 |
+
|
| 64 |
+
return layout_map
|
| 65 |
|
| 66 |
|
| 67 |
def log_applied_layout_map(model):
|