hazarri commited on
Commit
b397916
·
verified ·
1 Parent(s): 880ab0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -12
app.py CHANGED
@@ -1,26 +1,64 @@
1
  from transformers import pipeline
2
  import gradio as gr
3
 
4
- # Load zero-shot classification pipeline
5
- classifier = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0")
 
 
 
6
 
7
- # Define possible classes
8
- labels = ["dangerous", "mild", "neutral"]
9
 
10
- def classify_side_effect(text):
 
 
 
 
 
 
 
 
11
  if not text.strip():
12
  return {"error": "Empty input"}
 
 
 
13
  result = classifier(text, candidate_labels=labels)
14
- scores = {label: float(score) for label, score in zip(result["labels"], result["scores"])}
15
- return scores
 
 
 
 
 
16
 
17
- # Gradio interface
18
- iface = gr.Interface(
19
  fn=classify_side_effect,
 
 
 
 
 
 
 
 
 
 
 
 
20
  inputs=gr.Textbox(label="Enter a side effect or sentence"),
21
- outputs=gr.Label(label="Classification Result"),
22
  title="Zero-Shot ADR Severity Classifier",
23
- description="Classifies a given sentence (e.g. side effect) as dangerous, mild, or neutral using DeBERTa v3 Large Zero-Shot."
 
 
 
 
 
 
24
  )
25
 
26
- iface.launch()
 
 
1
  from transformers import pipeline
2
  import gradio as gr
3
 
4
+ # Load the DeBERTa zero-shot classifier
5
+ classifier = pipeline(
6
+ "zero-shot-classification",
7
+ model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
8
+ )
9
 
10
+ # Default candidate labels (can be overridden by API)
11
+ DEFAULT_LABELS = ["mild", "moderate", "severe", "life-threatening", "death"]
12
 
13
+ def classify_side_effect(text, candidate_labels=None):
14
+ """
15
+ Classify the severity of a side effect using zero-shot classification.
16
+ Args:
17
+ text (str): Input text describing the side effect.
18
+ candidate_labels (list[str]): Optional list of labels.
19
+ Returns:
20
+ dict: Predicted labels and scores.
21
+ """
22
  if not text.strip():
23
  return {"error": "Empty input"}
24
+
25
+ labels = candidate_labels if candidate_labels else DEFAULT_LABELS
26
+
27
  result = classifier(text, candidate_labels=labels)
28
+ # Return structured output for API
29
+ return {
30
+ "labels": result["labels"],
31
+ "scores": [float(s) for s in result["scores"]],
32
+ "top_label": result["labels"][0],
33
+ "top_score": float(result["scores"][0])
34
+ }
35
 
36
+ # Define the API endpoint (for programmatic use)
37
+ api = gr.Interface(
38
  fn=classify_side_effect,
39
+ inputs=[
40
+ gr.Textbox(label="Side Effect Text"),
41
+ gr.Textbox(label="Candidate Labels (comma-separated, optional)")
42
+ ],
43
+ outputs=gr.JSON(label="Classification Result"),
44
+ title="Zero-Shot ADR Severity Classifier API",
45
+ description="Predicts the severity level of a side effect using DeBERTa-v3 large zero-shot classification."
46
+ )
47
+
48
+ # Add a user-friendly UI for manual testing
49
+ demo = gr.Interface(
50
+ fn=lambda text: classify_side_effect(text),
51
  inputs=gr.Textbox(label="Enter a side effect or sentence"),
52
+ outputs=gr.Label(label="Top Predicted Severity"),
53
  title="Zero-Shot ADR Severity Classifier",
54
+ description="Classifies a side effect sentence into severity categories using DeBERTa v3."
55
+ )
56
+
57
+ # Combine both: API and Demo
58
+ demo_and_api = gr.TabbedInterface(
59
+ [demo, api],
60
+ ["Demo", "API"]
61
  )
62
 
63
+ if __name__ == "__main__":
64
+ demo_and_api.launch()