File size: 7,204 Bytes
			
			| fbc8418 acc9b5d 171cc73 acc9b5d 078e469 e4524b0 078e469 f40466f acc9b5d 078e469 acc9b5d 078e469 acc9b5d 078e469 acc9b5d 057b8f0 acc9b5d 078e469 acc9b5d b150b57 acc9b5d b150b57 e4524b0 acc9b5d d9da728 ed47265 f43b9bc 078e469 21e3017 ed47265 acc9b5d d9d7db9 e4524b0 acc9b5d 078e469 33ce564 acc9b5d 33ce564 acc9b5d 078e469 acc9b5d e4524b0 acc9b5d 078e469 acc9b5d 078e469 acc9b5d e4524b0 acc9b5d 33ce564 acc9b5d 078e469 acc9b5d 078e469 33ce564 078e469 fbc8418 078e469 21e3017 fbc8418 078e469 fbc8418 acc9b5d fbc8418 2360523 078e469 acc9b5d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | from typing import Dict, Any
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from modelscope import snapshot_download
from qwen_vl_utils import process_vision_info
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
import subprocess
from moviepy.editor import VideoFileClip
import traceback  # For formatting exception tracebacks
class EndpointHandler():
    """
    Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.
    This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
    for multimodal understanding and generation. It includes a runtime workaround to
    install FFmpeg if it's not available in the environment.
    """
    def __init__(self, path=""):
        """
        Initializes the handler, installs FFmpeg, and loads the Qwen2-VL model.
        Args:
            path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
        """
        self.model_dir = path
        # Install FFmpeg at runtime (this will run once during container initialization)
        try:
            subprocess.run(["apt-get", "update"], check=True)
            subprocess.run(["apt-get", "install", "-y", "ffmpeg"], check=True)
            logging.info("FFmpeg installed successfully.")
        except subprocess.CalledProcessError as e:
            logging.error(f"Error installing FFmpeg: {e}")
        # Load the Qwen2-VL model
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.model_dir, torch_dtype="auto", device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(self.model_dir)
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Processes the input data and returns the Qwen2-VL model's output.
        Args:
            data (Dict[str, Any]): A dictionary containing the input data.
                - "inputs" (str): The input text, including image/video references.
                - "max_new_tokens" (int, optional): Max tokens to generate (default: 128).
        Returns:
            Dict[str, Any]: A dictionary containing the generated text.
        """
        inputs = data.get("inputs")
        max_new_tokens = data.get("max_new_tokens", 128)
        # Construct the messages list from the input string
        messages = [{"role": "user", "content": self._parse_input(inputs)}]
        # Prepare for inference (using qwen_vl_utils)
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        logging.debug(f"Image inputs: {image_inputs}")
        logging.debug(f"Video inputs: {video_inputs}")
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
        # Inference
        generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        return {"generated_text": output_text}
    def _parse_input(self, input_string):
        """
        Parses the input string to identify image/video references and text.
        Args:
            input_string (str): The input string containing text, image, and video references.
        Returns:
            list: A list of dictionaries representing the parsed content.
        """
        content = []
        parts = input_string.split("<image>")
        for i, part in enumerate(parts):
            if i % 2 == 0:  # Text part
                content.append({"type": "text", "text": part.strip()})
            else:  # Image/video part
                if part.lower().startswith("video:"):
                    video_path = part.split("video:")[1].strip()
                    print(f"Video path: {video_path}")
                    video_frames = self._extract_video_frames(video_path)
                    print(f"Number of frames extracted: {len(video_frames) if video_frames else 0}")
                    if video_frames:
                        content.append({"type": "video", "video": video_frames, "fps": 1})
                else:
                    image = self._load_image(part.strip())
                    if image:
                        content.append({"type": "image", "image": image})
        return content
    def _load_image(self, image_data):
        """
        Loads an image from a URL or base64 encoded string.
        Args:
            image_data (str): The image data, either a URL or a base64 encoded string.
        Returns:
            PIL.Image.Image or None: The loaded image, or None if loading fails.
        """
        if image_data.startswith("http"):
            try:
                image = Image.open(requests.get(image_data, stream=True).raw)
            except Exception as e:
                logging.error(f"Error loading image from URL: {e}")
                return None
        elif image_data.startswith("data:image"):
            try:
                image_data = image_data.split(",")[1]
                image_bytes = base64.b64decode(image_data)
                image = Image.open(io.BytesIO(image_bytes))
            except Exception as e:
                logging.error(f"Error loading image from base64: {e}")
                return None
        else:
            logging.error("Invalid image data format. Must be URL or base64 encoded.")
            return None
        return image
    def _extract_video_frames(self, video_path, fps=1):
        """
        Extracts frames from a video at the specified FPS using MoviePy.
        Args:
            video_path (str): The path or URL of the video file.
            fps (int, optional): The desired frames per second. Defaults to 1.
        Returns:
            list or None: A list of PIL Images representing the extracted frames, 
                          or None if extraction fails.
        """
        try:
            print(f"Attempting to load video from: {video_path}")
            video = VideoFileClip(video_path)
            print(f"Video loaded: {video}")
            frames = [
                Image.fromarray(frame.astype('uint8'), 'RGB')
                for frame in video.iter_frames(fps=fps)
            ]
            print(f"Number of frames: {len(frames)}")
            print(f"Frame type: {type(frames[0]) if frames else None}")
            print(f"Frame size: {frames[0].size if frames else None}")
            video.close()
            return frames
        except Exception as e:
            error_message = f"Error extracting video frames: {e}\n{traceback.format_exc()}"
            logging.error(error_message)  # Log the formatted error message
            return None | 
