Spaces:
Runtime error
Runtime error
feat(model):
Browse files- add support for new model (model_7) for image classification task
- extended CLASS_NAMES to support new model class labels
- added resource loading for the new model
- accessing the new model in the predict_image function
- updated the combined_results and combined_outputs to include new model output
- added a new tile in the HTML results for the new model
♻️ style(frontend):
- increase w-24 to w-30 in HTML CSS snippet
Note: Preferred writing "feat" for adding new model, even if it is just adding an URLs, and "chore" for any small tutorials added.
app.py
CHANGED
|
@@ -27,7 +27,8 @@ MODEL_PATHS = {
|
|
| 27 |
"model_4": "cmckinle/sdxl-flux-detector",
|
| 28 |
"model_5": "prithivMLmods/Deep-Fake-Detector-v2-Model",
|
| 29 |
"model_5b": "prithivMLmods/Deepfake-Detection-Exp-02-22",
|
| 30 |
-
"model_6": "ideepankarsharma2003/AI_ImageClassification_MidjourneyV6_SDXL"
|
|
|
|
| 31 |
}
|
| 32 |
|
| 33 |
CLASS_NAMES = {
|
|
@@ -38,6 +39,7 @@ CLASS_NAMES = {
|
|
| 38 |
"model_5": ['Realism', 'Deepfake'],
|
| 39 |
"model_5b": ['Real', 'Deepfake'],
|
| 40 |
"model_6": ['ai_gen', 'human'],
|
|
|
|
| 41 |
|
| 42 |
}
|
| 43 |
|
|
@@ -63,9 +65,13 @@ def load_models():
|
|
| 63 |
model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
|
| 64 |
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6
|
|
|
|
|
|
|
| 69 |
|
| 70 |
@spaces.GPU(duration=10)
|
| 71 |
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
|
@@ -118,6 +124,7 @@ def predict_image(img, confidence_threshold):
|
|
| 118 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
| 119 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
| 120 |
label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7)
|
|
|
|
| 121 |
|
| 122 |
combined_results = {
|
| 123 |
"SwinV2/detect": label_1,
|
|
@@ -126,11 +133,12 @@ def predict_image(img, confidence_threshold):
|
|
| 126 |
"Swin/SDXL-FLUX": label_4,
|
| 127 |
"prithivMLmods": label_5,
|
| 128 |
"prithivMLmods-2-22": label_5b,
|
| 129 |
-
"SwinMidSDXL": label_6
|
|
|
|
| 130 |
}
|
| 131 |
print(combined_results)
|
| 132 |
|
| 133 |
-
combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput, result_6output]
|
| 134 |
return img_pil, combined_outputs
|
| 135 |
# Define a function to generate the HTML content
|
| 136 |
|
|
@@ -159,7 +167,7 @@ def generate_results_html(results):
|
|
| 159 |
class="-m-4 h-24 {header_colors[0]} rounded-sm rounded-b-none transition border group-hover:border-gray-100 group-hover:shadow-lg group-hover:{header_colors[4]}">
|
| 160 |
<span class="text-gray-300 font-mono tracking-widest p-4 pb-3 block text-xs text-center">MODEL {index + 1}:</span>
|
| 161 |
<span
|
| 162 |
-
class="flex w-
|
| 163 |
>
|
| 164 |
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="3" stroke="currentColor" class="w-4 h-4 mr-2 -ml-3 group-hover:animate group-hover:animate-pulse">
|
| 165 |
{'<path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75 11.25 15 15 9.75M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />' if label == 'REAL' else '<path stroke-linecap="round" stroke-linejoin="round" d="m9.75 9.75 4.5 4.5m0-4.5-4.5 4.5M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />'}
|
|
@@ -207,7 +215,8 @@ def generate_results_html(results):
|
|
| 207 |
{generate_tile_html(3, results[3], "SDXL + FLUX", "cmckinle", MODEL_PATHS["model_4"])}
|
| 208 |
{generate_tile_html(4, results[4], "Vit Based", "prithivMLmods", MODEL_PATHS["model_5"])}
|
| 209 |
{generate_tile_html(5, results[5], "Vit Based, Newer Dataset", "prithivMLmods", MODEL_PATHS["model_5b"])}
|
| 210 |
-
{generate_tile_html(6, results[6], "Swin,
|
|
|
|
| 211 |
</div>
|
| 212 |
</div>
|
| 213 |
"""
|
|
|
|
| 27 |
"model_4": "cmckinle/sdxl-flux-detector",
|
| 28 |
"model_5": "prithivMLmods/Deep-Fake-Detector-v2-Model",
|
| 29 |
"model_5b": "prithivMLmods/Deepfake-Detection-Exp-02-22",
|
| 30 |
+
"model_6": "ideepankarsharma2003/AI_ImageClassification_MidjourneyV6_SDXL",
|
| 31 |
+
"model_7": "date3k2/vit-real-fake-classification-v4"
|
| 32 |
}
|
| 33 |
|
| 34 |
CLASS_NAMES = {
|
|
|
|
| 39 |
"model_5": ['Realism', 'Deepfake'],
|
| 40 |
"model_5b": ['Real', 'Deepfake'],
|
| 41 |
"model_6": ['ai_gen', 'human'],
|
| 42 |
+
"model_7": ['Fake', 'Real'],
|
| 43 |
|
| 44 |
}
|
| 45 |
|
|
|
|
| 65 |
model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
|
| 66 |
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
| 67 |
|
| 68 |
+
image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
|
| 69 |
+
model_7 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_7"]).to(device)
|
| 70 |
+
clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
|
| 71 |
|
| 72 |
+
return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6, model_7, clf_7
|
| 73 |
+
|
| 74 |
+
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6, model_7, clf_7 = load_models()
|
| 75 |
|
| 76 |
@spaces.GPU(duration=10)
|
| 77 |
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
|
|
|
| 124 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
| 125 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
| 126 |
label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7)
|
| 127 |
+
label_7, result_7output = predict_with_model(img_pilvits, clf_7, CLASS_NAMES["model_7"], confidence_threshold, "Vit", 7)
|
| 128 |
|
| 129 |
combined_results = {
|
| 130 |
"SwinV2/detect": label_1,
|
|
|
|
| 133 |
"Swin/SDXL-FLUX": label_4,
|
| 134 |
"prithivMLmods": label_5,
|
| 135 |
"prithivMLmods-2-22": label_5b,
|
| 136 |
+
"SwinMidSDXL": label_6,
|
| 137 |
+
"Vit": label_7
|
| 138 |
}
|
| 139 |
print(combined_results)
|
| 140 |
|
| 141 |
+
combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput, result_6output, result_7output]
|
| 142 |
return img_pil, combined_outputs
|
| 143 |
# Define a function to generate the HTML content
|
| 144 |
|
|
|
|
| 167 |
class="-m-4 h-24 {header_colors[0]} rounded-sm rounded-b-none transition border group-hover:border-gray-100 group-hover:shadow-lg group-hover:{header_colors[4]}">
|
| 168 |
<span class="text-gray-300 font-mono tracking-widest p-4 pb-3 block text-xs text-center">MODEL {index + 1}:</span>
|
| 169 |
<span
|
| 170 |
+
class="flex w-30 mx-auto tracking-wide items-center justify-center rounded-full {header_colors[2]} px-1 py-0.5 {header_colors[3]}"
|
| 171 |
>
|
| 172 |
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="3" stroke="currentColor" class="w-4 h-4 mr-2 -ml-3 group-hover:animate group-hover:animate-pulse">
|
| 173 |
{'<path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75 11.25 15 15 9.75M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />' if label == 'REAL' else '<path stroke-linecap="round" stroke-linejoin="round" d="m9.75 9.75 4.5 4.5m0-4.5-4.5 4.5M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />'}
|
|
|
|
| 215 |
{generate_tile_html(3, results[3], "SDXL + FLUX", "cmckinle", MODEL_PATHS["model_4"])}
|
| 216 |
{generate_tile_html(4, results[4], "Vit Based", "prithivMLmods", MODEL_PATHS["model_5"])}
|
| 217 |
{generate_tile_html(5, results[5], "Vit Based, Newer Dataset", "prithivMLmods", MODEL_PATHS["model_5b"])}
|
| 218 |
+
{generate_tile_html(6, results[6], "Swin, Midj + SDXL", "ideepankarsharma2003", MODEL_PATHS["model_6"])}
|
| 219 |
+
{generate_tile_html(7, results[7], "ViT", "temp", MODEL_PATHS["model_7"])}
|
| 220 |
</div>
|
| 221 |
</div>
|
| 222 |
"""
|