Spaces:
Runtime error
Runtime error
Commit
·
c88286f
1
Parent(s):
064a5f0
Update app.py
Browse files
app.py
CHANGED
|
@@ -116,7 +116,7 @@ def print_kernel_execution(flop, nbytes):
|
|
| 116 |
c2.write(str(THREAD_OVERHEAD))
|
| 117 |
|
| 118 |
st.title("Inference time MHA vs MQA")
|
| 119 |
-
st.write("This space approximates the inference time for Multi-Query Attention and Multi-Head Attention
|
| 120 |
|
| 121 |
mqa_total_time = 0.
|
| 122 |
mha_total_time = 0.
|
|
@@ -187,63 +187,62 @@ st.latex("max(T_{math}, T_{mem})")
|
|
| 187 |
|
| 188 |
st.markdown("We also a minimum time for executing the operation due to [kernel launch overhead](https://forums.developer.nvidia.com/t/any-way-to-measure-the-latency-of-a-kernel-launch/221413/2)")
|
| 189 |
|
| 190 |
-
st.subheader("
|
|
|
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
flop, nbytes, exec_time = qkv_mha_exec(bs, h, n, d)
|
| 199 |
-
print_kernel_execution(flop, nbytes)
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
| 116 |
c2.write(str(THREAD_OVERHEAD))
|
| 117 |
|
| 118 |
st.title("Inference time MHA vs MQA")
|
| 119 |
+
st.write("This space approximates the inference time for Multi-Query Attention and Multi-Head Attention transformers. You can change the hyperparameters in sidebar.")
|
| 120 |
|
| 121 |
mqa_total_time = 0.
|
| 122 |
mha_total_time = 0.
|
|
|
|
| 187 |
|
| 188 |
st.markdown("We also a minimum time for executing the operation due to [kernel launch overhead](https://forums.developer.nvidia.com/t/any-way-to-measure-the-latency-of-a-kernel-launch/221413/2)")
|
| 189 |
|
| 190 |
+
st.subheader("Inference time for Transformer operations")
|
| 191 |
+
st.text("We can now estimate the execution for each of the operations in the transformer model. I suggest you inspect the code for details on the calculations. ")
|
| 192 |
|
| 193 |
+
st.subheader('Attention layer')
|
| 194 |
+
|
| 195 |
+
st.markdown('**QKV projection**')
|
| 196 |
+
st.caption("Multi-Head Attention")
|
| 197 |
+
flop, nbytes, exec_time = qkv_mha_exec(bs, h, n, d)
|
| 198 |
+
print_kernel_execution(flop, nbytes)
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
st.caption("Multi-Query Attention")
|
| 201 |
+
flop, nbytes, exec_time = qkv_mqa_exec(bs, h, n, d)
|
| 202 |
+
print_kernel_execution(flop, nbytes)
|
| 203 |
+
|
| 204 |
+
st.markdown('**QK gemm**')
|
| 205 |
+
st.write("Showing calculation for the maximum sequence length (n)")
|
| 206 |
+
|
| 207 |
+
st.caption("Multi-Head Attention")
|
| 208 |
+
flop, nbytes, exec_time = att1_mha_exec(bs, h, n, d)
|
| 209 |
+
print_kernel_execution(flop, nbytes)
|
| 210 |
+
|
| 211 |
+
st.caption("Multi-Query Attention")
|
| 212 |
+
flop, nbytes, exec_time = att1_mqa_exec(bs, h, n, d)
|
| 213 |
+
print_kernel_execution(flop, nbytes)
|
| 214 |
+
|
| 215 |
+
st.markdown('**Attention-value gemm**')
|
| 216 |
+
st.write("Showing calculation for the maximum sequence length (n)")
|
| 217 |
+
st.caption("Multi-Head Attention")
|
| 218 |
+
flop, nbytes, exec_time = att2_mha_exec(bs, h, n, d)
|
| 219 |
+
print_kernel_execution(flop, nbytes)
|
| 220 |
+
|
| 221 |
+
st.caption("Multi-Query Attention")
|
| 222 |
+
flop, nbytes, exec_time = att2_mqa_exec(bs, h, n, d)
|
| 223 |
+
print_kernel_execution(flop, nbytes)
|
| 224 |
+
|
| 225 |
+
st.markdown('**Output projection**')
|
| 226 |
+
flop, nbytes, exec_time = out_exec(bs, h, n, d)
|
| 227 |
+
print_kernel_execution(flop, nbytes)
|
| 228 |
+
|
| 229 |
+
st.markdown('**Element-wise ops**')
|
| 230 |
+
st.write("We also need to take into the softmax layer, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
| 231 |
+
|
| 232 |
+
st.caption("Softmax")
|
| 233 |
+
flop, nbytes, exec_time = softmax_exec(bs, h, n, d)
|
| 234 |
+
print_kernel_execution(flop, nbytes)
|
| 235 |
+
|
| 236 |
+
st.caption("Layer norm/residual connection")
|
| 237 |
+
flop, nbytes, exec_time = ln_exec(bs, h, n, d)
|
| 238 |
+
print_kernel_execution(flop, nbytes)
|
| 239 |
+
|
| 240 |
+
st.subheader('MLP layer')
|
| 241 |
+
st.markdown('**First and Second Linear Layer**')
|
| 242 |
+
flop, nbytes, exec_time = mlp_exec(bs, h, n, d)
|
| 243 |
+
print_kernel_execution(flop, nbytes)
|
| 244 |
+
|
| 245 |
+
st.markdown('**Element-wise ops**')
|
| 246 |
+
st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
| 247 |
+
flop, nbytes, exec_time = ln_exec(bs, h, n, d)
|
| 248 |
+
print_kernel_execution(flop, nbytes)
|