Spaces:
Runtime error
Runtime error
Commit
·
1934207
1
Parent(s):
622e054
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,9 @@ def number_field(label, **args):
|
|
| 6 |
|
| 7 |
return c2.number_input('', **args)
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
st.header("Transformer parameters")
|
| 10 |
col1, col2 = st.columns([2, 4])
|
| 11 |
|
|
@@ -19,6 +22,7 @@ st.header('Query, Key, Value projection')
|
|
| 19 |
mha_flop = 2*bs*n*d*3*d
|
| 20 |
mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d
|
| 21 |
mha_int = mha_flop/mha_bytes
|
|
|
|
| 22 |
|
| 23 |
mha_flop = round(mha_flop/1e9, 2)
|
| 24 |
mha_bytes = round(mha_bytes/1e6, 2)
|
|
@@ -32,18 +36,24 @@ c1.write("MB: ")
|
|
| 32 |
c2.write(str(mha_bytes))
|
| 33 |
c1.write("Arithm. intensity:")
|
| 34 |
c2.write(str(mha_int))
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
mqa_flop = 2*bs*n*d*(1+2/h)*d
|
| 38 |
mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
|
| 39 |
|
|
|
|
|
|
|
|
|
|
| 40 |
st.subheader("Multi-query Attention")
|
| 41 |
c1, c2 = st.columns([2, 3])
|
| 42 |
-
c1.write("
|
| 43 |
c2.write(str(mqa_flop))
|
| 44 |
-
c1.write("
|
| 45 |
c2.write(str(mqa_bytes))
|
| 46 |
c1.write("Arithm. intensity:")
|
| 47 |
c2.write(str(mqa_flop/mqa_bytes))
|
| 48 |
|
| 49 |
st.header('Attention')
|
|
|
|
|
|
| 6 |
|
| 7 |
return c2.number_input('', **args)
|
| 8 |
|
| 9 |
+
TFLOPS = 312e12
|
| 10 |
+
GB_S = 1935e9
|
| 11 |
+
|
| 12 |
st.header("Transformer parameters")
|
| 13 |
col1, col2 = st.columns([2, 4])
|
| 14 |
|
|
|
|
| 22 |
mha_flop = 2*bs*n*d*3*d
|
| 23 |
mha_bytes = 2*bs*n*d + 2*3*d*d + 2*bs*n*3*d
|
| 24 |
mha_int = mha_flop/mha_bytes
|
| 25 |
+
mha_time = (mha_flop/TFLOPS + mha_bytes/GB_S)*1000
|
| 26 |
|
| 27 |
mha_flop = round(mha_flop/1e9, 2)
|
| 28 |
mha_bytes = round(mha_bytes/1e6, 2)
|
|
|
|
| 36 |
c2.write(str(mha_bytes))
|
| 37 |
c1.write("Arithm. intensity:")
|
| 38 |
c2.write(str(mha_int))
|
| 39 |
+
c1.write("Time (ms):")
|
| 40 |
+
c2.write(str(mha_time))
|
| 41 |
|
| 42 |
|
| 43 |
mqa_flop = 2*bs*n*d*(1+2/h)*d
|
| 44 |
mqa_bytes = 2*bs*n*d + 2*(2/h)*d*d + 2*bs*n*(2/h)*d
|
| 45 |
|
| 46 |
+
mqa_flop = round(mqa_flop/1e9, 2)
|
| 47 |
+
mqa_bytes = round(mqa_bytes/1e6, 2)
|
| 48 |
+
|
| 49 |
st.subheader("Multi-query Attention")
|
| 50 |
c1, c2 = st.columns([2, 3])
|
| 51 |
+
c1.write("GFLOP:")
|
| 52 |
c2.write(str(mqa_flop))
|
| 53 |
+
c1.write("MB:")
|
| 54 |
c2.write(str(mqa_bytes))
|
| 55 |
c1.write("Arithm. intensity:")
|
| 56 |
c2.write(str(mqa_flop/mqa_bytes))
|
| 57 |
|
| 58 |
st.header('Attention')
|
| 59 |
+
|