| import json | |
| import os | |
| from typing import Dict, List, Any | |
| import torch | |
| from transformers import pipeline | |
| PROMPT_FORMAT= """ | |
| <|user|> | |
| {inputs} <|end|> | |
| <|assistant|> | |
| """ | |
| class EndpointHandler(): | |
| def __init__(self, data): | |
| cfg = { | |
| "repo": "MrOvkill/Phi-3-Instruct-Bloated", | |
| } | |
| self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """ | |
| data args: | |
| inputs (:obj: `str` | `PIL.Image` | `np.array`) | |
| kwargs | |
| Return: | |
| A :obj:`list` | `dict`: will be serialized and returned | |
| """ | |
| self.pipe = pipeline("text-generation", "MrOvkill/Phi-3-Instruct-Bloated", torch_dtype=torch.float16, trust_remote_code=True) | |
| max_new_tokens = 1024 | |
| if "max_new_tokens" in data: | |
| max_new_tokens = data["max_new_tokens"] | |
| try: | |
| max_new_tokens = int(max_new_tokens) | |
| except Exception as e: | |
| return json.dumps({ | |
| "status": "error", | |
| "reason": "max_length was passed as something that was absolutely not a plain old int" | |
| }) | |
| res = PROMPT_FORMAT.format(inputs=data['inputs']) | |
| return self.pipe( | |
| res, | |
| do_sample=False, | |
| max_new_tokens=max_new_tokens | |
| ) | |
| return res | 
