Luigi commited on
Commit
a7866ff
·
1 Parent(s): fea2910

Add AOT compilation optimization for ZeroGPU acceleration

Browse files
Files changed (1) hide show
  1. app.py +32 -1
app.py CHANGED
@@ -11,6 +11,7 @@ from transformers import pipeline, TextIteratorStreamer
11
  from transformers import AutoTokenizer
12
  from ddgs import DDGS
13
  import spaces # Import spaces early to enable ZeroGPU support
 
14
 
15
  access_token=os.environ['HF_TOKEN']
16
 
@@ -329,7 +330,7 @@ def format_conversation(history, system_prompt, tokenizer):
329
  return prompt
330
 
331
  def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
332
- base_duration = 60
333
  token_duration = max_tokens * 0.1 # Estimate 0.1 seconds per token
334
  search_duration = 30 if enable_search else 0
335
  return base_duration + token_duration + search_duration
@@ -417,6 +418,36 @@ def chat_response(user_msg, chat_history, system_prompt,
417
  enriched = system_prompt
418
 
419
  pipe = load_pipeline(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  prompt = format_conversation(history, enriched, pipe.tokenizer)
421
  prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
422
  streamer = TextIteratorStreamer(pipe.tokenizer,
 
11
  from transformers import AutoTokenizer
12
  from ddgs import DDGS
13
  import spaces # Import spaces early to enable ZeroGPU support
14
+ from torch.utils._pytree import tree_map
15
 
16
  access_token=os.environ['HF_TOKEN']
17
 
 
330
  return prompt
331
 
332
  def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
333
+ base_duration = 120 # Increased for AOT compilation
334
  token_duration = max_tokens * 0.1 # Estimate 0.1 seconds per token
335
  search_duration = 30 if enable_search else 0
336
  return base_duration + token_duration + search_duration
 
418
  enriched = system_prompt
419
 
420
  pipe = load_pipeline(model_name)
421
+
422
+ # AOT compilation for performance optimization
423
+ try:
424
+ with spaces.aoti_capture(pipe.model) as call:
425
+ pipe("Hello world", max_new_tokens=5, do_sample=False, pad_token_id=pipe.tokenizer.eos_token_id)
426
+
427
+ # Define dynamic shapes for variable sequence lengths
428
+ seq_dim = torch.export.Dim('seq', min=1, max=4096)
429
+ dynamic_shapes = tree_map(lambda v: None, call.kwargs)
430
+
431
+ # Set dynamic dimensions for common inputs
432
+ if 'input_ids' in call.kwargs:
433
+ dynamic_shapes['input_ids'] = {1: seq_dim}
434
+ if 'attention_mask' in call.kwargs:
435
+ dynamic_shapes['attention_mask'] = {1: seq_dim}
436
+ if 'position_ids' in call.kwargs:
437
+ dynamic_shapes['position_ids'] = {1: seq_dim}
438
+
439
+ exported = torch.export.export(
440
+ pipe.model,
441
+ args=call.args,
442
+ kwargs=call.kwargs,
443
+ dynamic_shapes=dynamic_shapes
444
+ )
445
+ compiled = spaces.aoti_compile(exported)
446
+ spaces.aoti_apply(compiled, pipe.model)
447
+ print(f"AOT compilation successful for {model_name}")
448
+ except Exception as e:
449
+ print(f"AOT compilation failed for {model_name}: {e}")
450
+
451
  prompt = format_conversation(history, enriched, pipe.tokenizer)
452
  prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
453
  streamer = TextIteratorStreamer(pipe.tokenizer,