Update app.py
Browse files
app.py
CHANGED
|
@@ -113,6 +113,7 @@ def create_analysis_html(sentence_results, global_label, global_confidence):
|
|
| 113 |
html += "</table>"
|
| 114 |
return html
|
| 115 |
|
|
|
|
| 116 |
def process_input(text_input, labels_or_premise, mode):
|
| 117 |
if mode == "Zero-Shot Classification":
|
| 118 |
labels = [label.strip() for label in labels_or_premise.split(',')]
|
|
@@ -126,8 +127,9 @@ def process_input(text_input, labels_or_premise, mode):
|
|
| 126 |
else: # Long Context NLI
|
| 127 |
# Global prediction
|
| 128 |
global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
| 129 |
-
global_results = {
|
| 130 |
-
global_label
|
|
|
|
| 131 |
|
| 132 |
# Sentence-level analysis
|
| 133 |
sentences = sent_tokenize(text_input)
|
|
@@ -135,15 +137,18 @@ def process_input(text_input, labels_or_premise, mode):
|
|
| 135 |
|
| 136 |
for sentence in sentences:
|
| 137 |
sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
| 140 |
sentence_results.append({
|
| 141 |
'sentence': sentence,
|
| 142 |
'prediction': max_label,
|
| 143 |
-
'
|
| 144 |
})
|
| 145 |
|
| 146 |
-
analysis_html = create_analysis_html(sentence_results, global_label,global_confidence)
|
| 147 |
return global_results, analysis_html
|
| 148 |
|
| 149 |
def update_interface(mode):
|
|
|
|
| 113 |
html += "</table>"
|
| 114 |
return html
|
| 115 |
|
| 116 |
+
|
| 117 |
def process_input(text_input, labels_or_premise, mode):
|
| 118 |
if mode == "Zero-Shot Classification":
|
| 119 |
labels = [label.strip() for label in labels_or_premise.split(',')]
|
|
|
|
| 127 |
else: # Long Context NLI
|
| 128 |
# Global prediction
|
| 129 |
global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
| 130 |
+
global_results = {p['label']: p['score'] for p in global_pred}
|
| 131 |
+
global_label = max(global_results.items(), key=lambda x: x[1])[0]
|
| 132 |
+
global_confidence = max(global_results.values())
|
| 133 |
|
| 134 |
# Sentence-level analysis
|
| 135 |
sentences = sent_tokenize(text_input)
|
|
|
|
| 137 |
|
| 138 |
for sentence in sentences:
|
| 139 |
sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
|
| 140 |
+
# Get the prediction and confidence for the sentence
|
| 141 |
+
pred_scores = [(p['label'], p['score']) for p in sent_pred]
|
| 142 |
+
max_pred = max(pred_scores, key=lambda x: x[1])
|
| 143 |
+
max_label, confidence = max_pred
|
| 144 |
+
|
| 145 |
sentence_results.append({
|
| 146 |
'sentence': sentence,
|
| 147 |
'prediction': max_label,
|
| 148 |
+
'confidence': confidence
|
| 149 |
})
|
| 150 |
|
| 151 |
+
analysis_html = create_analysis_html(sentence_results, global_label, global_confidence)
|
| 152 |
return global_results, analysis_html
|
| 153 |
|
| 154 |
def update_interface(mode):
|