fix(bug): feature extractor added back in
Browse files
app.py
CHANGED
|
@@ -61,10 +61,19 @@ def load_models():
|
|
| 61 |
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b = load_models()
|
| 62 |
|
| 63 |
@spaces.GPU(duration=10)
|
| 64 |
-
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id):
|
| 65 |
try:
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
result_output = [model_id, model_name, result.get(class_names[1], 0.0), result.get(class_names[0], 0.0)]
|
| 69 |
logger.info(result_output)
|
| 70 |
for class_name in class_names:
|
|
@@ -97,8 +106,8 @@ def predict_image(img, confidence_threshold):
|
|
| 97 |
|
| 98 |
label_1, result_1output = predict_with_model(img_pil, clf_1, CLASS_NAMES["model_1"], confidence_threshold, "SwinV2-base", 1)
|
| 99 |
label_2, result_2output = predict_with_model(img_pilvits, clf_2, CLASS_NAMES["model_2"], confidence_threshold, "ViT-base Classifier", 2)
|
| 100 |
-
label_3, result_3output = predict_with_model(img_pil, model_3, CLASS_NAMES["model_3"], confidence_threshold, "SDXL-Trained", 3)
|
| 101 |
-
label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4)
|
| 102 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
| 103 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
| 104 |
|
|
@@ -113,7 +122,6 @@ def predict_image(img, confidence_threshold):
|
|
| 113 |
|
| 114 |
combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput]
|
| 115 |
return img_pil, combined_outputs
|
| 116 |
-
|
| 117 |
# Define a function to generate the HTML content
|
| 118 |
def generate_results_html(results):
|
| 119 |
def get_header_color(label):
|
|
|
|
| 61 |
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b = load_models()
|
| 62 |
|
| 63 |
@spaces.GPU(duration=10)
|
| 64 |
+
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
| 65 |
try:
|
| 66 |
+
if feature_extractor:
|
| 67 |
+
inputs = feature_extractor(img_pil, return_tensors="pt").to(device)
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
outputs = clf(**inputs)
|
| 70 |
+
logits = outputs.logits
|
| 71 |
+
probabilities = softmax(logits.cpu().numpy()[0])
|
| 72 |
+
result = {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
| 73 |
+
else:
|
| 74 |
+
prediction = clf(img_pil)
|
| 75 |
+
result = {pred['label']: pred['score'] for pred in prediction}
|
| 76 |
+
|
| 77 |
result_output = [model_id, model_name, result.get(class_names[1], 0.0), result.get(class_names[0], 0.0)]
|
| 78 |
logger.info(result_output)
|
| 79 |
for class_name in class_names:
|
|
|
|
| 106 |
|
| 107 |
label_1, result_1output = predict_with_model(img_pil, clf_1, CLASS_NAMES["model_1"], confidence_threshold, "SwinV2-base", 1)
|
| 108 |
label_2, result_2output = predict_with_model(img_pilvits, clf_2, CLASS_NAMES["model_2"], confidence_threshold, "ViT-base Classifier", 2)
|
| 109 |
+
label_3, result_3output = predict_with_model(img_pil, model_3, CLASS_NAMES["model_3"], confidence_threshold, "SDXL-Trained", 3, feature_extractor_3)
|
| 110 |
+
label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
|
| 111 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
| 112 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
| 113 |
|
|
|
|
| 122 |
|
| 123 |
combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput]
|
| 124 |
return img_pil, combined_outputs
|
|
|
|
| 125 |
# Define a function to generate the HTML content
|
| 126 |
def generate_results_html(results):
|
| 127 |
def get_header_color(label):
|