Spaces:
Running
Running
File size: 7,911 Bytes
c17b112 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
import math
import cutlass.cute as cute
import cutlass
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ast
def visualize_tv_layout(
tiler_mn: tuple[int, int],
tv_layout, # (((thr_shape),(val_shape)),
# ((thr_stride),(val_stride)))
*,
font_size: int = 10,
cell_px: int = 70,
grid_lw: float = 1.5,
dpi: int = 100,
max_rows: int = None,
max_cols: int = None,
color_fn=None, # optional (tid,vid) -> colour
):
"""Draw a T/V checkerboard for an arbitrary TV layout."""
# -----------------------------------------------------------------
# 1) Build a real CuTe layout from the tuple the user passed
# -----------------------------------------------------------------
shape, stride = tv_layout
def compute_recursive_size(shape):
if isinstance(shape, int):
return shape
else:
return math.prod(compute_recursive_size(i) for i in shape)
n_thr = compute_recursive_size(shape[0])
n_val = compute_recursive_size(shape[1])
M, N = tiler_mn
# Apply max rows/cols limits if specified
if max_rows is not None and max_rows > 0:
M = min(M, max_rows)
if max_cols is not None and max_cols > 0:
N = min(N, max_cols)
thr_ids = np.full((M, N), -1, dtype=int)
val_ids = np.full((M, N), -1, dtype=int)
filled = np.zeros((M, N), dtype=bool)
# -----------------------------------------------------------------
# 2) Query CuTe for every (tid, vid) → (m,n)
# -----------------------------------------------------------------
@cute.jit
def g():
tv_layout = cute.make_layout(shape, stride=stride)
tid_vals = []
for tid in cutlass.range_constexpr(n_thr):
vid_vals = []
for vid in cutlass.range_constexpr(n_val):
vid_vals.append(tv_layout((tid, vid)))
tid_vals.append(vid_vals)
return tid_vals
vals = g()
full_M, full_N = tiler_mn
for tid in range(n_thr):
for vid in range(n_val):
pos = vals[tid][vid]
n = pos // full_M
m = pos % full_M
# Skip if outside the display limits
if m >= M or n >= N:
continue
if filled[m, n]:
continue
thr_ids[m, n] = tid
val_ids[m, n] = vid
filled[m, n] = True
# -----------------------------------------------------------------
# 3) Colours (default: pastel per-thread)
# -----------------------------------------------------------------
if color_fn is None:
# pastel = list(plt.cm.Set3.colors) # + plt.cm.Set2.colors + plt.cm.Set1.colors)
color_palettes = [
plt.cm.Set3.colors,
plt.cm.Set2.colors,
plt.cm.Set1.colors,
# plt.cm.Pastel1.colors,
# plt.cm.Pastel2.colors,
]
color_palettes = [j for i in color_palettes for j in i]
# breakpoint()
# cmap = []
# for i in range(n_thr):
# cmap += [[k * ((n_thr) - i)/ n_thr for k in j] for j in color_palettes]
cmap = (color_palettes * n_thr)[:n_thr]
# cmap = (pastel * n_thr)[:n_thr]
color_fn = lambda t, v: cmap[t % len(cmap)]
bg_rgb = np.zeros((M, N, 3))
for m in range(M):
for n in range(N):
tid = thr_ids[m, n]
if tid >= 0:
bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n]))
# -----------------------------------------------------------------
# 4) Draw
# -----------------------------------------------------------------
fig_w, fig_h = N * cell_px / 100, M * cell_px / 100
fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=dpi)
ax.imshow(bg_rgb, interpolation="none")
for m in range(M):
for n in range(N):
if thr_ids[m, n] >= 0:
ax.text(
n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}",
ha="center", va="center",
fontsize=font_size, weight="bold"
)
ax.set_xticks(np.arange(N + 1) - 0.5)
ax.set_yticks(np.arange(M + 1) - 0.5)
ax.set_xticklabels([str(i) for i in range(N + 1)])
ax.set_yticklabels([str(i) for i in range(M + 1)])
ax.tick_params(axis='both', which='both', length=6, width=1)
ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False)
ax.tick_params(axis='y', which='both', left=True, right=False)
ax.grid(which="major", color="black", linewidth=grid_lw)
ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5)
# Format title with colon notation
ax.set_title(f"tv_layout = {shape}:{stride}", fontsize=font_size + 2, pad=12)
plt.tight_layout()
return fig
# visualize_tv_layout((32, 16), (((4,8,2,2),((2,2),(1,1))),((64,1,16,256),((32,8),(0,0)))), dpi=100, max_rows=16, max_cols=16)
# exit(0)
def gradio_visualize(tiler_mn_str, tv_layout_str, dpi, max_rows, max_cols):
"""Gradio wrapper for visualize_tv_layout."""
try:
# Parse input strings
tiler_mn = ast.literal_eval(tiler_mn_str)
# Support colon notation: (128,64):(1,128) or comma notation
if ':' in tv_layout_str:
# Split by colon to get shape and stride parts
parts = tv_layout_str.split(':')
if len(parts) != 2:
raise ValueError("Colon format must be shape:stride")
shape = ast.literal_eval(parts[0])
stride = ast.literal_eval(parts[1])
tv_layout = (shape, stride)
else:
# Traditional nested tuple format
tv_layout = ast.literal_eval(tv_layout_str)
fig = visualize_tv_layout(tiler_mn, tv_layout, dpi=dpi, max_rows=max_rows, max_cols=max_cols)
return fig
except Exception as e:
# Return error message
fig, ax = plt.subplots(figsize=(8, 4))
ax.text(0.5, 0.5, f"Error: {str(e)}",
ha='center', va='center', fontsize=12, color='red')
ax.axis('off')
return fig
# Create Gradio interface
with gr.Blocks(title="CuTe TV Layout Visualizer") as demo:
gr.Markdown("# CuTe TV Layout Visualizer")
gr.Markdown("Visualize thread/value (T/V) layouts for CuTe tensor operations.")
with gr.Row():
with gr.Column():
gr.Markdown("### Layout Parameters")
tiler_mn = gr.Textbox(
label="Tiler Dimensions (M, N)",
value="(8, 8)",
placeholder="(8, 8)"
)
tv_layout = gr.Textbox(
label="TV Layout",
value="((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))",
lines=2
)
dpi = gr.Number(label="DPI", value=200, precision=0)
max_rows = gr.Number(label="Max Rows (leave empty for no limit)", value=None, precision=0)
max_cols = gr.Number(label="Max Cols (leave empty for no limit)", value=None, precision=0)
visualize_btn = gr.Button("Visualize", variant="primary")
with gr.Column():
output_plot = gr.Plot(label="TV Layout Visualization")
visualize_btn.click(
fn=gradio_visualize,
inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols],
outputs=output_plot
)
# Add examples
gr.Examples(
examples=[
["(8, 8)", "((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))", 200, None, None],
["(4, 8)", "((4, 2), 4):((1, 16), 4)", 200, None, None],
["(8, 4)", "((4, 2), 4):((8, 4), 1)", 200, None, None],
],
inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols],
)
if __name__ == "__main__":
demo.launch() |