Luigi commited on
Commit
426163f
·
1 Parent(s): 273acf8

Remove AOT compilation completely and enable use_cache

Browse files
Files changed (1) hide show
  1. app.py +3 -37
app.py CHANGED
@@ -338,7 +338,7 @@ def load_pipeline(model_name):
338
  trust_remote_code=True,
339
  torch_dtype=dtype,
340
  device_map="auto",
341
- use_cache=False, # disable past-key-value caching
342
  token=access_token)
343
  PIPELINES[model_name] = pipe
344
  return pipe
@@ -350,7 +350,8 @@ def load_pipeline(model_name):
350
  model=repo,
351
  tokenizer=tokenizer,
352
  trust_remote_code=True,
353
- device_map="auto"
 
354
  )
355
  PIPELINES[model_name] = pipe
356
  return pipe
@@ -483,41 +484,6 @@ def chat_response(user_msg, chat_history, system_prompt,
483
 
484
  pipe = load_pipeline(model_name)
485
 
486
- # Determine actual model size for AOT decision
487
- actual_params = sum(p.numel() for p in pipe.model.parameters())
488
- model_size_b = actual_params / 1e9 # Convert to billions
489
- use_aot = model_size_b >= 2 # Only compile models >= 2B parameters
490
-
491
- if use_aot:
492
- try:
493
- with spaces.aoti_capture(pipe.model) as call:
494
- pipe("Hello world", max_new_tokens=5, do_sample=False, pad_token_id=pipe.tokenizer.eos_token_id)
495
-
496
- # Define dynamic shapes for variable sequence lengths
497
- seq_dim = torch.export.Dim('seq', min=1, max=4096)
498
- dynamic_shapes = {
499
- 'input_ids': {1: seq_dim} if 'input_ids' in call.kwargs else None,
500
- 'attention_mask': {1: seq_dim} if 'attention_mask' in call.kwargs else None,
501
- 'inputs_embeds': None,
502
- 'use_cache': None,
503
- 'cache_position': {1: seq_dim} if 'cache_position' in call.kwargs or 'position_ids' in call.kwargs else None,
504
- 'kwargs': {k: None for k in call.kwargs if k not in ['input_ids', 'attention_mask', 'inputs_embeds', 'use_cache', 'cache_position', 'position_ids']}
505
- }
506
-
507
- exported = torch.export.export(
508
- pipe.model,
509
- args=call.args,
510
- kwargs=call.kwargs,
511
- dynamic_shapes=dynamic_shapes
512
- )
513
- compiled = spaces.aoti_compile(exported)
514
- spaces.aoti_apply(compiled, pipe.model)
515
- print(f"AOT compilation successful for {model_name} ({model_size_b:.1f}B parameters)")
516
- except Exception as e:
517
- print(f"AOT compilation failed for {model_name}: {e}")
518
- else:
519
- print(f"Skipping AOT compilation for small model {model_name} ({model_size_b:.1f}B parameters)")
520
-
521
  prompt = format_conversation(history, enriched, pipe.tokenizer)
522
  prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
523
  streamer = TextIteratorStreamer(pipe.tokenizer,
 
338
  trust_remote_code=True,
339
  torch_dtype=dtype,
340
  device_map="auto",
341
+ use_cache=True, # Enable past-key-value caching
342
  token=access_token)
343
  PIPELINES[model_name] = pipe
344
  return pipe
 
350
  model=repo,
351
  tokenizer=tokenizer,
352
  trust_remote_code=True,
353
+ device_map="auto",
354
+ use_cache=True
355
  )
356
  PIPELINES[model_name] = pipe
357
  return pipe
 
484
 
485
  pipe = load_pipeline(model_name)
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  prompt = format_conversation(history, enriched, pipe.tokenizer)
488
  prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
489
  streamer = TextIteratorStreamer(pipe.tokenizer,