Spaces:
Runtime error
Runtime error
Commit
·
28b4830
1
Parent(s):
8cbefab
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
-
def number_field(label, **
|
| 4 |
c1, c2 = st.columns([2, 4])
|
| 5 |
c1.write(label)
|
| 6 |
|
| 7 |
-
return c2.number_input('', **
|
| 8 |
|
| 9 |
def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
|
| 10 |
arith_int = comp_flop/mem_bytes
|
| 11 |
exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
|
| 12 |
|
| 13 |
-
comp_flop = round(
|
| 14 |
-
mem_bytes = round(
|
| 15 |
|
| 16 |
c1.write("GFLOP:")
|
| 17 |
c2.write(str(comp_flop))
|
|
@@ -66,17 +66,17 @@ mqa_bytes = 2*bs*h*(d/h) + 2*bs*n*(d/h) + 2*bs*h*n
|
|
| 66 |
c1, c2 = st.columns([2, 3])
|
| 67 |
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 68 |
|
| 69 |
-
st.header('Attention scores: ')
|
| 70 |
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
|
| 71 |
st.subheader("Multi-Head Attention")
|
| 72 |
-
mha_flop = 2*bs*h*(d/h)
|
| 73 |
-
mha_bytes = 2*bs*h*
|
| 74 |
c1, c2 = st.columns([2, 3])
|
| 75 |
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 76 |
|
| 77 |
st.subheader("Multi-Query Attention")
|
| 78 |
-
mqa_flop = 2*bs*h*(d/h)
|
| 79 |
-
mqa_bytes = 2*bs*
|
| 80 |
c1, c2 = st.columns([2, 3])
|
| 81 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 82 |
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
+
def number_field(label, **kwargs)
|
| 4 |
c1, c2 = st.columns([2, 4])
|
| 5 |
c1.write(label)
|
| 6 |
|
| 7 |
+
return c2.number_input('', **kwargs)
|
| 8 |
|
| 9 |
def print_kernel_execution(c1, c2, comp_flop, mem_bytes):
|
| 10 |
arith_int = comp_flop/mem_bytes
|
| 11 |
exec_time = (comp_flop/TFLOPS + mem_bytes/GB_S)*1000
|
| 12 |
|
| 13 |
+
comp_flop = round(comp_flop/1e9, 2)
|
| 14 |
+
mem_bytes = round(comp_bytes/1e6, 2)
|
| 15 |
|
| 16 |
c1.write("GFLOP:")
|
| 17 |
c2.write(str(comp_flop))
|
|
|
|
| 66 |
c1, c2 = st.columns([2, 3])
|
| 67 |
att1_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 68 |
|
| 69 |
+
st.header('Attention scores: attention-value gemm')
|
| 70 |
st.write("Calculation depends on sequence length. We show numbers for maximum sequence length n.")
|
| 71 |
st.subheader("Multi-Head Attention")
|
| 72 |
+
mha_flop = 2*bs*h*n*(d/h)
|
| 73 |
+
mha_bytes = 2*bs*h*n + 2*bs*h*n*(d/h) + 2*bs*h*(d/h)
|
| 74 |
c1, c2 = st.columns([2, 3])
|
| 75 |
att_mha_time = print_kernel_execution(c1, c2, mha_flop, mha_bytes)
|
| 76 |
|
| 77 |
st.subheader("Multi-Query Attention")
|
| 78 |
+
mqa_flop = 2*bs*h*n*(d/h)
|
| 79 |
+
mqa_bytes = 2*bs*n*(d/h) + 2*bs*n*(d/h) + 2*bs*h*(d/h)
|
| 80 |
c1, c2 = st.columns([2, 3])
|
| 81 |
att_mqa_time = print_kernel_execution(c1, c2, mqa_flop, mqa_bytes)
|
| 82 |
|