from typing import Dict, List, Any import onnxruntime as ort import numpy as np from PIL import Image import io import base64 import os class EndpointHandler: def __init__(self, path=""): model_path = path if path else "." providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] self.encoder = ort.InferenceSession( os.path.join(model_path, "edge_sam_3x_encoder.onnx"), providers=providers ) self.decoder = ort.InferenceSession( os.path.join(model_path, "edge_sam_3x_decoder.onnx"), providers=providers ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: try: # Parse input inputs = data.get("inputs", data) params = data.get("parameters", {}) # Load image if isinstance(inputs, str): image = Image.open(io.BytesIO(base64.b64decode(inputs))) else: image = inputs # Preprocess if image.mode != 'RGB': image = image.convert('RGB') image = image.resize((1024, 1024), Image.BILINEAR) img_array = np.array(image).astype(np.float32) / 255.0 img_array = img_array.transpose(2, 0, 1)[np.newaxis, :] # Encode embeddings = self.encoder.run(None, {'image': img_array})[0] # Prepare prompts coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32) labels = np.array(params.get("point_labels", [1]), dtype=np.float32) # Decode decoder_outputs = self.decoder.run(None, { 'image_embeddings': embeddings, 'point_coords': coords.reshape(1, -1, 2), 'point_labels': labels.reshape(1, -1) }) # decoder_outputs[0] is IoU scores (1, 4) # decoder_outputs[1] is masks (1, 4, 256, 256) masks = decoder_outputs[1] # Take first mask and resize to 1024x1024 mask = masks[0, 0] # Shape: (256, 256) mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR) mask = np.array(mask) mask = (mask > 0.0).astype(np.uint8) * 255 # Return result result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)} if params.get("return_mask_image", True): buffer = io.BytesIO() Image.fromarray(mask, mode='L').save(buffer, format='PNG') result["mask"] = base64.b64encode(buffer.getvalue()).decode() return [result] except Exception as e: import traceback return [{ "error": str(e), "type": type(e).__name__, "traceback": traceback.format_exc() }]