Horace He commited on
Commit
c17b112
·
0 Parent(s):

added visualizer gradio space

Browse files
Files changed (1) hide show
  1. layout_tv_viz.py +218 -0
layout_tv_viz.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import cutlass.cute as cute
3
+ import cutlass
4
+ import gradio as gr
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.colors as mcolors
8
+ import ast
9
+
10
+ def visualize_tv_layout(
11
+ tiler_mn: tuple[int, int],
12
+ tv_layout, # (((thr_shape),(val_shape)),
13
+ # ((thr_stride),(val_stride)))
14
+ *,
15
+ font_size: int = 10,
16
+ cell_px: int = 70,
17
+ grid_lw: float = 1.5,
18
+ dpi: int = 100,
19
+ max_rows: int = None,
20
+ max_cols: int = None,
21
+ color_fn=None, # optional (tid,vid) -> colour
22
+ ):
23
+ """Draw a T/V checkerboard for an arbitrary TV layout."""
24
+
25
+ # -----------------------------------------------------------------
26
+ # 1) Build a real CuTe layout from the tuple the user passed
27
+ # -----------------------------------------------------------------
28
+ shape, stride = tv_layout
29
+
30
+ def compute_recursive_size(shape):
31
+ if isinstance(shape, int):
32
+ return shape
33
+ else:
34
+ return math.prod(compute_recursive_size(i) for i in shape)
35
+
36
+ n_thr = compute_recursive_size(shape[0])
37
+ n_val = compute_recursive_size(shape[1])
38
+ M, N = tiler_mn
39
+
40
+ # Apply max rows/cols limits if specified
41
+ if max_rows is not None and max_rows > 0:
42
+ M = min(M, max_rows)
43
+ if max_cols is not None and max_cols > 0:
44
+ N = min(N, max_cols)
45
+
46
+ thr_ids = np.full((M, N), -1, dtype=int)
47
+ val_ids = np.full((M, N), -1, dtype=int)
48
+ filled = np.zeros((M, N), dtype=bool)
49
+
50
+ # -----------------------------------------------------------------
51
+ # 2) Query CuTe for every (tid, vid) → (m,n)
52
+ # -----------------------------------------------------------------
53
+
54
+ @cute.jit
55
+ def g():
56
+ tv_layout = cute.make_layout(shape, stride=stride)
57
+ tid_vals = []
58
+ for tid in cutlass.range_constexpr(n_thr):
59
+ vid_vals = []
60
+ for vid in cutlass.range_constexpr(n_val):
61
+ vid_vals.append(tv_layout((tid, vid)))
62
+ tid_vals.append(vid_vals)
63
+ return tid_vals
64
+ vals = g()
65
+ full_M, full_N = tiler_mn
66
+ for tid in range(n_thr):
67
+ for vid in range(n_val):
68
+ pos = vals[tid][vid]
69
+ n = pos // full_M
70
+ m = pos % full_M
71
+ # Skip if outside the display limits
72
+ if m >= M or n >= N:
73
+ continue
74
+ if filled[m, n]:
75
+ continue
76
+ thr_ids[m, n] = tid
77
+ val_ids[m, n] = vid
78
+ filled[m, n] = True
79
+
80
+ # -----------------------------------------------------------------
81
+ # 3) Colours (default: pastel per-thread)
82
+ # -----------------------------------------------------------------
83
+ if color_fn is None:
84
+ # pastel = list(plt.cm.Set3.colors) # + plt.cm.Set2.colors + plt.cm.Set1.colors)
85
+ color_palettes = [
86
+ plt.cm.Set3.colors,
87
+ plt.cm.Set2.colors,
88
+ plt.cm.Set1.colors,
89
+ # plt.cm.Pastel1.colors,
90
+ # plt.cm.Pastel2.colors,
91
+ ]
92
+ color_palettes = [j for i in color_palettes for j in i]
93
+ # breakpoint()
94
+ # cmap = []
95
+ # for i in range(n_thr):
96
+ # cmap += [[k * ((n_thr) - i)/ n_thr for k in j] for j in color_palettes]
97
+ cmap = (color_palettes * n_thr)[:n_thr]
98
+ # cmap = (pastel * n_thr)[:n_thr]
99
+ color_fn = lambda t, v: cmap[t % len(cmap)]
100
+
101
+ bg_rgb = np.zeros((M, N, 3))
102
+ for m in range(M):
103
+ for n in range(N):
104
+ tid = thr_ids[m, n]
105
+ if tid >= 0:
106
+ bg_rgb[m, n] = mcolors.to_rgb(color_fn(tid, val_ids[m, n]))
107
+
108
+ # -----------------------------------------------------------------
109
+ # 4) Draw
110
+ # -----------------------------------------------------------------
111
+ fig_w, fig_h = N * cell_px / 100, M * cell_px / 100
112
+ fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=dpi)
113
+ ax.imshow(bg_rgb, interpolation="none")
114
+
115
+ for m in range(M):
116
+ for n in range(N):
117
+ if thr_ids[m, n] >= 0:
118
+ ax.text(
119
+ n, m, f"T{thr_ids[m,n]}\nV{val_ids[m,n]}",
120
+ ha="center", va="center",
121
+ fontsize=font_size, weight="bold"
122
+ )
123
+
124
+ ax.set_xticks(np.arange(N + 1) - 0.5)
125
+ ax.set_yticks(np.arange(M + 1) - 0.5)
126
+ ax.set_xticklabels([str(i) for i in range(N + 1)])
127
+ ax.set_yticklabels([str(i) for i in range(M + 1)])
128
+ ax.tick_params(axis='both', which='both', length=6, width=1)
129
+ ax.tick_params(axis='x', which='both', top=True, bottom=False, labeltop=True, labelbottom=False)
130
+ ax.tick_params(axis='y', which='both', left=True, right=False)
131
+ ax.grid(which="major", color="black", linewidth=grid_lw)
132
+ ax.set_xlim(-.5, N -.5); ax.set_ylim(M -.5, -.5)
133
+
134
+ # Format title with colon notation
135
+ ax.set_title(f"tv_layout = {shape}:{stride}", fontsize=font_size + 2, pad=12)
136
+ plt.tight_layout()
137
+
138
+ return fig
139
+
140
+ # visualize_tv_layout((32, 16), (((4,8,2,2),((2,2),(1,1))),((64,1,16,256),((32,8),(0,0)))), dpi=100, max_rows=16, max_cols=16)
141
+ # exit(0)
142
+
143
+ def gradio_visualize(tiler_mn_str, tv_layout_str, dpi, max_rows, max_cols):
144
+ """Gradio wrapper for visualize_tv_layout."""
145
+ try:
146
+ # Parse input strings
147
+ tiler_mn = ast.literal_eval(tiler_mn_str)
148
+
149
+ # Support colon notation: (128,64):(1,128) or comma notation
150
+ if ':' in tv_layout_str:
151
+ # Split by colon to get shape and stride parts
152
+ parts = tv_layout_str.split(':')
153
+ if len(parts) != 2:
154
+ raise ValueError("Colon format must be shape:stride")
155
+ shape = ast.literal_eval(parts[0])
156
+ stride = ast.literal_eval(parts[1])
157
+ tv_layout = (shape, stride)
158
+ else:
159
+ # Traditional nested tuple format
160
+ tv_layout = ast.literal_eval(tv_layout_str)
161
+
162
+ fig = visualize_tv_layout(tiler_mn, tv_layout, dpi=dpi, max_rows=max_rows, max_cols=max_cols)
163
+ return fig
164
+ except Exception as e:
165
+ # Return error message
166
+ fig, ax = plt.subplots(figsize=(8, 4))
167
+ ax.text(0.5, 0.5, f"Error: {str(e)}",
168
+ ha='center', va='center', fontsize=12, color='red')
169
+ ax.axis('off')
170
+ return fig
171
+
172
+
173
+ # Create Gradio interface
174
+ with gr.Blocks(title="CuTe TV Layout Visualizer") as demo:
175
+ gr.Markdown("# CuTe TV Layout Visualizer")
176
+ gr.Markdown("Visualize thread/value (T/V) layouts for CuTe tensor operations.")
177
+
178
+ with gr.Row():
179
+ with gr.Column():
180
+ gr.Markdown("### Layout Parameters")
181
+ tiler_mn = gr.Textbox(
182
+ label="Tiler Dimensions (M, N)",
183
+ value="(8, 8)",
184
+ placeholder="(8, 8)"
185
+ )
186
+ tv_layout = gr.Textbox(
187
+ label="TV Layout",
188
+ value="((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))",
189
+ lines=2
190
+ )
191
+
192
+ dpi = gr.Number(label="DPI", value=200, precision=0)
193
+ max_rows = gr.Number(label="Max Rows (leave empty for no limit)", value=None, precision=0)
194
+ max_cols = gr.Number(label="Max Cols (leave empty for no limit)", value=None, precision=0)
195
+
196
+ visualize_btn = gr.Button("Visualize", variant="primary")
197
+
198
+ with gr.Column():
199
+ output_plot = gr.Plot(label="TV Layout Visualization")
200
+
201
+ visualize_btn.click(
202
+ fn=gradio_visualize,
203
+ inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols],
204
+ outputs=output_plot
205
+ )
206
+
207
+ # Add examples
208
+ gr.Examples(
209
+ examples=[
210
+ ["(8, 8)", "((2, 2, 2), (2, 2, 2)):((1, 16, 4), (8, 2, 32))", 200, None, None],
211
+ ["(4, 8)", "((4, 2), 4):((1, 16), 4)", 200, None, None],
212
+ ["(8, 4)", "((4, 2), 4):((8, 4), 1)", 200, None, None],
213
+ ],
214
+ inputs=[tiler_mn, tv_layout, dpi, max_rows, max_cols],
215
+ )
216
+
217
+ if __name__ == "__main__":
218
+ demo.launch()