| from typing import Dict, List, Any | |
| from tempfile import TemporaryDirectory | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
| from PIL import Image | |
| import torch | |
| import requests | |
| class EndpointHandler: | |
| def __init__(self): | |
| pass | |
| # self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
| # device = 'gpu' if torch.cuda.is_available() else 'cpu' | |
| # model = LlavaNextForConditionalGeneration.from_pretrained( | |
| # "llava-hf/llava-v1.6-mistral-7b-hf", | |
| # torch_dtype=torch.float32 if device == 'cpu' else torch.float16, | |
| # low_cpu_mem_usage=True | |
| # ) | |
| # model.to(device) | |
| # self.model = model | |
| # self.device = device | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| inputs = data.get("inputs", "") | |
| if not inputs: | |
| return [{"error": "No inputs provided"}] | |
| return inputs | |
| # def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| # """ | |
| # data args: | |
| # inputs (:obj: `dict`) | |
| # Return: | |
| # A :obj:`list` | `dict`: will be serialized and returned | |
| # """ | |
| # # get inputs | |
| # inputs = data.get("inputs") | |
| # if not inputs: | |
| # return f"Inputs not in payload got {data}" | |
| # # get additional date field0 | |
| # prompt = inputs.get("prompt") | |
| # image_url = inputs.get("image") | |
| # if image_url is None: | |
| # return "You need to upload an image URL for LLaVA to work." | |
| # if prompt is None: | |
| # prompt = "Can you describe this picture focusing on specifics visual artifacts and ambiance (objects, colors, person, athmosphere..). Please stay concise only output keywords and concepts detected." | |
| # if not self.model: | |
| # return "Model was not initialized" | |
| # if not self.processor: | |
| # return "Processor was not initialized" | |
| # # Create a temporary directory | |
| # with TemporaryDirectory() as tmpdirname: | |
| # # Download the image | |
| # response = requests.get(image_url) | |
| # if response.status_code != 200: | |
| # return "Failed to download the image." | |
| # # Define the path for the downloaded image | |
| # image_path = f"{tmpdirname}/image.jpg" | |
| # with open(image_path, "wb") as f: | |
| # f.write(response.content) | |
| # # Open the downloaded image | |
| # with Image.open(image_path).convert("RGB") as image: | |
| # prompt = f"[INST] <image>\n{prompt} [/INST]" | |
| # inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) | |
| # output = self.model.generate(**inputs, max_new_tokens=100) | |
| # if not output: | |
| # return 'Model failed to generate' | |
| # clean = self.processor.decode(output[0], skip_special_tokens=True) | |
| # return clean |