Commit
·
4843704
1
Parent(s):
e140f31
jax cache for faster compilation attempt
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import os
|
| 2 |
|
| 3 |
# ---- Space mode gating (place above any JAX import!) ----
|
| 4 |
SPACE_MODE = os.getenv("SPACE_MODE")
|
|
@@ -28,21 +28,33 @@ else:
|
|
| 28 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
| 29 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
try:
|
| 35 |
-
jax.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
except Exception:
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
#
|
| 40 |
try:
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
except Exception:
|
| 44 |
-
|
| 45 |
-
# --------------------------------------------------------------------
|
| 46 |
|
| 47 |
|
| 48 |
|
|
@@ -67,8 +79,6 @@ from one_shot_generation import generate_loop_continuation_with_mrt, generate_st
|
|
| 67 |
|
| 68 |
import uuid, threading
|
| 69 |
|
| 70 |
-
import logging
|
| 71 |
-
|
| 72 |
import gradio as gr
|
| 73 |
from typing import Optional, Union, Literal
|
| 74 |
|
|
@@ -357,15 +367,21 @@ _WARMUP_LOCK = threading.Lock()
|
|
| 357 |
|
| 358 |
def _mrt_warmup():
|
| 359 |
"""
|
| 360 |
-
Build a minimal, bar-aligned silent context and run
|
| 361 |
-
to trigger
|
| 362 |
"""
|
| 363 |
global _WARMED
|
| 364 |
with _WARMUP_LOCK:
|
| 365 |
if _WARMED:
|
| 366 |
return
|
| 367 |
try:
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
# --- derive timing from model config ---
|
| 371 |
codec_fps = float(mrt.codec.frame_rate)
|
|
@@ -406,10 +422,18 @@ def _mrt_warmup():
|
|
| 406 |
state.context_tokens = context_tokens
|
| 407 |
style_vec = mrt.embed_style("warmup")
|
| 408 |
|
| 409 |
-
# ---
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
-
logging.info("MagentaRT warmup complete.")
|
| 413 |
finally:
|
| 414 |
try:
|
| 415 |
os.unlink(tmp_path)
|
|
|
|
| 1 |
+
import logging, os
|
| 2 |
|
| 3 |
# ---- Space mode gating (place above any JAX import!) ----
|
| 4 |
SPACE_MODE = os.getenv("SPACE_MODE")
|
|
|
|
| 28 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
| 29 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
| 30 |
|
| 31 |
+
# --- JAX persistent compilation cache (new + old APIs), plus extra XLA caches ---
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
CACHE_DIR = os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
| 35 |
+
|
| 36 |
+
# Prefer new API (JAX ≥ 0.4.26 / 0.5+), fall back to older initialize_cache
|
| 37 |
try:
|
| 38 |
+
from jax.experimental import compilation_cache as cc # new-style
|
| 39 |
+
if hasattr(cc, "set_cache_dir"):
|
| 40 |
+
cc.set_cache_dir(CACHE_DIR)
|
| 41 |
+
logging.info("JAX persistent cache (set_cache_dir) -> %s", CACHE_DIR)
|
| 42 |
+
else:
|
| 43 |
+
raise ImportError
|
| 44 |
except Exception:
|
| 45 |
+
try:
|
| 46 |
+
from jax.experimental.compilation_cache import compilation_cache as cc_old # old-style
|
| 47 |
+
cc_old.initialize_cache(CACHE_DIR)
|
| 48 |
+
logging.info("JAX persistent cache (initialize_cache) -> %s", CACHE_DIR)
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logging.warning("JAX persistent cache init skipped: %s", e)
|
| 51 |
|
| 52 |
+
# Extra XLA caches piggyback on the persistent cache (best effort)
|
| 53 |
try:
|
| 54 |
+
import jax
|
| 55 |
+
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logging.info("XLA extra caches not enabled: %s", e)
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
|
|
|
|
| 79 |
|
| 80 |
import uuid, threading
|
| 81 |
|
|
|
|
|
|
|
| 82 |
import gradio as gr
|
| 83 |
from typing import Optional, Union, Literal
|
| 84 |
|
|
|
|
| 367 |
|
| 368 |
def _mrt_warmup():
|
| 369 |
"""
|
| 370 |
+
Build a minimal, bar-aligned silent context and run a couple of ~2s generate_chunk
|
| 371 |
+
passes to trigger JIT, fill persistent caches, and run XLA autotune.
|
| 372 |
"""
|
| 373 |
global _WARMED
|
| 374 |
with _WARMUP_LOCK:
|
| 375 |
if _WARMED:
|
| 376 |
return
|
| 377 |
try:
|
| 378 |
+
# Touch JAX backend early (brings up CUDA context etc.)
|
| 379 |
+
try:
|
| 380 |
+
import jax; _ = jax.devices()
|
| 381 |
+
except Exception:
|
| 382 |
+
pass
|
| 383 |
+
|
| 384 |
+
mrt = get_mrt() # will build model and (with our earlier changes) ensure assets if envs are set
|
| 385 |
|
| 386 |
# --- derive timing from model config ---
|
| 387 |
codec_fps = float(mrt.codec.frame_rate)
|
|
|
|
| 422 |
state.context_tokens = context_tokens
|
| 423 |
style_vec = mrt.embed_style("warmup")
|
| 424 |
|
| 425 |
+
# --- prime compiled paths & autotune: run twice ---
|
| 426 |
+
wav1, state = mrt.generate_chunk(state=state, style=style_vec) # compile + autotune
|
| 427 |
+
wav2, _ = mrt.generate_chunk(state=state, style=style_vec) # hit cached executables
|
| 428 |
+
|
| 429 |
+
# Optional sanity: ensure we didn't return all zeros
|
| 430 |
+
try:
|
| 431 |
+
if np.abs(wav2.samples).mean() <= 1e-7:
|
| 432 |
+
logging.warning("Warmup produced near-silence; continuing.")
|
| 433 |
+
except Exception:
|
| 434 |
+
pass
|
| 435 |
|
| 436 |
+
logging.info("MagentaRT warmup complete (persistent cache primed).")
|
| 437 |
finally:
|
| 438 |
try:
|
| 439 |
os.unlink(tmp_path)
|