Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import dataclasses | |
| import enum | |
| import subprocess | |
| from typing import Callable, Optional, Sequence | |
| import numpy as np | |
| from pytriton.decorators import batch, sample | |
| from pytriton.model_config import ModelConfig, Tensor | |
| class Framework(enum.Enum): | |
| TENSORFLOW = "tensorflow" | |
| PYTORCH = "pytorch" | |
| class TestModelSpec: | |
| name: str | |
| framework: Optional[Framework] | |
| create_infer_fn: Callable[..., Callable] | |
| inputs: Sequence[Tensor] | |
| outputs: Sequence[Tensor] | |
| model_config: ModelConfig | |
| def _create_add_sub_fn() -> Callable: | |
| def _add_sub(**inputs): | |
| a_batch, b_batch = inputs.values() | |
| add_batch = a_batch + b_batch | |
| sub_batch = a_batch - b_batch | |
| return {"add": add_batch, "sub": sub_batch} | |
| return _add_sub | |
| ADD_SUB_PYTHON_MODEL = TestModelSpec( | |
| name="AddSub", | |
| framework=None, | |
| create_infer_fn=_create_add_sub_fn, | |
| inputs=( | |
| Tensor(dtype=np.float32, shape=(-1,)), | |
| Tensor(dtype=np.float32, shape=(-1,)), | |
| ), | |
| outputs=( | |
| Tensor(name="add", dtype=np.float32, shape=(-1,)), | |
| Tensor(name="sub", dtype=np.float32, shape=(-1,)), | |
| ), | |
| model_config=ModelConfig(max_batch_size=128), | |
| ) | |
| def _create_identity_fn() -> Callable: | |
| def _identity(**inputs): | |
| (a_batch,) = inputs.values() | |
| return {"identity": a_batch} | |
| return _identity | |
| IDENTITY_PYTHON_MODEL = TestModelSpec( | |
| name="Identity", | |
| framework=None, | |
| create_infer_fn=_create_identity_fn, | |
| inputs=(Tensor(dtype=np.float32, shape=(-1,)),), | |
| outputs=(Tensor(name="identity", dtype=np.float32, shape=(-1,)),), | |
| model_config=ModelConfig(max_batch_size=128), | |
| ) | |
| def _create_tfhub_tensorflow_efficientdet_lite0_detection_fn() -> Callable: | |
| subprocess.run(["pip", "install", "tensorflow_hub"], check=True) | |
| import tensorflow_hub as hub # pytype: disable=import-error | |
| detector = hub.load("https://tfhub.dev/tensorflow/efficientdet/lite0/detection/1") | |
| def _tfhub_tensorflow_efficientdet_lite0_detection(image: np.ndarray): | |
| boxes, scores, classes, num_detections = detector(image) | |
| return { | |
| "boxes": boxes.numpy(), | |
| "scores": scores.numpy(), | |
| "classes": classes.numpy(), | |
| "num_detections": num_detections.numpy(), | |
| } | |
| return _tfhub_tensorflow_efficientdet_lite0_detection | |
| EFFICIENTDET_DETECTION_TF_MODEL = TestModelSpec( | |
| name="EfficientDetDetector", | |
| framework=Framework.TENSORFLOW, | |
| create_infer_fn=_create_tfhub_tensorflow_efficientdet_lite0_detection_fn, | |
| inputs=(Tensor(name="image", dtype=np.uint8, shape=(-1, -1, 3)),), | |
| outputs=( | |
| Tensor( | |
| name="boxes", | |
| dtype=np.float32, | |
| shape=(-1, 4), | |
| ), | |
| Tensor( | |
| name="scores", | |
| dtype=np.float32, | |
| shape=(-1,), | |
| ), | |
| Tensor( | |
| name="classes", | |
| dtype=np.int32, | |
| shape=(-1,), | |
| ), | |
| Tensor(name="num_detections", dtype=np.int32, shape=(1,)), | |
| ), | |
| model_config=ModelConfig(max_batch_size=128), | |
| ) | |
| def _create_hf_pytorch_bart_fn() -> Callable: | |
| subprocess.run(["pip", "install", "transformers"], check=True) | |
| import transformers # pytype: disable=import-error | |
| classifier = transformers.pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0) | |
| def _hf_pytorch_bart_fn(sequence: np.ndarray, labels: np.ndarray): | |
| sequence = sequence[0].decode("utf-8") | |
| labels = [label.decode("utf-8") for label in labels] | |
| classification_result = classifier(sequence, labels) | |
| scores_batch = np.array(classification_result["scores"], dtype=np.float32) | |
| return {"scores": scores_batch} | |
| return _hf_pytorch_bart_fn | |
| BART_CLASSIFIER_PYTORCH_MODEL = TestModelSpec( | |
| name="BARTClassifier", | |
| framework=Framework.PYTORCH, | |
| create_infer_fn=_create_hf_pytorch_bart_fn, | |
| inputs=(Tensor(name="sequence", dtype=object, shape=(-1,)), Tensor(name="labels", dtype=object, shape=(-1,))), | |
| outputs=(Tensor(name="scores", dtype=np.float32, shape=(-1,)),), | |
| model_config=ModelConfig(batching=False), | |
| ) | |
| def _create_hf_jax_bert_fn() -> Callable: | |
| subprocess.run( | |
| [ | |
| "pip", | |
| "install", | |
| "--upgrade", | |
| "jax[cuda12_pip]", | |
| "-f", | |
| "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html", | |
| ], | |
| check=True, | |
| ) | |
| subprocess.run(["pip", "install", "transformers", "flax"], check=True) | |
| import transformers # pytype: disable=import-error | |
| tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") | |
| model = transformers.FlaxBertModel.from_pretrained("bert-base-uncased") | |
| def _infer_fn(**inputs: np.ndarray): | |
| (sequence_batch,) = inputs.values() | |
| sequence_batch = sequence_batch.tolist() | |
| last_hidden_states = [] | |
| for sequence_item in sequence_batch: | |
| # 0 needed to extract str from numpy array and deocode utf-8 | |
| sequence_as_str = sequence_item[0].decode("utf-8") | |
| tokenized_sequence = tokenizer(sequence_as_str, return_tensors="jax") | |
| results = model(**tokenized_sequence) | |
| last_hidden_states.append(results.last_hidden_state) | |
| last_hidden_states = np.array(last_hidden_states, dtype=np.float32) | |
| return [last_hidden_states] | |
| return _infer_fn | |
| BERT_JAX_MODEL = TestModelSpec( | |
| name="BERTJax", | |
| framework=Framework.TENSORFLOW, | |
| create_infer_fn=_create_hf_jax_bert_fn, | |
| inputs=(Tensor(name="sequence", dtype=object, shape=(1,)),), | |
| outputs=(Tensor(name="last_hidden_state", dtype=np.float32, shape=(-1,)),), | |
| model_config=ModelConfig(max_batch_size=128), | |
| ) | |
| MODELS_CATALOGUE = [ | |
| ADD_SUB_PYTHON_MODEL, | |
| IDENTITY_PYTHON_MODEL, | |
| EFFICIENTDET_DETECTION_TF_MODEL, | |
| BART_CLASSIFIER_PYTORCH_MODEL, | |
| BERT_JAX_MODEL, | |
| ] | |