Spaces:
Runtime error
Runtime error
Commit
·
a275f69
1
Parent(s):
ea57214
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,51 +35,54 @@ h = number_field('Num heads', value=16)
|
|
| 35 |
d = number_field('Dimension', value=768)
|
| 36 |
n_start = number_field('Start seq', value=1)
|
| 37 |
n = number_field('End seq', value=1024)
|
| 38 |
-
l = number_field('Num layers', value=
|
| 39 |
|
| 40 |
-
st.header('
|
| 41 |
|
| 42 |
-
st.subheader(
|
|
|
|
| 43 |
mha_flop = 2*bs*1*d*3*d
|
| 44 |
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
|
| 45 |
c1, c2 = st.columns([2, 3])
|
| 46 |
qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 47 |
|
| 48 |
-
st.
|
| 49 |
mqa_flop = 2*bs*1*d*(1+2/h)*d
|
| 50 |
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
|
| 51 |
c1, c2 = st.columns([2, 3])
|
| 52 |
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 53 |
|
| 54 |
-
st.
|
| 55 |
-
st.write("
|
| 56 |
|
| 57 |
-
st.
|
| 58 |
mha_flop = 2*bs*h*(d/h)*n
|
| 59 |
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
| 60 |
c1, c2 = st.columns([2, 3])
|
| 61 |
att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 62 |
|
| 63 |
-
st.
|
| 64 |
mqa_flop = 2*bs*h*(d/h)*n
|
| 65 |
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
| 66 |
c1, c2 = st.columns([2, 3])
|
| 67 |
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 68 |
|
| 69 |
-
st.header('Attention
|
| 70 |
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
|
| 71 |
-
st.
|
| 72 |
mha_flop = 2*bs*h*n*(d/h)
|
| 73 |
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
| 74 |
c1, c2 = st.columns([2, 3])
|
| 75 |
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 76 |
|
| 77 |
-
st.
|
| 78 |
mqa_flop = 2*bs*h*n*(d/h)
|
| 79 |
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
| 80 |
c1, c2 = st.columns([2, 3])
|
| 81 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 82 |
|
|
|
|
|
|
|
| 83 |
st.header('MLP')
|
| 84 |
st.subheader('First Linear')
|
| 85 |
mlp1_flop = 2*bs*1*d*4*d
|
|
|
|
| 35 |
d = number_field('Dimension', value=768)
|
| 36 |
n_start = number_field('Start seq', value=1)
|
| 37 |
n = number_field('End seq', value=1024)
|
| 38 |
+
l = number_field('Num layers', value=24)
|
| 39 |
|
| 40 |
+
st.header('Attention layer')
|
| 41 |
|
| 42 |
+
st.subheader('QKV projection')
|
| 43 |
+
st.caption("Multi-Head Attention")
|
| 44 |
mha_flop = 2*bs*1*d*3*d
|
| 45 |
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
|
| 46 |
c1, c2 = st.columns([2, 3])
|
| 47 |
qkv_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 48 |
|
| 49 |
+
st.caption("Multi-Query Attention")
|
| 50 |
mqa_flop = 2*bs*1*d*(1+2/h)*d
|
| 51 |
mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
|
| 52 |
c1, c2 = st.columns([2, 3])
|
| 53 |
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 54 |
|
| 55 |
+
st.subheader('QK gemm')
|
| 56 |
+
st.write("Note that calculation depends on sequence length (n)")
|
| 57 |
|
| 58 |
+
st.caption("Multi-Head Attention")
|
| 59 |
mha_flop = 2*bs*h*(d/h)*n
|
| 60 |
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
| 61 |
c1, c2 = st.columns([2, 3])
|
| 62 |
att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 63 |
|
| 64 |
+
st.caption("Multi-Query Attention")
|
| 65 |
mqa_flop = 2*bs*h*(d/h)*n
|
| 66 |
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
| 67 |
c1, c2 = st.columns([2, 3])
|
| 68 |
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 69 |
|
| 70 |
+
st.header('Attention-value gemm')
|
| 71 |
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
|
| 72 |
+
st.caption("Multi-Head Attention")
|
| 73 |
mha_flop = 2*bs*h*n*(d/h)
|
| 74 |
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
| 75 |
c1, c2 = st.columns([2, 3])
|
| 76 |
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 77 |
|
| 78 |
+
st.caption("Multi-Query Attention")
|
| 79 |
mqa_flop = 2*bs*h*n*(d/h)
|
| 80 |
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
| 81 |
c1, c2 = st.columns([2, 3])
|
| 82 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 83 |
|
| 84 |
+
st.subheader('Output projection')
|
| 85 |
+
|
| 86 |
st.header('MLP')
|
| 87 |
st.subheader('First Linear')
|
| 88 |
mlp1_flop = 2*bs*1*d*4*d
|