Spaces:
Runtime error
Runtime error
Commit
·
2a3864d
1
Parent(s):
729a063
Update app.py
Browse files
app.py
CHANGED
|
@@ -53,7 +53,7 @@ st.caption("Multi-Query Attention")
|
|
| 53 |
mqa_flop = 2*bs*1*d*(1+2/h)*d
|
| 54 |
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
|
| 55 |
c1, c2 = st.columns([2, 3])
|
| 56 |
-
|
| 57 |
|
| 58 |
st.subheader('QK gemm')
|
| 59 |
st.write("Note that calculation depends on sequence length (n)")
|
|
@@ -101,7 +101,7 @@ softmax_time = print_kernel_execution(c1, c2, 0, softmax_bytes)
|
|
| 101 |
st.caption("Layer norm/residual connection")
|
| 102 |
ln_bytes = 2*bs*1*d
|
| 103 |
ln_flop = 0
|
| 104 |
-
ln_time = print_kernel_execution(c1, c2, 0,
|
| 105 |
|
| 106 |
st.header('MLP')
|
| 107 |
st.subheader('First Linear')
|
|
@@ -120,7 +120,7 @@ st.subheader('Element-wise ops')
|
|
| 120 |
st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
| 121 |
ln_bytes = 2*bs*1*d
|
| 122 |
ln_flop = 0
|
| 123 |
-
ln_time = print_kernel_execution(c1, c2, 0,
|
| 124 |
|
| 125 |
st.header("Adding it all up")
|
| 126 |
|
|
|
|
| 53 |
mqa_flop = 2*bs*1*d*(1+2/h)*d
|
| 54 |
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
|
| 55 |
c1, c2 = st.columns([2, 3])
|
| 56 |
+
qkv_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 57 |
|
| 58 |
st.subheader('QK gemm')
|
| 59 |
st.write("Note that calculation depends on sequence length (n)")
|
|
|
|
| 101 |
st.caption("Layer norm/residual connection")
|
| 102 |
ln_bytes = 2*bs*1*d
|
| 103 |
ln_flop = 0
|
| 104 |
+
ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
|
| 105 |
|
| 106 |
st.header('MLP')
|
| 107 |
st.subheader('First Linear')
|
|
|
|
| 120 |
st.write("We also need to take into the GeLU, layer norm, and residual connection. We assume that these operations are memory bound. ")
|
| 121 |
ln_bytes = 2*bs*1*d
|
| 122 |
ln_flop = 0
|
| 123 |
+
ln_time = print_kernel_execution(c1, c2, 0, ln_bytes)
|
| 124 |
|
| 125 |
st.header("Adding it all up")
|
| 126 |
|