EdgeSAM / handler.py
JingShiang Yang
Fix mask output: use decoder output[1] and resize to 1024x1024
2e11ff8
from typing import Dict, List, Any
import onnxruntime as ort
import numpy as np
from PIL import Image
import io
import base64
import os
class EndpointHandler:
def __init__(self, path=""):
model_path = path if path else "."
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self.encoder = ort.InferenceSession(
os.path.join(model_path, "edge_sam_3x_encoder.onnx"),
providers=providers
)
self.decoder = ort.InferenceSession(
os.path.join(model_path, "edge_sam_3x_decoder.onnx"),
providers=providers
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
# Parse input
inputs = data.get("inputs", data)
params = data.get("parameters", {})
# Load image
if isinstance(inputs, str):
image = Image.open(io.BytesIO(base64.b64decode(inputs)))
else:
image = inputs
# Preprocess
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize((1024, 1024), Image.BILINEAR)
img_array = np.array(image).astype(np.float32) / 255.0
img_array = img_array.transpose(2, 0, 1)[np.newaxis, :]
# Encode
embeddings = self.encoder.run(None, {'image': img_array})[0]
# Prepare prompts
coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32)
labels = np.array(params.get("point_labels", [1]), dtype=np.float32)
# Decode
decoder_outputs = self.decoder.run(None, {
'image_embeddings': embeddings,
'point_coords': coords.reshape(1, -1, 2),
'point_labels': labels.reshape(1, -1)
})
# decoder_outputs[0] is IoU scores (1, 4)
# decoder_outputs[1] is masks (1, 4, 256, 256)
masks = decoder_outputs[1]
# Take first mask and resize to 1024x1024
mask = masks[0, 0] # Shape: (256, 256)
mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR)
mask = np.array(mask)
mask = (mask > 0.0).astype(np.uint8) * 255
# Return result
result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)}
if params.get("return_mask_image", True):
buffer = io.BytesIO()
Image.fromarray(mask, mode='L').save(buffer, format='PNG')
result["mask"] = base64.b64encode(buffer.getvalue()).decode()
return [result]
except Exception as e:
import traceback
return [{
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}]