Molbap's picture
Molbap HF Staff
Update app.py
e3461d1 verified
raw
history blame
10.7 kB
import os
import sys
import re
import json
import time
import threading
import subprocess
from pathlib import Path
import gradio as gr
import pandas as pd
import torch
import spaces
# ---------------------------
# Markdown rendering (Option A)
# ---------------------------
def _make_md_markdownit():
# Prefer markdown-it-py + mdit-py-plugins if available
from importlib import import_module
from markdown_it import MarkdownIt
md = MarkdownIt("gfm-like")
# Version-agnostic plugin shims
foot_mod = import_module("mdit_py_plugins.footnote")
foot = getattr(foot_mod, "footnote", None) or getattr(foot_mod, "footnote_plugin")
md.use(foot)
tl_mod = import_module("mdit_py_plugins.tasklists")
tasklists = getattr(tl_mod, "tasklists", None) or getattr(tl_mod, "tasklists_plugin")
md.use(tasklists)
cont_mod = import_module("mdit_py_plugins.container")
container = getattr(cont_mod, "container", None) or getattr(cont_mod, "container_plugin")
try:
md.use(container, "details")
except TypeError:
md.use(lambda m: container(m, name="details"))
return md
def _make_md_pythonmarkdown():
# Fallback: Python-Markdown + PyMdown
import markdown as md
exts = [
"extra", # tables + fenced code
"footnotes",
"admonition",
"toc",
"pymdownx.details",
"pymdownx.superfences",
"pymdownx.tasklist",
]
ext_cfg = {"pymdownx.tasklist": {"custom_checkbox": True}, "toc": {"permalink": True}}
return ("python-markdown", exts, ext_cfg, md)
try:
_md_engine = ("markdown-it", _make_md_markdownit())
except Exception:
_md_engine = _make_md_pythonmarkdown()
def _obsidian_rewrites(text: str) -> str:
# Obsidian image/file embeds and wiki links
text = re.sub(r'!\[\[([^\]|]+)\]\]', r'![](static/\1)', text) # ![[file.png]]
text = re.sub(r'\[\[([^\]|]+)\|([^\]]+)\]\]', r'[\2](\1)', text) # [[file|label]]
text = re.sub(r'\[\[([^\]]+)\]\]', r'[\1](\1)', text) # [[file]]
return text
def md_to_html(text: str) -> str:
text = _obsidian_rewrites(text)
if _md_engine[0] == "markdown-it":
md = _md_engine[1]
return md.render(text)
else:
tag, exts, cfg, md = _md_engine
return md.markdown(text, extensions=exts, extension_configs=cfg, output_format="html5")
def render_article(md_path: str, inserts: dict[str, callable]):
raw = ""
path = Path(md_path)
if path.exists():
raw = path.read_text(encoding="utf-8")
else:
raw = f"**Missing article**: `{md_path}` not found.\n\nCreate it in your Space repo."
# Split on {{TOKEN}} markers (e.g., {{ALLOC_PLOT}})
parts = re.split(r"\{\{([A-Z_]+)\}\}", raw)
with gr.Column():
for i, part in enumerate(parts):
if i % 2 == 0:
gr.HTML(md_to_html(part))
else:
build = inserts.get(part)
if build is None:
gr.HTML(f"<p><em>Unknown insert: {part}</em></p>")
else:
build()
# ---------------------------
# Terminal (safe, simplified)
# ---------------------------
def run_shell(cmd: str) -> str:
banned = ["|", ">", "<", "&&", "||", "`"]
if any(b in cmd for b in banned):
return "$ " + cmd + "\nBlocked characters. Use a single command."
try:
p = subprocess.run(cmd, shell=True, check=False, capture_output=True, text=True, timeout=30)
return f"$ {cmd}\n{p.stdout}{p.stderr}"
except Exception as e:
return f"$ {cmd}\n{e!r}"
def build_terminal():
with gr.Group():
cmd = gr.Textbox(label="Command", value="python -c 'import torch; print(torch.__version__)'")
run = gr.Button("Run")
out = gr.Textbox(label="Output", lines=12, interactive=False)
run.click(run_shell, inputs=cmd, outputs=out)
# ---------------------------------------
# Attention Mask Visualizer (Transformers)
# ---------------------------------------
def _import_attention_visualizer():
try:
from transformers.utils.attention_visualizer import AttentionMaskVisualizer # type: ignore
except Exception as e:
raise RuntimeError(
"AttentionMaskVisualizer is unavailable in this Transformers version."
) from e
return AttentionMaskVisualizer
@spaces.GPU(duration=120)
def render_attention_mask(model_id: str, prompt: str) -> str:
try:
AttentionMaskVisualizer = _import_attention_visualizer()
vis = AttentionMaskVisualizer(model_id)
out = vis(prompt) # returns embeddable HTML or object with _repr_html_
return str(out)
except Exception as e:
return f"<p>Attention visualizer error: {e}</p>"
def build_attn_vis():
with gr.Group():
with gr.Row():
model = gr.Dropdown(
label="Model",
choices=["openai-community/gpt2", "google/gemma-2-2b"],
value="openai-community/gpt2",
allow_custom_value=True,
)
prompt = gr.Textbox(label="Prompt", value="You are an assistant. Make sure you print me.")
go = gr.Button("Render")
html = gr.HTML()
go.click(render_attention_mask, inputs=[model, prompt], outputs=html)
# -------------------------------------------------------
# Transformers caching allocator warmup (time vs MiB plot)
# -------------------------------------------------------
from transformers import AutoModelForCausalLM, modeling_utils as MU # noqa: E402
def _measure_load_timeline(model_id: str, disable_warmup: bool):
orig = getattr(MU, "caching_allocator_warmup", None)
if disable_warmup and orig is not None:
MU.caching_allocator_warmup = lambda *a, **k: None # type: ignore[attr-defined]
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
tl = []
def sample(start_t, stop_evt):
while not stop_evt.is_set():
if device == "cuda":
torch.cuda.synchronize()
alloc = torch.cuda.memory_allocated()
else:
alloc = 0
tl.append({"t": time.perf_counter() - start_t, "MiB": alloc / (1024**2)})
time.sleep(0.05)
if device == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
start = time.perf_counter()
stop_evt = threading.Event()
th = threading.Thread(target=sample, args=(start, stop_evt), daemon=True)
th.start()
kwargs = {}
if device == "cuda":
kwargs.update(dict(torch_dtype=torch.float16, device_map="cuda:0", low_cpu_mem_usage=True))
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
stop_evt.set()
th.join()
if device == "cuda":
torch.cuda.synchronize()
tl.append({"t": time.perf_counter() - start, "MiB": torch.cuda.memory_allocated() / (1024**2)})
del model
if device == "cuda":
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return tl
finally:
if orig is not None:
MU.caching_allocator_warmup = orig # restore
@spaces.GPU(duration=240)
def profile_warmup(model_id: str):
on = _measure_load_timeline(model_id, disable_warmup=False)
off = _measure_load_timeline(model_id, disable_warmup=True)
rows = [{"t": r["t"], "MiB": r["MiB"], "mode": "warmup ON"} for r in on] + \
[{"t": r["t"], "MiB": r["MiB"], "mode": "warmup OFF"} for r in off]
return pd.DataFrame(rows)
def build_alloc_plot():
with gr.Group():
model = gr.Dropdown(
label="Model",
choices=["openai-community/gpt2", "google/gemma-2-2b"],
value="openai-community/gpt2",
allow_custom_value=True,
)
go = gr.Button("Run")
plot = gr.LinePlot(
x="t", y="MiB", color="mode", overlay_point=True,
title="from_pretrained() load: time vs CUDA memory_allocated()",
tooltip=["t", "MiB", "mode"], width=900, height=420
)
go.click(profile_warmup, inputs=[model], outputs=plot)
# ---------------------------
# Optional FastRTC preview
# ---------------------------
try:
from fastrtc import WebRTC, ReplyOnPause # type: ignore
def _echo_video(frame):
yield frame
HAS_FASTRTC = True
except Exception:
HAS_FASTRTC = False
def build_fastrtc():
if not HAS_FASTRTC:
gr.Markdown("Install `fastrtc` to enable this section.")
return
with gr.Group():
gr.Markdown("Camera loopback using FastRTC WebRTC. Extend with streaming handlers later.")
rtc = WebRTC(mode="send-receive", modality="video")
rtc.stream(ReplyOnPause(_echo_video), inputs=[rtc], outputs=[rtc], time_limit=60)
# ---------------------------
# Inserts registry
# ---------------------------
INSERTS = {
"TERMINAL": build_terminal,
"ATTN_VIS": build_attn_vis,
"ALLOC_PLOT": build_alloc_plot,
}
# ---------------------------
# Layout / CSS / App
# ---------------------------
CSS = """
:root { --toc-w: 280px; }
#layout { display: grid; grid-template-columns: var(--toc-w) 1fr; gap: 1.25rem; }
#toc { position: sticky; top: 0.75rem; height: calc(100vh - 1.5rem); overflow: auto; padding-right: .5rem; }
#toc a { text-decoration: none; display: block; padding: .25rem 0; }
.section { scroll-margin-top: 72px; }
.gradio-container { max-width: 1200px !important; margin: 0 auto; }
hr { border: none; border-top: 1px solid var(--neutral-300); margin: 1.25rem 0; }
"""
with gr.Blocks(css=CSS, fill_height=True, title="Interactive Blog — Transformers Feature Showcase") as demo:
gr.HTML("<h1>Transformers Feature Showcase</h1><p>Interactive, scrollable demo.</p>")
with gr.Row(elem_id="layout"):
with gr.Column(scale=0):
gr.HTML(
"""
<nav id="toc">
<h3>Sections</h3>
<a href="#article">Article</a>
<a href="#rtc">FastRTC (preview)</a>
</nav>
"""
)
with gr.Column():
gr.HTML('<h2 id="article" class="section">Article</h2>')
# Author in Obsidian. Put {{ALLOC_PLOT}}, {{ATTN_VIS}}, {{TERMINAL}} where you want widgets.
render_article("content/article.md", INSERTS)
gr.HTML("<hr/>")
gr.HTML('<h2 id="rtc" class="section">FastRTC (preview)</h2>')
build_fastrtc()
if __name__ == "__main__":
demo.launch()