| """ | |
| Copyright (c) 2024-present Naver Cloud Corp. | |
| This source code is based on code from the Segment Anything Model (SAM) | |
| (https://github.com/facebookresearch/segment-anything). | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import torch | |
| from typing import Any, Callable | |
| import onnxruntime | |
| def np2tensor(np_array, device): | |
| return torch.from_numpy(np_array).to(device) | |
| def tensor2np(torch_tensor): | |
| return torch_tensor.detach().cpu().numpy() | |
| class ZIM_Encoder(): | |
| def __init__(self, onnx_path, num_threads=16): | |
| self.onnx_path = onnx_path | |
| sessionOptions = onnxruntime.SessionOptions() | |
| sessionOptions.intra_op_num_threads = num_threads | |
| sessionOptions.inter_op_num_threads = num_threads | |
| providers = ["CPUExecutionProvider"] | |
| self.ort_session = onnxruntime.InferenceSession( | |
| onnx_path, sess_options=sessionOptions, providers=providers | |
| ) | |
| def cuda(self, device_id=0): | |
| providers = [ | |
| ( | |
| "CUDAExecutionProvider", | |
| { | |
| "device_id": device_id, | |
| }, | |
| ), | |
| ] | |
| self.ort_session.set_providers(providers) | |
| def forward( | |
| self, | |
| image, | |
| ): | |
| device = image.device | |
| ort_inputs = { | |
| "image": tensor2np(image), | |
| } | |
| image_embeddings, feat_D0, feat_D1, feat_D2 = self.ort_session.run(None, ort_inputs) | |
| image_embeddings = np2tensor(image_embeddings, device) | |
| feat_D0 = np2tensor(feat_D0, device) | |
| feat_D1 = np2tensor(feat_D1, device) | |
| feat_D2 = np2tensor(feat_D2, device) | |
| return image_embeddings, (feat_D0, feat_D1, feat_D2) | |
| __call__: Callable[..., Any] = forward | |