Luigi commited on
Commit
4500f92
·
1 Parent(s): a7866ff

Make AOT compilation conditional for models >= 2B parameters to optimize free tier usage

Browse files
Files changed (1) hide show
  1. app.py +99 -31
app.py CHANGED
@@ -330,10 +330,44 @@ def format_conversation(history, system_prompt, tokenizer):
330
  return prompt
331
 
332
  def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
333
- base_duration = 120 # Increased for AOT compilation
334
- token_duration = max_tokens * 0.1 # Estimate 0.1 seconds per token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  search_duration = 30 if enable_search else 0
336
- return base_duration + token_duration + search_duration
 
 
337
 
338
  @spaces.GPU(duration=get_duration)
339
  def chat_response(user_msg, chat_history, system_prompt,
@@ -419,34 +453,68 @@ def chat_response(user_msg, chat_history, system_prompt,
419
 
420
  pipe = load_pipeline(model_name)
421
 
422
- # AOT compilation for performance optimization
423
- try:
424
- with spaces.aoti_capture(pipe.model) as call:
425
- pipe("Hello world", max_new_tokens=5, do_sample=False, pad_token_id=pipe.tokenizer.eos_token_id)
426
-
427
- # Define dynamic shapes for variable sequence lengths
428
- seq_dim = torch.export.Dim('seq', min=1, max=4096)
429
- dynamic_shapes = tree_map(lambda v: None, call.kwargs)
430
-
431
- # Set dynamic dimensions for common inputs
432
- if 'input_ids' in call.kwargs:
433
- dynamic_shapes['input_ids'] = {1: seq_dim}
434
- if 'attention_mask' in call.kwargs:
435
- dynamic_shapes['attention_mask'] = {1: seq_dim}
436
- if 'position_ids' in call.kwargs:
437
- dynamic_shapes['position_ids'] = {1: seq_dim}
438
-
439
- exported = torch.export.export(
440
- pipe.model,
441
- args=call.args,
442
- kwargs=call.kwargs,
443
- dynamic_shapes=dynamic_shapes
444
- )
445
- compiled = spaces.aoti_compile(exported)
446
- spaces.aoti_apply(compiled, pipe.model)
447
- print(f"AOT compilation successful for {model_name}")
448
- except Exception as e:
449
- print(f"AOT compilation failed for {model_name}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  prompt = format_conversation(history, enriched, pipe.tokenizer)
452
  prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
 
330
  return prompt
331
 
332
  def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
333
+ # Estimate model size (rough approximation based on model name)
334
+ model_size = 0
335
+ if '30B' in model_name or '32B' in model_name:
336
+ model_size = 30
337
+ elif '20B' in model_name:
338
+ model_size = 20
339
+ elif '15B' in model_name or '14B' in model_name:
340
+ model_size = 15
341
+ elif '4B' in model_name or '3B' in model_name:
342
+ model_size = 4
343
+ elif '2B' in model_name or '1.7B' in model_name:
344
+ model_size = 2
345
+ elif '1.5B' in model_name or '1.2B' in model_name or '1.1B' in model_name:
346
+ model_size = 1.5
347
+ elif '1B' in model_name:
348
+ model_size = 1
349
+ elif '700M' in model_name or '600M' in model_name:
350
+ model_size = 0.7
351
+ elif '500M' in model_name:
352
+ model_size = 0.5
353
+ elif '360M' in model_name or '350M' in model_name:
354
+ model_size = 0.35
355
+ elif '270M' in model_name:
356
+ model_size = 0.27
357
+ elif '135M' in model_name:
358
+ model_size = 0.135
359
+ else:
360
+ model_size = 4 # default
361
+
362
+ # Only use AOT for models >= 2B parameters
363
+ use_aot = model_size >= 2
364
+
365
+ base_duration = 60 if not use_aot else 120 # Shorter base for non-AOT
366
+ token_duration = max_tokens * 0.1
367
  search_duration = 30 if enable_search else 0
368
+ aot_compilation_buffer = 60 if use_aot else 0 # Extra time for compilation
369
+
370
+ return base_duration + token_duration + search_duration + aot_compilation_buffer
371
 
372
  @spaces.GPU(duration=get_duration)
373
  def chat_response(user_msg, chat_history, system_prompt,
 
453
 
454
  pipe = load_pipeline(model_name)
455
 
456
+ # AOT compilation for performance optimization (only for larger models)
457
+ # Estimate model size
458
+ model_size = 0
459
+ if '30B' in model_name or '32B' in model_name:
460
+ model_size = 30
461
+ elif '20B' in model_name:
462
+ model_size = 20
463
+ elif '15B' in model_name or '14B' in model_name:
464
+ model_size = 15
465
+ elif '4B' in model_name or '3B' in model_name:
466
+ model_size = 4
467
+ elif '2B' in model_name or '1.7B' in model_name:
468
+ model_size = 2
469
+ elif '1.5B' in model_name or '1.2B' in model_name or '1.1B' in model_name:
470
+ model_size = 1.5
471
+ elif '1B' in model_name:
472
+ model_size = 1
473
+ elif '700M' in model_name or '600M' in model_name:
474
+ model_size = 0.7
475
+ elif '500M' in model_name:
476
+ model_size = 0.5
477
+ elif '360M' in model_name or '350M' in model_name:
478
+ model_size = 0.35
479
+ elif '270M' in model_name:
480
+ model_size = 0.27
481
+ elif '135M' in model_name:
482
+ model_size = 0.135
483
+ else:
484
+ model_size = 4 # default
485
+
486
+ use_aot = model_size >= 2 # Only compile models >= 2B parameters
487
+
488
+ if use_aot:
489
+ try:
490
+ with spaces.aoti_capture(pipe.model) as call:
491
+ pipe("Hello world", max_new_tokens=5, do_sample=False, pad_token_id=pipe.tokenizer.eos_token_id)
492
+
493
+ # Define dynamic shapes for variable sequence lengths
494
+ seq_dim = torch.export.Dim('seq', min=1, max=4096)
495
+ dynamic_shapes = tree_map(lambda v: None, call.kwargs)
496
+
497
+ # Set dynamic dimensions for common inputs
498
+ if 'input_ids' in call.kwargs:
499
+ dynamic_shapes['input_ids'] = {1: seq_dim}
500
+ if 'attention_mask' in call.kwargs:
501
+ dynamic_shapes['attention_mask'] = {1: seq_dim}
502
+ if 'position_ids' in call.kwargs:
503
+ dynamic_shapes['position_ids'] = {1: seq_dim}
504
+
505
+ exported = torch.export.export(
506
+ pipe.model,
507
+ args=call.args,
508
+ kwargs=call.kwargs,
509
+ dynamic_shapes=dynamic_shapes
510
+ )
511
+ compiled = spaces.aoti_compile(exported)
512
+ spaces.aoti_apply(compiled, pipe.model)
513
+ print(f"AOT compilation successful for {model_name}")
514
+ except Exception as e:
515
+ print(f"AOT compilation failed for {model_name}: {e}")
516
+ else:
517
+ print(f"Skipping AOT compilation for small model {model_name}")
518
 
519
  prompt = format_conversation(history, enriched, pipe.tokenizer)
520
  prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"