Spaces:
Sleeping
Sleeping
Enhance prediction function with validation checks and improved error handling
Browse files- utils/prediction.py +26 -26
utils/prediction.py
CHANGED
|
@@ -46,43 +46,43 @@ def predict_sentence(model, sentence, tokenizer, label_encoder):
|
|
| 46 |
"""
|
| 47 |
Make prediction for a single sentence with label validation.
|
| 48 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
model.eval()
|
| 50 |
|
| 51 |
# Tokenize
|
| 52 |
-
encoding = tokenizer(
|
| 53 |
-
sentence,
|
| 54 |
-
add_special_tokens=True,
|
| 55 |
-
max_length=512,
|
| 56 |
-
padding='max_length',
|
| 57 |
-
truncation=True,
|
| 58 |
-
return_tensors='pt'
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
with torch.no_grad():
|
| 63 |
-
# Get model outputs
|
| 64 |
outputs = model(encoding['input_ids'], encoding['attention_mask'])
|
| 65 |
probabilities = torch.softmax(outputs, dim=1)
|
| 66 |
-
|
| 67 |
-
# Get prediction and probability
|
| 68 |
prob, pred_idx = torch.max(probabilities, dim=1)
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
# Validate prediction index
|
| 71 |
-
if pred_idx.item() >= len(label_encoder.classes_):
|
| 72 |
-
print(f"Warning: Model predicted invalid label index {pred_idx.item()}")
|
| 73 |
-
return "Unknown", 0.0
|
| 74 |
-
|
| 75 |
-
# Convert to label
|
| 76 |
-
try:
|
| 77 |
-
predicted_class = label_encoder.classes_[pred_idx.item()]
|
| 78 |
-
return predicted_class, prob.item()
|
| 79 |
-
except IndexError:
|
| 80 |
-
print(f"Warning: Invalid label index {pred_idx.item()}")
|
| 81 |
-
return "Unknown", 0.0
|
| 82 |
-
|
| 83 |
except Exception as e:
|
| 84 |
print(f"Prediction error: {str(e)}")
|
| 85 |
-
return "Error", 0.0
|
| 86 |
|
| 87 |
def print_labels(label_encoder, show_counts=False):
|
| 88 |
"""Print all labels and their corresponding indices"""
|
|
|
|
| 46 |
"""
|
| 47 |
Make prediction for a single sentence with label validation.
|
| 48 |
"""
|
| 49 |
+
# Validation checks
|
| 50 |
+
if model is None:
|
| 51 |
+
print("Error: Model not loaded")
|
| 52 |
+
return "Error: Model not loaded", 0.0
|
| 53 |
+
if tokenizer is None:
|
| 54 |
+
print("Error: Tokenizer not loaded")
|
| 55 |
+
return "Error: Tokenizer not loaded", 0.0
|
| 56 |
+
if label_encoder is None:
|
| 57 |
+
print("Error: Label encoder not loaded")
|
| 58 |
+
return "Error: Label encoder not loaded", 0.0
|
| 59 |
+
|
| 60 |
+
# Force CPU device
|
| 61 |
+
device = torch.device('cpu')
|
| 62 |
+
model = model.to(device)
|
| 63 |
model.eval()
|
| 64 |
|
| 65 |
# Tokenize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
try:
|
| 67 |
+
encoding = tokenizer(
|
| 68 |
+
sentence,
|
| 69 |
+
add_special_tokens=True,
|
| 70 |
+
max_length=512,
|
| 71 |
+
padding='max_length',
|
| 72 |
+
truncation=True,
|
| 73 |
+
return_tensors='pt'
|
| 74 |
+
).to(device)
|
| 75 |
+
|
| 76 |
with torch.no_grad():
|
|
|
|
| 77 |
outputs = model(encoding['input_ids'], encoding['attention_mask'])
|
| 78 |
probabilities = torch.softmax(outputs, dim=1)
|
|
|
|
|
|
|
| 79 |
prob, pred_idx = torch.max(probabilities, dim=1)
|
| 80 |
+
predicted_label = label_encoder.classes_[pred_idx.item()]
|
| 81 |
+
return predicted_label, prob.item()
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
except Exception as e:
|
| 84 |
print(f"Prediction error: {str(e)}")
|
| 85 |
+
return f"Error: {str(e)}", 0.0
|
| 86 |
|
| 87 |
def print_labels(label_encoder, show_counts=False):
|
| 88 |
"""Print all labels and their corresponding indices"""
|