output prob distribution
Browse files- handler.py +16 -5
handler.py
CHANGED
|
@@ -18,6 +18,7 @@ class EndpointHandler:
|
|
| 18 |
if not input_text:
|
| 19 |
return {"error": "No input provided."}
|
| 20 |
|
|
|
|
| 21 |
inputs = self.tokenizer(
|
| 22 |
input_text,
|
| 23 |
return_tensors="pt",
|
|
@@ -27,14 +28,24 @@ class EndpointHandler:
|
|
| 27 |
)
|
| 28 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 29 |
|
|
|
|
| 30 |
with torch.no_grad():
|
| 31 |
outputs = self.model(**inputs)
|
| 32 |
-
probs = torch.softmax(outputs.logits, dim=-1)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
return {
|
| 38 |
"label": top_class_label,
|
| 39 |
-
"
|
| 40 |
}
|
|
|
|
|
|
| 18 |
if not input_text:
|
| 19 |
return {"error": "No input provided."}
|
| 20 |
|
| 21 |
+
# Tokenization
|
| 22 |
inputs = self.tokenizer(
|
| 23 |
input_text,
|
| 24 |
return_tensors="pt",
|
|
|
|
| 28 |
)
|
| 29 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 30 |
|
| 31 |
+
# Forward pass
|
| 32 |
with torch.no_grad():
|
| 33 |
outputs = self.model(**inputs)
|
| 34 |
+
probs = torch.softmax(outputs.logits, dim=-1)[0] # shape: (num_classes,)
|
| 35 |
+
|
| 36 |
+
# Get top class
|
| 37 |
+
top_class_id = torch.argmax(probs).item()
|
| 38 |
+
top_class_label = self.id2label.get(top_class_id) or self.id2label.get(str(top_class_id))
|
| 39 |
+
top_class_prob = probs[top_class_id].item()
|
| 40 |
+
|
| 41 |
+
# Convert full distribution to label->probability dict
|
| 42 |
+
prob_distribution = {
|
| 43 |
+
self.id2label.get(i) or self.id2label.get(str(i)): round(p.item(), 4)
|
| 44 |
+
for i, p in enumerate(probs)
|
| 45 |
+
}
|
| 46 |
|
| 47 |
return {
|
| 48 |
"label": top_class_label,
|
| 49 |
+
"probabilities": prob_distribution
|
| 50 |
}
|
| 51 |
+
|