File size: 2,906 Bytes
594e70f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5f470e
594e70f
 
 
d5f470e
594e70f
2e11ff8
 
 
d5f470e
2e11ff8
 
 
 
d5f470e
594e70f
 
 
 
 
 
 
 
 
 
 
 
decc98d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Dict, List, Any
import onnxruntime as ort
import numpy as np
from PIL import Image
import io
import base64
import os


class EndpointHandler:
    def __init__(self, path=""):
        model_path = path if path else "."
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

        self.encoder = ort.InferenceSession(
            os.path.join(model_path, "edge_sam_3x_encoder.onnx"),
            providers=providers
        )
        self.decoder = ort.InferenceSession(
            os.path.join(model_path, "edge_sam_3x_decoder.onnx"),
            providers=providers
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        try:
            # Parse input
            inputs = data.get("inputs", data)
            params = data.get("parameters", {})

            # Load image
            if isinstance(inputs, str):
                image = Image.open(io.BytesIO(base64.b64decode(inputs)))
            else:
                image = inputs

            # Preprocess
            if image.mode != 'RGB':
                image = image.convert('RGB')
            image = image.resize((1024, 1024), Image.BILINEAR)
            img_array = np.array(image).astype(np.float32) / 255.0
            img_array = img_array.transpose(2, 0, 1)[np.newaxis, :]

            # Encode
            embeddings = self.encoder.run(None, {'image': img_array})[0]

            # Prepare prompts
            coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32)
            labels = np.array(params.get("point_labels", [1]), dtype=np.float32)

            # Decode
            decoder_outputs = self.decoder.run(None, {
                'image_embeddings': embeddings,
                'point_coords': coords.reshape(1, -1, 2),
                'point_labels': labels.reshape(1, -1)
            })

            # decoder_outputs[0] is IoU scores (1, 4)
            # decoder_outputs[1] is masks (1, 4, 256, 256)
            masks = decoder_outputs[1]

            # Take first mask and resize to 1024x1024
            mask = masks[0, 0]  # Shape: (256, 256)
            mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR)
            mask = np.array(mask)
            mask = (mask > 0.0).astype(np.uint8) * 255

            # Return result
            result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)}

            if params.get("return_mask_image", True):
                buffer = io.BytesIO()
                Image.fromarray(mask, mode='L').save(buffer, format='PNG')
                result["mask"] = base64.b64encode(buffer.getvalue()).decode()

            return [result]

        except Exception as e:
            import traceback
            return [{
                "error": str(e),
                "type": type(e).__name__,
                "traceback": traceback.format_exc()
            }]