Paula Leonova
commited on
Commit
·
d4be6e6
1
Parent(s):
2b16dfe
Add evaluation metrics
Browse files- app.py +20 -0
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -5,6 +5,8 @@ import pandas as pd
|
|
| 5 |
import base64
|
| 6 |
from typing import Sequence
|
| 7 |
import streamlit as st
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
| 10 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
|
@@ -102,7 +104,16 @@ if submit_button:
|
|
| 102 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
| 103 |
|
| 104 |
data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
|
|
|
|
| 105 |
data2 = pd.merge(data, data_ex_text, on = ['label'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
st.markdown("### Data Table")
|
| 107 |
|
| 108 |
with st.spinner('Generating a table of results and a download link...'):
|
|
@@ -112,5 +123,14 @@ if submit_button:
|
|
| 112 |
unsafe_allow_html = True
|
| 113 |
)
|
| 114 |
st.dataframe(data2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
st.success('All done!')
|
| 116 |
st.balloons()
|
|
|
|
| 5 |
import base64
|
| 6 |
from typing import Sequence
|
| 7 |
import streamlit as st
|
| 8 |
+
from sklearn.metrics import classification_report
|
| 9 |
+
|
| 10 |
|
| 11 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
| 12 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
|
|
|
| 104 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
| 105 |
|
| 106 |
data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
|
| 107 |
+
|
| 108 |
data2 = pd.merge(data, data_ex_text, on = ['label'])
|
| 109 |
+
|
| 110 |
+
if len(glabels) > 0:
|
| 111 |
+
gdata = pd.DataFrame({'label': glabels})
|
| 112 |
+
gdata['is_true_label'] = 1
|
| 113 |
+
|
| 114 |
+
data2 = pd.merge(data2, gdata, how = 'left', on = ['label'])
|
| 115 |
+
data2['is_true_label'].fillna(0, inplace = True)
|
| 116 |
+
|
| 117 |
st.markdown("### Data Table")
|
| 118 |
|
| 119 |
with st.spinner('Generating a table of results and a download link...'):
|
|
|
|
| 123 |
unsafe_allow_html = True
|
| 124 |
)
|
| 125 |
st.dataframe(data2)
|
| 126 |
+
|
| 127 |
+
if len(glabels) > 0:
|
| 128 |
+
with st.spinner('Evaluating output against ground truth...'):
|
| 129 |
+
report = classification_report(y_true = data2[['is_true_label']],
|
| 130 |
+
y_pred = (data2[['scores_from_full_text']] >= threshold_value) * 1.0,
|
| 131 |
+
output_dict=True)
|
| 132 |
+
df_report = pd.DataFrame(report).transpose()
|
| 133 |
+
st.dataframe(df_report)
|
| 134 |
+
|
| 135 |
st.success('All done!')
|
| 136 |
st.balloons()
|
requirements.txt
CHANGED
|
@@ -3,5 +3,6 @@ pandas
|
|
| 3 |
streamlit
|
| 4 |
plotly
|
| 5 |
torch
|
|
|
|
| 6 |
spacy>=2.2.0,<3.0.0
|
| 7 |
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz#egg=en_core_web_sm
|
|
|
|
| 3 |
streamlit
|
| 4 |
plotly
|
| 5 |
torch
|
| 6 |
+
sklearn
|
| 7 |
spacy>=2.2.0,<3.0.0
|
| 8 |
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz#egg=en_core_web_sm
|