Spaces:
Running
Running
Remove AOT compilation completely and enable use_cache
Browse files
app.py
CHANGED
|
@@ -338,7 +338,7 @@ def load_pipeline(model_name):
|
|
| 338 |
trust_remote_code=True,
|
| 339 |
torch_dtype=dtype,
|
| 340 |
device_map="auto",
|
| 341 |
-
use_cache=
|
| 342 |
token=access_token)
|
| 343 |
PIPELINES[model_name] = pipe
|
| 344 |
return pipe
|
|
@@ -350,7 +350,8 @@ def load_pipeline(model_name):
|
|
| 350 |
model=repo,
|
| 351 |
tokenizer=tokenizer,
|
| 352 |
trust_remote_code=True,
|
| 353 |
-
device_map="auto"
|
|
|
|
| 354 |
)
|
| 355 |
PIPELINES[model_name] = pipe
|
| 356 |
return pipe
|
|
@@ -483,41 +484,6 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 483 |
|
| 484 |
pipe = load_pipeline(model_name)
|
| 485 |
|
| 486 |
-
# Determine actual model size for AOT decision
|
| 487 |
-
actual_params = sum(p.numel() for p in pipe.model.parameters())
|
| 488 |
-
model_size_b = actual_params / 1e9 # Convert to billions
|
| 489 |
-
use_aot = model_size_b >= 2 # Only compile models >= 2B parameters
|
| 490 |
-
|
| 491 |
-
if use_aot:
|
| 492 |
-
try:
|
| 493 |
-
with spaces.aoti_capture(pipe.model) as call:
|
| 494 |
-
pipe("Hello world", max_new_tokens=5, do_sample=False, pad_token_id=pipe.tokenizer.eos_token_id)
|
| 495 |
-
|
| 496 |
-
# Define dynamic shapes for variable sequence lengths
|
| 497 |
-
seq_dim = torch.export.Dim('seq', min=1, max=4096)
|
| 498 |
-
dynamic_shapes = {
|
| 499 |
-
'input_ids': {1: seq_dim} if 'input_ids' in call.kwargs else None,
|
| 500 |
-
'attention_mask': {1: seq_dim} if 'attention_mask' in call.kwargs else None,
|
| 501 |
-
'inputs_embeds': None,
|
| 502 |
-
'use_cache': None,
|
| 503 |
-
'cache_position': {1: seq_dim} if 'cache_position' in call.kwargs or 'position_ids' in call.kwargs else None,
|
| 504 |
-
'kwargs': {k: None for k in call.kwargs if k not in ['input_ids', 'attention_mask', 'inputs_embeds', 'use_cache', 'cache_position', 'position_ids']}
|
| 505 |
-
}
|
| 506 |
-
|
| 507 |
-
exported = torch.export.export(
|
| 508 |
-
pipe.model,
|
| 509 |
-
args=call.args,
|
| 510 |
-
kwargs=call.kwargs,
|
| 511 |
-
dynamic_shapes=dynamic_shapes
|
| 512 |
-
)
|
| 513 |
-
compiled = spaces.aoti_compile(exported)
|
| 514 |
-
spaces.aoti_apply(compiled, pipe.model)
|
| 515 |
-
print(f"AOT compilation successful for {model_name} ({model_size_b:.1f}B parameters)")
|
| 516 |
-
except Exception as e:
|
| 517 |
-
print(f"AOT compilation failed for {model_name}: {e}")
|
| 518 |
-
else:
|
| 519 |
-
print(f"Skipping AOT compilation for small model {model_name} ({model_size_b:.1f}B parameters)")
|
| 520 |
-
|
| 521 |
prompt = format_conversation(history, enriched, pipe.tokenizer)
|
| 522 |
prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
|
| 523 |
streamer = TextIteratorStreamer(pipe.tokenizer,
|
|
|
|
| 338 |
trust_remote_code=True,
|
| 339 |
torch_dtype=dtype,
|
| 340 |
device_map="auto",
|
| 341 |
+
use_cache=True, # Enable past-key-value caching
|
| 342 |
token=access_token)
|
| 343 |
PIPELINES[model_name] = pipe
|
| 344 |
return pipe
|
|
|
|
| 350 |
model=repo,
|
| 351 |
tokenizer=tokenizer,
|
| 352 |
trust_remote_code=True,
|
| 353 |
+
device_map="auto",
|
| 354 |
+
use_cache=True
|
| 355 |
)
|
| 356 |
PIPELINES[model_name] = pipe
|
| 357 |
return pipe
|
|
|
|
| 484 |
|
| 485 |
pipe = load_pipeline(model_name)
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
prompt = format_conversation(history, enriched, pipe.tokenizer)
|
| 488 |
prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
|
| 489 |
streamer = TextIteratorStreamer(pipe.tokenizer,
|