File size: 7,204 Bytes
fbc8418
acc9b5d
 
171cc73
acc9b5d
 
 
 
 
 
 
078e469
e4524b0
078e469
f40466f
acc9b5d
078e469
 
 
 
 
 
 
 
acc9b5d
078e469
 
 
 
 
 
acc9b5d
078e469
 
 
 
 
 
 
 
 
 
acc9b5d
 
 
 
057b8f0
acc9b5d
 
078e469
 
 
 
 
 
 
 
 
acc9b5d
 
 
b150b57
acc9b5d
 
b150b57
e4524b0
acc9b5d
 
d9da728
ed47265
f43b9bc
078e469
 
21e3017
ed47265
 
 
 
 
acc9b5d
 
 
d9d7db9
e4524b0
acc9b5d
 
 
 
 
 
078e469
33ce564
acc9b5d
33ce564
acc9b5d
078e469
 
 
 
 
 
 
 
 
acc9b5d
 
 
 
 
 
e4524b0
acc9b5d
078e469
acc9b5d
078e469
acc9b5d
e4524b0
acc9b5d
 
 
 
 
33ce564
acc9b5d
078e469
 
 
 
 
 
 
 
 
acc9b5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
078e469
 
 
 
 
 
 
 
 
 
 
33ce564
078e469
fbc8418
078e469
21e3017
fbc8418
 
 
 
078e469
 
 
fbc8418
acc9b5d
fbc8418
2360523
078e469
acc9b5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from typing import Dict, Any
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from modelscope import snapshot_download
from qwen_vl_utils import process_vision_info
import torch
import os
import base64
import io
from PIL import Image
import logging
import requests
import subprocess
from moviepy.editor import VideoFileClip
import traceback  # For formatting exception tracebacks

class EndpointHandler():
    """
    Handler class for the Qwen2-VL-7B-Instruct model on Hugging Face Inference Endpoints.

    This handler processes text, image, and video inputs, leveraging the Qwen2-VL model
    for multimodal understanding and generation. It includes a runtime workaround to
    install FFmpeg if it's not available in the environment.
    """

    def __init__(self, path=""):
        """
        Initializes the handler, installs FFmpeg, and loads the Qwen2-VL model.

        Args:
            path (str, optional): The path to the Qwen2-VL model directory. Defaults to "".
        """
        self.model_dir = path

        # Install FFmpeg at runtime (this will run once during container initialization)
        try:
            subprocess.run(["apt-get", "update"], check=True)
            subprocess.run(["apt-get", "install", "-y", "ffmpeg"], check=True)
            logging.info("FFmpeg installed successfully.")
        except subprocess.CalledProcessError as e:
            logging.error(f"Error installing FFmpeg: {e}")

        # Load the Qwen2-VL model
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.model_dir, torch_dtype="auto", device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(self.model_dir)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Processes the input data and returns the Qwen2-VL model's output.

        Args:
            data (Dict[str, Any]): A dictionary containing the input data.
                - "inputs" (str): The input text, including image/video references.
                - "max_new_tokens" (int, optional): Max tokens to generate (default: 128).

        Returns:
            Dict[str, Any]: A dictionary containing the generated text.
        """
        inputs = data.get("inputs")
        max_new_tokens = data.get("max_new_tokens", 128)

        # Construct the messages list from the input string
        messages = [{"role": "user", "content": self._parse_input(inputs)}]

        # Prepare for inference (using qwen_vl_utils)
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)

        logging.debug(f"Image inputs: {image_inputs}")
        logging.debug(f"Video inputs: {video_inputs}")

        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")

        # Inference
        generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        return {"generated_text": output_text}

    def _parse_input(self, input_string):
        """
        Parses the input string to identify image/video references and text.

        Args:
            input_string (str): The input string containing text, image, and video references.

        Returns:
            list: A list of dictionaries representing the parsed content.
        """
        content = []
        parts = input_string.split("<image>")
        for i, part in enumerate(parts):
            if i % 2 == 0:  # Text part
                content.append({"type": "text", "text": part.strip()})
            else:  # Image/video part
                if part.lower().startswith("video:"):
                    video_path = part.split("video:")[1].strip()
                    print(f"Video path: {video_path}")
                    video_frames = self._extract_video_frames(video_path)
                    print(f"Number of frames extracted: {len(video_frames) if video_frames else 0}")
                    if video_frames:
                        content.append({"type": "video", "video": video_frames, "fps": 1})
                else:
                    image = self._load_image(part.strip())
                    if image:
                        content.append({"type": "image", "image": image})
        return content

    def _load_image(self, image_data):
        """
        Loads an image from a URL or base64 encoded string.

        Args:
            image_data (str): The image data, either a URL or a base64 encoded string.

        Returns:
            PIL.Image.Image or None: The loaded image, or None if loading fails.
        """
        if image_data.startswith("http"):
            try:
                image = Image.open(requests.get(image_data, stream=True).raw)
            except Exception as e:
                logging.error(f"Error loading image from URL: {e}")
                return None
        elif image_data.startswith("data:image"):
            try:
                image_data = image_data.split(",")[1]
                image_bytes = base64.b64decode(image_data)
                image = Image.open(io.BytesIO(image_bytes))
            except Exception as e:
                logging.error(f"Error loading image from base64: {e}")
                return None
        else:
            logging.error("Invalid image data format. Must be URL or base64 encoded.")
            return None
        return image

    def _extract_video_frames(self, video_path, fps=1):
        """
        Extracts frames from a video at the specified FPS using MoviePy.

        Args:
            video_path (str): The path or URL of the video file.
            fps (int, optional): The desired frames per second. Defaults to 1.

        Returns:
            list or None: A list of PIL Images representing the extracted frames, 
                          or None if extraction fails.
        """
        try:
            print(f"Attempting to load video from: {video_path}")
            video = VideoFileClip(video_path)
            print(f"Video loaded: {video}")

            frames = [
                Image.fromarray(frame.astype('uint8'), 'RGB')
                for frame in video.iter_frames(fps=fps)
            ]
            print(f"Number of frames: {len(frames)}")
            print(f"Frame type: {type(frames[0]) if frames else None}")
            print(f"Frame size: {frames[0].size if frames else None}")
            video.close()
            return frames
        except Exception as e:
            error_message = f"Error extracting video frames: {e}\n{traceback.format_exc()}"
            logging.error(error_message)  # Log the formatted error message
            return None