Spaces:
Running
Running
Fix AOT compilation dynamic_shapes to match expected arg names for torch.export.export
Browse files
app.py
CHANGED
|
@@ -495,15 +495,14 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 495 |
|
| 496 |
# Define dynamic shapes for variable sequence lengths
|
| 497 |
seq_dim = torch.export.Dim('seq', min=1, max=4096)
|
| 498 |
-
dynamic_shapes =
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
dynamic_shapes['position_ids'] = {1: seq_dim}
|
| 507 |
|
| 508 |
exported = torch.export.export(
|
| 509 |
pipe.model,
|
|
|
|
| 495 |
|
| 496 |
# Define dynamic shapes for variable sequence lengths
|
| 497 |
seq_dim = torch.export.Dim('seq', min=1, max=4096)
|
| 498 |
+
dynamic_shapes = {
|
| 499 |
+
'input_ids': {1: seq_dim} if 'input_ids' in call.kwargs else None,
|
| 500 |
+
'attention_mask': {1: seq_dim} if 'attention_mask' in call.kwargs else None,
|
| 501 |
+
'inputs_embeds': None,
|
| 502 |
+
'use_cache': None,
|
| 503 |
+
'cache_position': {1: seq_dim} if 'cache_position' in call.kwargs or 'position_ids' in call.kwargs else None,
|
| 504 |
+
'kwargs': None
|
| 505 |
+
}
|
|
|
|
| 506 |
|
| 507 |
exported = torch.export.export(
|
| 508 |
pipe.model,
|