feat: add Gradio API integration and ONNX preprocessing functions
Browse files- app.py +2 -92
- utils/onnx_helpers.py +45 -0
- utils/utils.py +43 -0
app.py
CHANGED
|
@@ -15,7 +15,8 @@ import concurrent.futures
|
|
| 15 |
import ast
|
| 16 |
import torch
|
| 17 |
|
| 18 |
-
from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar
|
|
|
|
| 19 |
from forensics.gradient import gradient_processing
|
| 20 |
from forensics.minmax import minmax_process
|
| 21 |
from forensics.ela import ELA
|
|
@@ -90,48 +91,6 @@ CLASS_NAMES = {
|
|
| 90 |
}
|
| 91 |
|
| 92 |
|
| 93 |
-
def infer_gradio_api(image_path):
|
| 94 |
-
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
| 95 |
-
result_dict = client.predict(
|
| 96 |
-
input_image=handle_file(image_path),
|
| 97 |
-
api_name="/simple_predict"
|
| 98 |
-
)
|
| 99 |
-
logger.info(f"Debug: Raw result_dict from Gradio API (model_8): {result_dict}, type: {type(result_dict)}")
|
| 100 |
-
# result_dict is already a dictionary, no need for ast.literal_eval
|
| 101 |
-
fake_probability = result_dict.get('Fake Probability', 0.0)
|
| 102 |
-
logger.info(f"Debug: Parsed result_dict: {result_dict}, Extracted fake_probability: {fake_probability}")
|
| 103 |
-
return {"probabilities": np.array([fake_probability])} # Return as a numpy array with one element
|
| 104 |
-
|
| 105 |
-
# New preprocess function for Gradio API
|
| 106 |
-
def preprocess_gradio_api(image: Image.Image):
|
| 107 |
-
# The Gradio API expects a file path, so we need to save the PIL Image to a temporary file.
|
| 108 |
-
temp_file_path = "./temp_gradio_input.png"
|
| 109 |
-
image.save(temp_file_path)
|
| 110 |
-
return temp_file_path
|
| 111 |
-
|
| 112 |
-
# New postprocess function for Gradio API (adapting postprocess_binary_output)
|
| 113 |
-
def postprocess_gradio_api(gradio_output, class_names):
|
| 114 |
-
# gradio_output is expected to be a dictionary like {"probabilities": np.array([fake_prob])}
|
| 115 |
-
probabilities_array = None
|
| 116 |
-
if isinstance(gradio_output, dict) and "probabilities" in gradio_output:
|
| 117 |
-
probabilities_array = gradio_output["probabilities"]
|
| 118 |
-
elif isinstance(gradio_output, np.ndarray):
|
| 119 |
-
probabilities_array = gradio_output
|
| 120 |
-
else:
|
| 121 |
-
logger.warning(f"Unexpected output type for Gradio API post-processing: {type(gradio_output)}. Expected dict with 'probabilities' or numpy.ndarray.")
|
| 122 |
-
return {class_names[0]: 0.0, class_names[1]: 1.0}
|
| 123 |
-
|
| 124 |
-
logger.info(f"Debug: Probabilities array entering postprocess_gradio_api: {probabilities_array}, type: {type(probabilities_array)}, shape: {probabilities_array.shape}")
|
| 125 |
-
|
| 126 |
-
if probabilities_array is None or probabilities_array.size == 0:
|
| 127 |
-
logger.warning("Probabilities array is None or empty after extracting from Gradio API output. Returning default scores.")
|
| 128 |
-
return {class_names[0]: 0.0, class_names[1]: 1.0}
|
| 129 |
-
|
| 130 |
-
# It should always be a single element array for fake probability
|
| 131 |
-
fake_prob = float(probabilities_array.item())
|
| 132 |
-
real_prob = 1.0 - fake_prob
|
| 133 |
-
|
| 134 |
-
return {class_names[0]: fake_prob, class_names[1]: real_prob}
|
| 135 |
|
| 136 |
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
|
| 137 |
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
|
|
@@ -178,27 +137,6 @@ def get_onnx_model_from_cache(hf_model_id):
|
|
| 178 |
_onnx_model_cache[hf_model_id] = load_onnx_model_and_preprocessor(hf_model_id)
|
| 179 |
return _onnx_model_cache[hf_model_id]
|
| 180 |
|
| 181 |
-
def preprocess_onnx_input(image: Image.Image, preprocessor_config: dict):
|
| 182 |
-
# Preprocess image for ONNX model based on preprocessor_config
|
| 183 |
-
if image.mode != 'RGB':
|
| 184 |
-
image = image.convert('RGB')
|
| 185 |
-
|
| 186 |
-
# Get image size and normalization values from preprocessor_config or use defaults
|
| 187 |
-
# Use 'size' for initial resize and 'crop_size' for center cropping
|
| 188 |
-
initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224})
|
| 189 |
-
crop_size = preprocessor_config.get('crop_size', initial_resize_size['height'])
|
| 190 |
-
mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406])
|
| 191 |
-
std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225])
|
| 192 |
-
|
| 193 |
-
transform = transforms.Compose([
|
| 194 |
-
transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])),
|
| 195 |
-
transforms.CenterCrop(crop_size), # Apply center crop
|
| 196 |
-
transforms.ToTensor(),
|
| 197 |
-
transforms.Normalize(mean=mean, std=std),
|
| 198 |
-
])
|
| 199 |
-
input_tensor = transform(image)
|
| 200 |
-
# ONNX expects numpy array with batch dimension (1, C, H, W)
|
| 201 |
-
return input_tensor.unsqueeze(0).cpu().numpy()
|
| 202 |
|
| 203 |
def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
| 204 |
try:
|
|
@@ -229,34 +167,6 @@ def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
|
| 229 |
# Return a structure consistent with other model errors
|
| 230 |
return {"logits": np.array([]), "probabilities": np.array([])}
|
| 231 |
|
| 232 |
-
def postprocess_onnx_output(onnx_output, model_config):
|
| 233 |
-
# Get class names from model_config
|
| 234 |
-
# Prioritize id2label, then check num_classes, otherwise default
|
| 235 |
-
class_names_map = model_config.get('id2label')
|
| 236 |
-
if class_names_map:
|
| 237 |
-
class_names = [class_names_map[k] for k in sorted(class_names_map.keys())]
|
| 238 |
-
elif model_config.get('num_classes') == 1: # Handle models that output a single value (e.g., probability of 'Fake')
|
| 239 |
-
class_names = ['Fake', 'Real'] # Assume first class is 'Fake' and second 'Real'
|
| 240 |
-
else:
|
| 241 |
-
class_names = {0: 'Fake', 1: 'Real'} # Default to Fake/Real if not found or not 1 class
|
| 242 |
-
class_names = [class_names[i] for i in sorted(class_names.keys())]
|
| 243 |
-
|
| 244 |
-
probabilities = onnx_output.get("probabilities")
|
| 245 |
-
|
| 246 |
-
if probabilities is not None:
|
| 247 |
-
if model_config.get('num_classes') == 1 and len(probabilities) == 2: # Special handling for single output models
|
| 248 |
-
# The single output is the probability of the 'Fake' class
|
| 249 |
-
fake_prob = float(probabilities[0])
|
| 250 |
-
real_prob = float(probabilities[1])
|
| 251 |
-
return {class_names[0]: fake_prob, class_names[1]: real_prob}
|
| 252 |
-
elif len(probabilities) == len(class_names):
|
| 253 |
-
return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
|
| 254 |
-
else:
|
| 255 |
-
logger.warning("ONNX post-processing: Probabilities length mismatch with class names.")
|
| 256 |
-
return {name: 0.0 for name in class_names}
|
| 257 |
-
else:
|
| 258 |
-
logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.")
|
| 259 |
-
return {name: 0.0 for name in class_names}
|
| 260 |
|
| 261 |
# Register the ONNX quantized model
|
| 262 |
# Dummy entry for ONNX model to be loaded dynamically
|
|
|
|
| 15 |
import ast
|
| 16 |
import torch
|
| 17 |
|
| 18 |
+
from utils.utils import softmax, augment_image, preprocess_resize_256, preprocess_resize_224, postprocess_pipeline, postprocess_logits, postprocess_binary_output, to_float_scalar, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
|
| 19 |
+
from utils.onnx_helpers import preprocess_onnx_input, postprocess_onnx_output
|
| 20 |
from forensics.gradient import gradient_processing
|
| 21 |
from forensics.minmax import minmax_process
|
| 22 |
from forensics.ela import ELA
|
|
|
|
| 91 |
}
|
| 92 |
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
|
| 96 |
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
|
|
|
|
| 137 |
_onnx_model_cache[hf_model_id] = load_onnx_model_and_preprocessor(hf_model_id)
|
| 138 |
return _onnx_model_cache[hf_model_id]
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
| 142 |
try:
|
|
|
|
| 167 |
# Return a structure consistent with other model errors
|
| 168 |
return {"logits": np.array([]), "probabilities": np.array([])}
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# Register the ONNX quantized model
|
| 172 |
# Dummy entry for ONNX model to be loaded dynamically
|
utils/onnx_helpers.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
def preprocess_onnx_input(image, preprocessor_config):
|
| 7 |
+
if image.mode != 'RGB':
|
| 8 |
+
image = image.convert('RGB')
|
| 9 |
+
initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224})
|
| 10 |
+
crop_size = preprocessor_config.get('crop_size', initial_resize_size['height'])
|
| 11 |
+
mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406])
|
| 12 |
+
std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225])
|
| 13 |
+
transform = transforms.Compose([
|
| 14 |
+
transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])),
|
| 15 |
+
transforms.CenterCrop(crop_size),
|
| 16 |
+
transforms.ToTensor(),
|
| 17 |
+
transforms.Normalize(mean=mean, std=std),
|
| 18 |
+
])
|
| 19 |
+
input_tensor = transform(image)
|
| 20 |
+
return input_tensor.unsqueeze(0).cpu().numpy()
|
| 21 |
+
|
| 22 |
+
def postprocess_onnx_output(onnx_output, model_config):
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
class_names_map = model_config.get('id2label')
|
| 25 |
+
if class_names_map:
|
| 26 |
+
class_names = [class_names_map[k] for k in sorted(class_names_map.keys())]
|
| 27 |
+
elif model_config.get('num_classes') == 1:
|
| 28 |
+
class_names = ['Fake', 'Real']
|
| 29 |
+
else:
|
| 30 |
+
class_names = {0: 'Fake', 1: 'Real'}
|
| 31 |
+
class_names = [class_names[i] for i in sorted(class_names.keys())]
|
| 32 |
+
probabilities = onnx_output.get("probabilities")
|
| 33 |
+
if probabilities is not None:
|
| 34 |
+
if model_config.get('num_classes') == 1 and len(probabilities) == 2:
|
| 35 |
+
fake_prob = float(probabilities[0])
|
| 36 |
+
real_prob = float(probabilities[1])
|
| 37 |
+
return {class_names[0]: fake_prob, class_names[1]: real_prob}
|
| 38 |
+
elif len(probabilities) == len(class_names):
|
| 39 |
+
return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
|
| 40 |
+
else:
|
| 41 |
+
logger.warning("ONNX post-processing: Probabilities length mismatch with class names.")
|
| 42 |
+
return {name: 0.0 for name in class_names}
|
| 43 |
+
else:
|
| 44 |
+
logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.")
|
| 45 |
+
return {name: 0.0 for name in class_names}
|
utils/utils.py
CHANGED
|
@@ -1,3 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def preprocess_resize_256(image):
|
| 2 |
if image.mode != 'RGB':
|
| 3 |
image = image.convert('RGB')
|
|
|
|
| 1 |
+
def infer_gradio_api(image_path):
|
| 2 |
+
from gradio_client import Client, handle_file
|
| 3 |
+
import numpy as np
|
| 4 |
+
import logging
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
| 7 |
+
result_dict = client.predict(
|
| 8 |
+
input_image=handle_file(image_path),
|
| 9 |
+
api_name="/simple_predict"
|
| 10 |
+
)
|
| 11 |
+
logger.info(f"Debug: Raw result_dict from Gradio API (model_8): {result_dict}, type: {type(result_dict)}")
|
| 12 |
+
fake_probability = result_dict.get('Fake Probability', 0.0)
|
| 13 |
+
logger.info(f"Debug: Parsed result_dict: {result_dict}, Extracted fake_probability: {fake_probability}")
|
| 14 |
+
return {"probabilities": np.array([fake_probability])}
|
| 15 |
+
|
| 16 |
+
def preprocess_gradio_api(image):
|
| 17 |
+
temp_file_path = "./temp_gradio_input.png"
|
| 18 |
+
image.save(temp_file_path)
|
| 19 |
+
return temp_file_path
|
| 20 |
+
|
| 21 |
+
def postprocess_gradio_api(gradio_output, class_names):
|
| 22 |
+
import numpy as np
|
| 23 |
+
import logging
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
probabilities_array = None
|
| 26 |
+
if isinstance(gradio_output, dict) and "probabilities" in gradio_output:
|
| 27 |
+
probabilities_array = gradio_output["probabilities"]
|
| 28 |
+
elif isinstance(gradio_output, np.ndarray):
|
| 29 |
+
probabilities_array = gradio_output
|
| 30 |
+
else:
|
| 31 |
+
logger.warning(f"Unexpected output type for Gradio API post-processing: {type(gradio_output)}. Expected dict with 'probabilities' or numpy.ndarray.")
|
| 32 |
+
return {class_names[0]: 0.0, class_names[1]: 1.0}
|
| 33 |
+
|
| 34 |
+
logger.info(f"Debug: Probabilities array entering postprocess_gradio_api: {probabilities_array}, type: {type(probabilities_array)}, shape: {getattr(probabilities_array, 'shape', None)}")
|
| 35 |
+
|
| 36 |
+
if probabilities_array is None or probabilities_array.size == 0:
|
| 37 |
+
logger.warning("Probabilities array is None or empty after extracting from Gradio API output. Returning default scores.")
|
| 38 |
+
return {class_names[0]: 0.0, class_names[1]: 1.0}
|
| 39 |
+
|
| 40 |
+
fake_prob = float(probabilities_array.item())
|
| 41 |
+
real_prob = 1.0 - fake_prob
|
| 42 |
+
return {class_names[0]: fake_prob, class_names[1]: real_prob}
|
| 43 |
+
|
| 44 |
def preprocess_resize_256(image):
|
| 45 |
if image.mode != 'RGB':
|
| 46 |
image = image.convert('RGB')
|