File size: 2,906 Bytes
594e70f d5f470e 594e70f d5f470e 594e70f 2e11ff8 d5f470e 2e11ff8 d5f470e 594e70f decc98d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import Dict, List, Any
import onnxruntime as ort
import numpy as np
from PIL import Image
import io
import base64
import os
class EndpointHandler:
def __init__(self, path=""):
model_path = path if path else "."
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self.encoder = ort.InferenceSession(
os.path.join(model_path, "edge_sam_3x_encoder.onnx"),
providers=providers
)
self.decoder = ort.InferenceSession(
os.path.join(model_path, "edge_sam_3x_decoder.onnx"),
providers=providers
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
# Parse input
inputs = data.get("inputs", data)
params = data.get("parameters", {})
# Load image
if isinstance(inputs, str):
image = Image.open(io.BytesIO(base64.b64decode(inputs)))
else:
image = inputs
# Preprocess
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize((1024, 1024), Image.BILINEAR)
img_array = np.array(image).astype(np.float32) / 255.0
img_array = img_array.transpose(2, 0, 1)[np.newaxis, :]
# Encode
embeddings = self.encoder.run(None, {'image': img_array})[0]
# Prepare prompts
coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32)
labels = np.array(params.get("point_labels", [1]), dtype=np.float32)
# Decode
decoder_outputs = self.decoder.run(None, {
'image_embeddings': embeddings,
'point_coords': coords.reshape(1, -1, 2),
'point_labels': labels.reshape(1, -1)
})
# decoder_outputs[0] is IoU scores (1, 4)
# decoder_outputs[1] is masks (1, 4, 256, 256)
masks = decoder_outputs[1]
# Take first mask and resize to 1024x1024
mask = masks[0, 0] # Shape: (256, 256)
mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR)
mask = np.array(mask)
mask = (mask > 0.0).astype(np.uint8) * 255
# Return result
result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)}
if params.get("return_mask_image", True):
buffer = io.BytesIO()
Image.fromarray(mask, mode='L').save(buffer, format='PNG')
result["mask"] = base64.b64encode(buffer.getvalue()).decode()
return [result]
except Exception as e:
import traceback
return [{
"error": str(e),
"type": type(e).__name__,
"traceback": traceback.format_exc()
}]
|