Luigi commited on
Commit
0a99dfc
·
1 Parent(s): 4d2c362

Fix AOT compilation dynamic_shapes to match expected arg names for torch.export.export

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -495,15 +495,14 @@ def chat_response(user_msg, chat_history, system_prompt,
495
 
496
  # Define dynamic shapes for variable sequence lengths
497
  seq_dim = torch.export.Dim('seq', min=1, max=4096)
498
- dynamic_shapes = tree_map(lambda v: None, call.kwargs)
499
-
500
- # Set dynamic dimensions for common inputs
501
- if 'input_ids' in call.kwargs:
502
- dynamic_shapes['input_ids'] = {1: seq_dim}
503
- if 'attention_mask' in call.kwargs:
504
- dynamic_shapes['attention_mask'] = {1: seq_dim}
505
- if 'position_ids' in call.kwargs:
506
- dynamic_shapes['position_ids'] = {1: seq_dim}
507
 
508
  exported = torch.export.export(
509
  pipe.model,
 
495
 
496
  # Define dynamic shapes for variable sequence lengths
497
  seq_dim = torch.export.Dim('seq', min=1, max=4096)
498
+ dynamic_shapes = {
499
+ 'input_ids': {1: seq_dim} if 'input_ids' in call.kwargs else None,
500
+ 'attention_mask': {1: seq_dim} if 'attention_mask' in call.kwargs else None,
501
+ 'inputs_embeds': None,
502
+ 'use_cache': None,
503
+ 'cache_position': {1: seq_dim} if 'cache_position' in call.kwargs or 'position_ids' in call.kwargs else None,
504
+ 'kwargs': None
505
+ }
 
506
 
507
  exported = torch.export.export(
508
  pipe.model,