import math import cutlass.cute as cute import cutlass import gradio as gr import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as mcolors import ast def visualize_tv_layout( tiler_mn: tuple[int, int], tv_layout, # (((thr_shape),(val_shape)), # ((thr_stride),(val_stride))) *, font_size: int = 10, cell_px: int = 70, grid_lw: float = 1.5, dpi: int = 100, max_rows: int = None, max_cols: int = None, color_fn=None, # optional (tid,vid) -> colour ): """Draw a T/V checkerboard for an arbitrary TV layout.""" # ----------------------------------------------------------------- # 1) Build a real CuTe layout from the tuple the user passed # ----------------------------------------------------------------- shape, stride = tv_layout def compute_recursive_size(shape): if isinstance(shape, int): return shape else: return math.prod(compute_recursive_size(i) for i in shape) n_thr = compute_recursive_size(shape[0]) n_val = compute_recursive_size(shape[1]) M, N = tiler_mn # Apply max rows/cols limits if specified if max_rows is not None and max_rows > 0: M = min(M, max_rows) if max_cols is not None and max_cols > 0: N = min(N, max_cols) thr_ids = np.full((M, N), -1, dtype=int) val_ids = np.full((M, N), -1, dtype=int) filled = np.zeros((M, N), dtype=bool) # ----------------------------------------------------------------- # 2) Query CuTe for every (tid, vid) → (m,n) # ----------------------------------------------------------------- @cute.jit def g(): tv_layout = cute.make_layout(shape, stride=stride) tid_vals = [] for tid in cutlass.range_constexpr(n_thr): vid_vals = [] for vid in cutlass.range_constexpr(n_val): vid_vals.append(tv_layout((tid, vid))) tid_vals.append(vid_vals) return tid_vals vals = g() full_M, full_N = tiler_mn for tid in range(n_thr): for vid in range(n_val): pos = vals[tid][vid] n = pos // full_M m = pos % full_M # Skip if outside the display limits if m >= M or n >= N: continue if filled[m, n]: continue thr_ids[m, n] = tid val_ids[m, n] = vid filled[m, n] = True # ----------------------------------------------------------------- # 3) Colours (default: pastel per-thread) # ----------------------------------------------------------------- if color_fn is None: # pastel = list(plt.cm.Set3.colors) # + plt.cm.Set2.colors + plt.cm.Set1.colors) color_palettes = [ plt.cm.Set3.colors, plt.cm.Set2.colors, plt.cm.Set1.colors, # plt.cm.Pastel1.colors, # plt.cm.Pastel2.colors, ] color_palettes = [j for i in color_palettes for j in i] # breakpoint() # cmap = [] # for i in range(n_thr): # cmap += [[k * ((n_thr) - i)/ n_thr for k in j] for j in color_palettes] cmap = (color_palettes * n_thr)[:n_thr] # cmap = (pastel * n_thr)[:n_thr] color_fn = lambda t, v: cmap[t % len(cmap)] bg_rgb = np.zeros((M, N, 3)) for m in range(M): for n in range(N): tid = thr_ids[m, n] if tid >= 0: bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n])) # ----------------------------------------------------------------- # 4) Draw # ----------------------------------------------------------------- fig_w, fig_h = N * cell_px / 100, M * cell_px / 100 fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=dpi) ax.imshow(bg_rgb, interpolation="none") for m in range(M): for n in range(N): if thr_ids[m, n] >= 0: ax.text( n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}", ha="center", va="center", fontsize=font_size, weight="bold" ) ax.set_xticks(np.arange(N + 1) - 0.5) ax.set_yticks(np.arange(M + 1) - 0.5) ax.set_xticklabels([str(i) for i in range(N + 1)]) ax.set_yticklabels([str(i) for i in range(M + 1)]) ax.tick_params(axis='both', which='both', length=6, width=1) ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False) ax.tick_params(axis='y', which='both', left=True, right=False) ax.grid(which="major", color="black", linewidth=grid_lw) ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5) # Format title with colon notation ax.set_title(f"tv_layout = {shape}:{stride}", fontsize=font_size + 2, pad=12) plt.tight_layout() return fig # visualize_tv_layout((32, 16), (((4,8,2,2),((2,2),(1,1))),((64,1,16,256),((32,8),(0,0)))), dpi=100, max_rows=16, max_cols=16) # exit(0) def gradio_visualize(tiler_mn_str, tv_layout_str, dpi, max_rows, max_cols): """Gradio wrapper for visualize_tv_layout.""" try: # Parse input strings tiler_mn = ast.literal_eval(tiler_mn_str) # Support colon notation: (128,64):(1,128) or comma notation if ':' in tv_layout_str: # Split by colon to get shape and stride parts parts = tv_layout_str.split(':') if len(parts) != 2: raise ValueError("Colon format must be shape:stride") shape = ast.literal_eval(parts[0]) stride = ast.literal_eval(parts[1]) tv_layout = (shape, stride) else: # Traditional nested tuple format tv_layout = ast.literal_eval(tv_layout_str) fig = visualize_tv_layout(tiler_mn, tv_layout, dpi=dpi, max_rows=max_rows, max_cols=max_cols) return fig except Exception as e: # Return error message fig, ax = plt.subplots(figsize=(8, 4)) ax.text(0.5, 0.5, f"Error: {str(e)}", ha='center', va='center', fontsize=12, color='red') ax.axis('off') return fig # Create Gradio interface with gr.Blocks(title="CuTe TV Layout Visualizer") as demo: gr.Markdown("# CuTe TV Layout Visualizer") gr.Markdown("Visualize thread/value (T/V) layouts for CuTe tensor operations.") with gr.Row(): with gr.Column(): gr.Markdown("### Layout Parameters") tiler_mn = gr.Textbox( label="Tiler Dimensions (M, N)", value="(8, 8)", placeholder="(8, 8)" ) tv_layout = gr.Textbox( label="TV Layout", value="((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))", lines=2 ) dpi = gr.Number(label="DPI", value=200, precision=0) max_rows = gr.Number(label="Max Rows (leave empty for no limit)", value=None, precision=0) max_cols = gr.Number(label="Max Cols (leave empty for no limit)", value=None, precision=0) visualize_btn = gr.Button("Visualize", variant="primary") with gr.Column(): output_plot = gr.Plot(label="TV Layout Visualization") visualize_btn.click( fn=gradio_visualize, inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols], outputs=output_plot ) # Add examples gr.Examples( examples=[ ["(8, 8)", "((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))", 200, None, None], ["(4, 8)", "((4, 2), 4):((1, 16), 4)", 200, None, None], ["(8, 4)", "((4, 2), 4):((8, 4), 1)", 200, None, None], ], inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols], ) if __name__ == "__main__": demo.launch()