Commit
·
8381f2e
1
Parent(s):
8999b96
reverting jax cache changes
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
|
| 3 |
# ---- Space mode gating (place above any JAX import!) ----
|
| 4 |
SPACE_MODE = os.getenv("SPACE_MODE")
|
|
@@ -28,33 +28,21 @@ else:
|
|
| 28 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
| 29 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
CACHE_DIR = os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
| 35 |
-
|
| 36 |
-
# Prefer new API (JAX ≥ 0.4.26 / 0.5+), fall back to older initialize_cache
|
| 37 |
try:
|
| 38 |
-
|
| 39 |
-
if hasattr(cc, "set_cache_dir"):
|
| 40 |
-
cc.set_cache_dir(CACHE_DIR)
|
| 41 |
-
logging.info("JAX persistent cache (set_cache_dir) -> %s", CACHE_DIR)
|
| 42 |
-
else:
|
| 43 |
-
raise ImportError
|
| 44 |
except Exception:
|
| 45 |
-
|
| 46 |
-
from jax.experimental.compilation_cache import compilation_cache as cc_old # old-style
|
| 47 |
-
cc_old.initialize_cache(CACHE_DIR)
|
| 48 |
-
logging.info("JAX persistent cache (initialize_cache) -> %s", CACHE_DIR)
|
| 49 |
-
except Exception as e:
|
| 50 |
-
logging.warning("JAX persistent cache init skipped: %s", e)
|
| 51 |
|
| 52 |
-
#
|
| 53 |
try:
|
| 54 |
-
import
|
| 55 |
-
|
| 56 |
-
except Exception
|
| 57 |
-
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
|
|
@@ -79,6 +67,8 @@ from one_shot_generation import generate_loop_continuation_with_mrt, generate_st
|
|
| 79 |
|
| 80 |
import uuid, threading
|
| 81 |
|
|
|
|
|
|
|
| 82 |
import gradio as gr
|
| 83 |
from typing import Optional, Union, Literal
|
| 84 |
|
|
@@ -367,21 +357,15 @@ _WARMUP_LOCK = threading.Lock()
|
|
| 367 |
|
| 368 |
def _mrt_warmup():
|
| 369 |
"""
|
| 370 |
-
Build a minimal, bar-aligned silent context and run
|
| 371 |
-
|
| 372 |
"""
|
| 373 |
global _WARMED
|
| 374 |
with _WARMUP_LOCK:
|
| 375 |
if _WARMED:
|
| 376 |
return
|
| 377 |
try:
|
| 378 |
-
|
| 379 |
-
try:
|
| 380 |
-
import jax; _ = jax.devices()
|
| 381 |
-
except Exception:
|
| 382 |
-
pass
|
| 383 |
-
|
| 384 |
-
mrt = get_mrt() # will build model and (with our earlier changes) ensure assets if envs are set
|
| 385 |
|
| 386 |
# --- derive timing from model config ---
|
| 387 |
codec_fps = float(mrt.codec.frame_rate)
|
|
@@ -422,18 +406,10 @@ def _mrt_warmup():
|
|
| 422 |
state.context_tokens = context_tokens
|
| 423 |
style_vec = mrt.embed_style("warmup")
|
| 424 |
|
| 425 |
-
# ---
|
| 426 |
-
|
| 427 |
-
wav2, _ = mrt.generate_chunk(state=state, style=style_vec) # hit cached executables
|
| 428 |
-
|
| 429 |
-
# Optional sanity: ensure we didn't return all zeros
|
| 430 |
-
try:
|
| 431 |
-
if np.abs(wav2.samples).mean() <= 1e-7:
|
| 432 |
-
logging.warning("Warmup produced near-silence; continuing.")
|
| 433 |
-
except Exception:
|
| 434 |
-
pass
|
| 435 |
|
| 436 |
-
logging.info("MagentaRT warmup complete
|
| 437 |
finally:
|
| 438 |
try:
|
| 439 |
os.unlink(tmp_path)
|
|
|
|
| 1 |
+
import os
|
| 2 |
|
| 3 |
# ---- Space mode gating (place above any JAX import!) ----
|
| 4 |
SPACE_MODE = os.getenv("SPACE_MODE")
|
|
|
|
| 28 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
| 29 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
| 30 |
|
| 31 |
+
import jax
|
| 32 |
+
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
|
| 33 |
+
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
|
|
|
|
|
|
|
|
|
|
| 34 |
try:
|
| 35 |
+
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
except Exception:
|
| 37 |
+
jax.config.update("jax_default_matmul_precision", "high") # older alias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# Initialize the on-disk compilation cache (best-effort)
|
| 40 |
try:
|
| 41 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 42 |
+
cc.initialize_cache(os.environ["JAX_CACHE_DIR"])
|
| 43 |
+
except Exception:
|
| 44 |
+
pass
|
| 45 |
+
# --------------------------------------------------------------------
|
| 46 |
|
| 47 |
|
| 48 |
|
|
|
|
| 67 |
|
| 68 |
import uuid, threading
|
| 69 |
|
| 70 |
+
import logging
|
| 71 |
+
|
| 72 |
import gradio as gr
|
| 73 |
from typing import Optional, Union, Literal
|
| 74 |
|
|
|
|
| 357 |
|
| 358 |
def _mrt_warmup():
|
| 359 |
"""
|
| 360 |
+
Build a minimal, bar-aligned silent context and run one 2s generate_chunk
|
| 361 |
+
to trigger XLA JIT & autotune so first real request is fast.
|
| 362 |
"""
|
| 363 |
global _WARMED
|
| 364 |
with _WARMUP_LOCK:
|
| 365 |
if _WARMED:
|
| 366 |
return
|
| 367 |
try:
|
| 368 |
+
mrt = get_mrt()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
# --- derive timing from model config ---
|
| 371 |
codec_fps = float(mrt.codec.frame_rate)
|
|
|
|
| 406 |
state.context_tokens = context_tokens
|
| 407 |
style_vec = mrt.embed_style("warmup")
|
| 408 |
|
| 409 |
+
# --- one throwaway chunk (~2s) ---
|
| 410 |
+
_wav, _state = mrt.generate_chunk(state=state, style=style_vec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
+
logging.info("MagentaRT warmup complete.")
|
| 413 |
finally:
|
| 414 |
try:
|
| 415 |
os.unlink(tmp_path)
|