Pontonkid commited on
Commit
c490753
Β·
verified Β·
1 Parent(s): 929059f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -0
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import re
4
+
5
+ # Custom sentence tokenizer
6
+ def sent_tokenize(text):
7
+ sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(\s|$)')
8
+ sentences = sentence_endings.split(text)
9
+ return [s.strip() for s in sentences if s.strip()]
10
+
11
+ # Initialize the classifiers
12
+ zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli", device="cpu")
13
+ nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli", device="cpu")
14
+
15
+ # [Previous example definitions remain the same]
16
+ zero_shot_examples = [
17
+ ["I absolutely love this product, it's amazing!", "positive, negative, neutral"],
18
+ ["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
19
+ ["The sun is very bright today", "weather, astronomy, complaints, poetry"],
20
+ ["I love playing video games", "entertainment, sports, education, business"],
21
+ ["The car won't start", "transportation, art, cooking, literature"]
22
+ ]
23
+
24
+ nli_examples = [
25
+ ["A man is sleeping on a couch", "The man is awake"],
26
+ ["The restaurant's waiting area is bustling, but several tables remain vacant", "The establishment is at maximum capacity"],
27
+ ["The child is methodically arranging blocks while frowning in concentration", "The kid is experiencing joy"],
28
+ ["Dark clouds are gathering and the pavement shows scattered wet spots", "It's been raining heavily all day"],
29
+ ["A German Shepherd is exhibiting defensive behavior towards someone approaching the property", "The animal making noise is feline"]
30
+ ]
31
+
32
+ long_context_examples = [
33
+ ["""A company's environmental policy typically has a profound impact upon its standing in the community. There are legal regulations, with stiff penalties attached, which compel managers to ensure that any waste products are disposed of without contaminating the air or water supplies. In addition, employees can be educated about the inevitable commercial and social benefits of recycling paper and other substances produced as by-products of the manufacturing process. One popular method for gaining staff co-operation is the internal incentive scheme. These often target teams rather than individuals, since the interdependence of staff organising any reprocessing, masks the importance of a given player's role.""",
34
+ "The regard held for an organisation may be affected by its commitment to environmental issues."]
35
+ ]
36
+
37
+ def get_label_color(label, confidence=1.0):
38
+ """Return color based on NLI label with confidence-based saturation."""
39
+ base_colors = {
40
+ 'entailment': 'rgb(144, 238, 144)', # Light green
41
+ 'neutral': 'rgb(255, 229, 180)', # Peach
42
+ 'contradiction': 'rgb(255, 182, 193)' # Light pink
43
+ }
44
+
45
+ # Convert RGB color to RGBA with confidence-based alpha
46
+ if label in base_colors:
47
+ rgb = base_colors[label].replace('rgb(', '').replace(')', '').split(',')
48
+ r, g, b = map(int, rgb)
49
+ # Adjust the color based on confidence
50
+ alpha = 0.3 + (0.7 * confidence) # Range from 0.3 to 1.0
51
+ return f"rgba({r},{g},{b},{alpha})"
52
+ return '#FFFFFF'
53
+
54
+ def create_analysis_html(sentence_results, global_label, global_confidence):
55
+ """Create HTML table for sentence analysis with color coding and confidence."""
56
+ html = """
57
+ <style>
58
+ .analysis-table {
59
+ width: 100%;
60
+ border-collapse: collapse;
61
+ margin: 20px 0;
62
+ font-family: Arial, sans-serif;
63
+ }
64
+ .analysis-table th, .analysis-table td {
65
+ padding: 12px;
66
+ border: 1px solid #ddd;
67
+ text-align: left;
68
+ }
69
+ .analysis-table th {
70
+ background-color: #f5f5f5;
71
+ }
72
+ .global-prediction {
73
+ padding: 15px;
74
+ margin: 20px 0;
75
+ border-radius: 5px;
76
+ font-weight: bold;
77
+ }
78
+ .confidence {
79
+ font-size: 0.9em;
80
+ color: #666;
81
+ }
82
+ </style>
83
+ """
84
+
85
+ # Add global prediction box with confidence
86
+ html += f"""
87
+ <div class="global-prediction" style="background-color: {get_label_color(global_label, global_confidence)}">
88
+ Global Prediction: {global_label}
89
+ <span class="confidence">(Confidence: {global_confidence:.2%})</span>
90
+ </div>
91
+ """
92
+
93
+ # Create table
94
+ html += """
95
+ <table class="analysis-table">
96
+ <tr>
97
+ <th>Sentence</th>
98
+ <th>Prediction</th>
99
+ <th>Confidence</th>
100
+ </tr>
101
+ """
102
+
103
+ # Add rows for each sentence
104
+ for result in sentence_results:
105
+ html += f"""
106
+ <tr style="background-color: {get_label_color(result['prediction'], result['confidence'])}">
107
+ <td>{result['sentence']}</td>
108
+ <td>{result['prediction']}</td>
109
+ <td>{result['confidence']:.2%}</td>
110
+ </tr>
111
+ """
112
+
113
+ html += "</table>"
114
+ return html
115
+
116
+
117
+ def process_input(text_input, labels_or_premise, mode):
118
+ if mode == "Zero-Shot Classification":
119
+ labels = [label.strip() for label in labels_or_premise.split(',')]
120
+ prediction = zero_shot_classifier(text_input, labels)
121
+ results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
122
+ return results, ''
123
+ elif mode == "Natural Language Inference":
124
+ pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
125
+ results = {pred['label']: pred['score'] for pred in pred}
126
+ return results, ''
127
+ else: # Long Context NLI
128
+ # Global prediction
129
+ global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
130
+ global_results = {p['label']: p['score'] for p in global_pred}
131
+ global_label = max(global_results.items(), key=lambda x: x[1])[0]
132
+ global_confidence = max(global_results.values())
133
+
134
+ # Sentence-level analysis
135
+ sentences = sent_tokenize(text_input)
136
+ sentence_results = []
137
+
138
+ for sentence in sentences:
139
+ sent_pred = nli_classifier([{"text": sentence, "text_pair": labels_or_premise}], return_all_scores=True)[0]
140
+ # Get the prediction and confidence for the sentence
141
+ pred_scores = [(p['label'], p['score']) for p in sent_pred]
142
+ max_pred = max(pred_scores, key=lambda x: x[1])
143
+ max_label, confidence = max_pred
144
+
145
+ sentence_results.append({
146
+ 'sentence': sentence,
147
+ 'prediction': max_label,
148
+ 'confidence': confidence
149
+ })
150
+
151
+ analysis_html = create_analysis_html(sentence_results, global_label, global_confidence)
152
+ return global_results, analysis_html
153
+
154
+ def update_interface(mode):
155
+ if mode == "Zero-Shot Classification":
156
+ return (
157
+ gr.update(
158
+ label="🏷️ Categories",
159
+ placeholder="Enter comma-separated categories...",
160
+ value=zero_shot_examples[0][1]
161
+ ),
162
+ gr.update(value=zero_shot_examples[0][0])
163
+ )
164
+ elif mode == "Natural Language Inference":
165
+ return (
166
+ gr.update(
167
+ label="πŸ”Ž Hypothesis",
168
+ placeholder="Enter a hypothesis to compare with the premise...",
169
+ value=nli_examples[0][1]
170
+ ),
171
+ gr.update(value=nli_examples[0][0])
172
+ )
173
+ else: # Long Context NLI
174
+ return (
175
+ gr.update(
176
+ label="πŸ”Ž Hypothesis",
177
+ placeholder="Enter a hypothesis to test against the full context...",
178
+ value=long_context_examples[0][1]
179
+ ),
180
+ gr.update(value=long_context_examples[0][0])
181
+ )
182
+
183
+ def update_visibility(mode):
184
+ return (
185
+ gr.update(visible=(mode == "Zero-Shot Classification")),
186
+ gr.update(visible=(mode == "Natural Language Inference")),
187
+ gr.update(visible=(mode == "Long Context NLI"))
188
+ )
189
+
190
+ # Now define the Blocks interface
191
+ with gr.Blocks() as demo:
192
+ gr.Markdown("""
193
+ # tasksource/ModernBERT-nli demonstration
194
+
195
+ This space uses [tasksource/ModernBERT-base-nli](https://huggingface.co/tasksource/ModernBERT-base-nli),
196
+ fine-tuned from [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
197
+ on tasksource classification tasks.
198
+ This NLI model achieves high accuracy on categorization, logical reasoning and long-context NLI, outperforming Llama 3 8B on ConTRoL (long-context NLI) and FOLIO (logical reasoning).
199
+ """)
200
+
201
+ mode = gr.Radio(
202
+ ["Zero-Shot Classification", "Natural Language Inference", "Long Context NLI"],
203
+ label="Select Mode",
204
+ value="Zero-Shot Classification"
205
+ )
206
+
207
+ with gr.Column():
208
+ text_input = gr.Textbox(
209
+ label="✍️ Input Text",
210
+ placeholder="Enter your text...",
211
+ lines=3,
212
+ value=zero_shot_examples[0][0]
213
+ )
214
+
215
+ labels_or_premise = gr.Textbox(
216
+ label="🏷️ Categories",
217
+ placeholder="Enter comma-separated categories...",
218
+ lines=2,
219
+ value=zero_shot_examples[0][1]
220
+ )
221
+
222
+ submit_btn = gr.Button("Submit")
223
+
224
+ outputs = [
225
+ gr.Label(label="πŸ“Š Results"),
226
+ gr.HTML(label="πŸ“ˆ Sentence Analysis")
227
+ ]
228
+
229
+ with gr.Column(variant="panel") as zero_shot_examples_panel:
230
+ gr.Examples(
231
+ examples=zero_shot_examples,
232
+ inputs=[text_input, labels_or_premise],
233
+ label="Zero-Shot Classification Examples",
234
+ )
235
+
236
+ with gr.Column(variant="panel") as nli_examples_panel:
237
+ gr.Examples(
238
+ examples=nli_examples,
239
+ inputs=[text_input, labels_or_premise],
240
+ label="Natural Language Inference Examples",
241
+ )
242
+
243
+ with gr.Column(variant="panel") as long_context_examples_panel:
244
+ gr.Examples(
245
+ examples=long_context_examples,
246
+ inputs=[text_input, labels_or_premise],
247
+ label="Long Context NLI Examples",
248
+ )
249
+
250
+ mode.change(
251
+ fn=update_interface,
252
+ inputs=[mode],
253
+ outputs=[labels_or_premise, text_input]
254
+ )
255
+
256
+ mode.change(
257
+ fn=update_visibility,
258
+ inputs=[mode],
259
+ outputs=[zero_shot_examples_panel, nli_examples_panel, long_context_examples_panel]
260
+ )
261
+
262
+ submit_btn.click(
263
+ fn=process_input,
264
+ inputs=[text_input, labels_or_premise, mode],
265
+ outputs=outputs
266
+ )
267
+
268
+ if __name__ == "__main__":
269
+ demo.launch()