|
|
from typing import Dict, List, Any |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
import os |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
model_path = path if path else "." |
|
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
|
|
|
|
self.encoder = ort.InferenceSession( |
|
|
os.path.join(model_path, "edge_sam_3x_encoder.onnx"), |
|
|
providers=providers |
|
|
) |
|
|
self.decoder = ort.InferenceSession( |
|
|
os.path.join(model_path, "edge_sam_3x_decoder.onnx"), |
|
|
providers=providers |
|
|
) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
try: |
|
|
|
|
|
inputs = data.get("inputs", data) |
|
|
params = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
image = Image.open(io.BytesIO(base64.b64decode(inputs))) |
|
|
else: |
|
|
image = inputs |
|
|
|
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
image = image.resize((1024, 1024), Image.BILINEAR) |
|
|
img_array = np.array(image).astype(np.float32) / 255.0 |
|
|
img_array = img_array.transpose(2, 0, 1)[np.newaxis, :] |
|
|
|
|
|
|
|
|
embeddings = self.encoder.run(None, {'image': img_array})[0] |
|
|
|
|
|
|
|
|
coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32) |
|
|
labels = np.array(params.get("point_labels", [1]), dtype=np.float32) |
|
|
|
|
|
|
|
|
decoder_outputs = self.decoder.run(None, { |
|
|
'image_embeddings': embeddings, |
|
|
'point_coords': coords.reshape(1, -1, 2), |
|
|
'point_labels': labels.reshape(1, -1) |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
masks = decoder_outputs[1] |
|
|
|
|
|
|
|
|
mask = masks[0, 0] |
|
|
mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR) |
|
|
mask = np.array(mask) |
|
|
mask = (mask > 0.0).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)} |
|
|
|
|
|
if params.get("return_mask_image", True): |
|
|
buffer = io.BytesIO() |
|
|
Image.fromarray(mask, mode='L').save(buffer, format='PNG') |
|
|
result["mask"] = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
|
|
return [result] |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
return [{ |
|
|
"error": str(e), |
|
|
"type": type(e).__name__, |
|
|
"traceback": traceback.format_exc() |
|
|
}] |
|
|
|