Spaces:
Running
on
Zero
Running
on
Zero
| # modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py | |
| import os | |
| import torch | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer | |
| from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList | |
| from qwen_vl_utils import process_vision_info | |
| from threading import Thread | |
| class MiMoVLInfer: | |
| def __init__(self, checkpoint_path, **kwargs): | |
| dtype = torch.float16 | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| checkpoint_path, | |
| torch_dtype=dtype, | |
| device_map={"": "cpu"}, | |
| trust_remote_code=True, | |
| ).eval() | |
| self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True) | |
| self._on_cuda = False | |
| def to_device(self, device: str): | |
| if device == "cuda" and not self._on_cuda: | |
| self.model.to("cuda") | |
| self._on_cuda = True | |
| elif device == "cpu" and self._on_cuda: | |
| self.model.to("cpu") | |
| self._on_cuda = False | |
| def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0): | |
| messages = self.construct_messages(inputs) | |
| updated_history = history + messages | |
| text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True) | |
| image_inputs, video_inputs = process_vision_info(updated_history) | |
| model_inputs = self.processor( | |
| text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt' | |
| ).to(self.model.device) | |
| tokenizer = self.processor.tokenizer | |
| streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
| max_new = int(os.getenv("MAX_NEW_TOKENS", "1024")) | |
| temp = float(temperature or 0.0) | |
| do_sample = temp > 1e-3 | |
| if do_sample: | |
| samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95} | |
| else: | |
| samp_args = {"do_sample": False} | |
| gen_kwargs = { | |
| "max_new_tokens": 1024, | |
| "streamer": streamer, | |
| "stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]), | |
| "pad_token_id": self.model.config.eos_token_id, | |
| **model_inputs, | |
| **samp_args, | |
| } | |
| thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True) | |
| thread.start() | |
| partial_response = "" | |
| for new_text in streamer: | |
| partial_response += new_text | |
| yield partial_response, updated_history + [{ | |
| 'role': 'assistant', | |
| 'content': [{'type': 'text', 'text': partial_response}] | |
| }] | |
| def _is_video_file(self, filename): | |
| return any(filename.lower().endswith(ext) for ext in | |
| ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']) | |
| def construct_messages(self, inputs: dict) -> list: | |
| content = [] | |
| for path in inputs.get('files', []): | |
| if self._is_video_file(path): | |
| content.append({"type": "video", "video": f'file://{path}'}) | |
| else: | |
| content.append({"type": "image", "image": f'file://{path}'}) | |
| query = inputs.get('text', '') | |
| if query: | |
| content.append({"type": "text", "text": query}) | |
| return [{"role": "user", "content": content}] |