Spaces:
Runtime error
Runtime error
Commit
·
8cbefab
1
Parent(s):
54cd0e6
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,12 +33,12 @@ col1, col2 = st.columns([2, 4])
|
|
| 33 |
bs = number_field('Batch size', value=10)
|
| 34 |
h = number_field('Num heads', value=16)
|
| 35 |
d = number_field('Dimension', value=768)
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
st.header('Query, Key, Value projection')
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
st.subheader("Multi-Head Attention")
|
| 43 |
mha_flop = 2*bs*1*d*3*d
|
| 44 |
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
|
|
@@ -51,16 +51,28 @@ mqa_bytes = 2*bs*1*d + 2*(2/h)*d*d + 2*bs*1*(2/h)*d
|
|
| 51 |
c1, c2 = st.columns([2, 3])
|
| 52 |
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 53 |
|
| 54 |
-
st.header('Attention')
|
|
|
|
| 55 |
|
| 56 |
st.subheader("Multi-Head Attention")
|
| 57 |
mha_flop = 2*bs*h*(d/h)*n
|
| 58 |
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
| 59 |
c1, c2 = st.columns([2, 3])
|
| 60 |
-
|
| 61 |
-
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
st.subheader("Multi-Query Attention")
|
| 66 |
mqa_flop = 2*bs*h*(d/h)*n
|
|
@@ -68,7 +80,6 @@ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
|
| 68 |
c1, c2 = st.columns([2, 3])
|
| 69 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 70 |
|
| 71 |
-
|
| 72 |
st.header('MLP')
|
| 73 |
st.subheader('First Linear')
|
| 74 |
mlp1_flop = 2*bs*1*d*4*d
|
|
|
|
| 33 |
bs = number_field('Batch size', value=10)
|
| 34 |
h = number_field('Num heads', value=16)
|
| 35 |
d = number_field('Dimension', value=768)
|
| 36 |
+
n_start = number_field('Start seq', value=1)
|
| 37 |
+
n = number_field('End seq', value=1024)
|
| 38 |
+
l = number_field('Num layers', value=2)
|
| 39 |
|
| 40 |
st.header('Query, Key, Value projection')
|
| 41 |
|
|
|
|
|
|
|
| 42 |
st.subheader("Multi-Head Attention")
|
| 43 |
mha_flop = 2*bs*1*d*3*d
|
| 44 |
mha_bytes = 2*bs*1*d + 2*3*d*d + 2*bs*1*3*d
|
|
|
|
| 51 |
c1, c2 = st.columns([2, 3])
|
| 52 |
qkv_mha_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 53 |
|
| 54 |
+
st.header('Attention scores: query-key gemm')
|
| 55 |
+
st.write("Calculation depends on sequence length (n). Take end of sequence.")
|
| 56 |
|
| 57 |
st.subheader("Multi-Head Attention")
|
| 58 |
mha_flop = 2*bs*h*(d/h)*n
|
| 59 |
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
| 60 |
c1, c2 = st.columns([2, 3])
|
| 61 |
+
att1_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
|
|
|
| 62 |
|
| 63 |
+
st.subheader("Multi-Query Attention")
|
| 64 |
+
mqa_flop = 2*bs*h*(d/h)*n
|
| 65 |
+
mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
| 66 |
+
c1, c2 = st.columns([2, 3])
|
| 67 |
+
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 68 |
|
| 69 |
+
st.header('Attention scores: ')
|
| 70 |
+
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
|
| 71 |
+
st.subheader("Multi-Head Attention")
|
| 72 |
+
mha_flop = 2*bs*h*(d/h)*n
|
| 73 |
+
mha_bytes = 2*bs*h*(d/h) + 2*bs*h*n*(d/h) + 2*bs*h*n
|
| 74 |
+
c1, c2 = st.columns([2, 3])
|
| 75 |
+
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 76 |
|
| 77 |
st.subheader("Multi-Query Attention")
|
| 78 |
mqa_flop = 2*bs*h*(d/h)*n
|
|
|
|
| 80 |
c1, c2 = st.columns([2, 3])
|
| 81 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 82 |
|
|
|
|
| 83 |
st.header('MLP')
|
| 84 |
st.subheader('First Linear')
|
| 85 |
mlp1_flop = 2*bs*1*d*4*d
|