Spaces:
Build error
Build error
| import copy | |
| import json | |
| import math | |
| import os | |
| import random | |
| import re | |
| import ast | |
| from typing import Dict | |
| import torch | |
| import transformers | |
| import yaml | |
| from qwen_vl_utils import smart_resize, process_vision_info | |
| from torch.utils.data import Dataset | |
| from gui_actor.constants import ( | |
| IGNORE_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_POINTER_START_TOKEN, | |
| DEFAULT_POINTER_PAD_TOKEN, | |
| DEFAULT_POINTER_END_TOKEN, | |
| ACTION_PATTENS_XY, | |
| ADDITIONAL_SPECIAL_TOKENS, | |
| assistant_template, | |
| chat_template, | |
| grounding_system_message, | |
| ) | |
| from gui_actor.trainer import rank0_print | |
| def reformat_coordinates(text): | |
| """ | |
| (1) Find all the coordinates in the text. | |
| (2) Replace the coordinates with the special tokens. | |
| (3) Return the new text and the coordinates as a list of (x, y), where x in [0, 1] and y in [0, 1]. | |
| """ | |
| epsilon = 0.001 | |
| def adjust_coord(c): | |
| """ | |
| Adjust coordinate if it is too close to 0 or 1. | |
| """ | |
| if abs(c) < epsilon: | |
| return epsilon | |
| elif abs(c - 1) < epsilon: | |
| return 1 - epsilon | |
| return c | |
| all_matches = [] | |
| for pattern in ACTION_PATTENS_XY: | |
| matches = list(re.finditer(pattern, text)) | |
| for match in matches: | |
| all_matches.append((match.start(), match.groups())) | |
| if pattern == ACTION_PATTENS_XY[0]: | |
| target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}" | |
| else: | |
| target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}, {DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}" | |
| text = re.sub( | |
| pattern, | |
| target_text, | |
| text | |
| ) | |
| coordinates = [] | |
| all_matches.sort(key=lambda x: x[0]) | |
| # Extract coordinates in order | |
| for _, groups in all_matches: | |
| # When two coordinate values are found, parse them as one (x, y) pair. | |
| if len(groups) == 2: | |
| x_str, y_str = groups | |
| x = adjust_coord(ast.literal_eval(x_str)) | |
| y = adjust_coord(ast.literal_eval(y_str)) | |
| coordinates.append((x, y)) | |
| # When four coordinate values are found, parse them as two pairs. | |
| elif len(groups) == 4: | |
| x1_str, y1_str, x2_str, y2_str = groups | |
| x1 = adjust_coord(ast.literal_eval(x1_str)) | |
| y1 = adjust_coord(ast.literal_eval(y1_str)) | |
| x2 = adjust_coord(ast.literal_eval(x2_str)) | |
| y2 = adjust_coord(ast.literal_eval(y2_str)) | |
| coordinates.append((x1, y1)) | |
| coordinates.append((x2, y2)) | |
| return text, coordinates | |
| def get_token_index(image_processor, image, point_x, point_y): | |
| """ | |
| Get the index of the visual token that contains the point (x, y). | |
| Args: | |
| image_processor: the image processor | |
| image: the image in PIL format | |
| point_x: the x coordinate of the point, in [0, 1]. | |
| point_y: the y coordinate of the point, in [0, 1]. | |
| """ | |
| if len(image) != 1: | |
| raise ValueError(f"Expected 1 image, got {len(image)}") | |
| # get the original image size and the resized image size | |
| image = image[0] | |
| w, h = image.size | |
| px, py = w * point_x, h * point_y | |
| # rank0_print(f"px: {px}, py: {py}") | |
| # get the token index | |
| merge_patch_size = image_processor.patch_size * image_processor.merge_size | |
| x_index = math.floor(px / merge_patch_size) | |
| y_index = math.floor(py / merge_patch_size) | |
| visual_token_index = y_index * (w // merge_patch_size) + x_index | |
| # merge all above print into one line | |
| return visual_token_index | |
| def get_multi_patch_labels(image_processor, image, bbox_gt): | |
| """ | |
| Get the multi-patch labels for the bounding box. | |
| Args: | |
| image_processor: the image processor | |
| image: the image in PIL format | |
| bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1] | |
| """ | |
| if len(image) != 1: | |
| raise ValueError(f"Expected 1 image, got {len(image)}") | |
| # Get the original image size and the resized image size | |
| image = image[0] | |
| w, h = image.size | |
| bbox_gt = [bbox_gt[0]*w, bbox_gt[1]*h, bbox_gt[2]*w, bbox_gt[3]*h] | |
| # Extract bounding box coordinates | |
| x_min, y_min, x_max, y_max = bbox_gt | |
| x_min = max(0, x_min) | |
| y_min = max(0, y_min) | |
| x_max = min(w, x_max) | |
| y_max = min(h, y_max) | |
| merge_patch_size = image_processor.patch_size * image_processor.merge_size | |
| assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}" | |
| grid_h, grid_w = h // merge_patch_size, w // merge_patch_size | |
| binary_mask = torch.zeros(grid_h * grid_w) | |
| # Iterate through all patches, check if they overlap with the bounding box | |
| for y_idx in range(grid_h): | |
| for x_idx in range(grid_w): | |
| # Calculate patch boundaries | |
| patch_x_min = x_idx * merge_patch_size | |
| patch_y_min = y_idx * merge_patch_size | |
| patch_x_max = patch_x_min + merge_patch_size | |
| patch_y_max = patch_y_min + merge_patch_size | |
| # Check if patch overlaps with the bounding box | |
| if not (patch_x_max <= x_min or patch_x_min >= x_max or | |
| patch_y_max <= y_min or patch_y_min >= y_max): | |
| # Calculate patch index in the flattened grid | |
| patch_idx = y_idx * grid_w + x_idx | |
| binary_mask[patch_idx] = 1 | |
| return binary_mask | |
| def token_index_to_coordinates(image_processor, visual_token_index, image_width, image_height): | |
| merge_patch_size = image_processor.patch_size * image_processor.merge_size | |
| x_index = visual_token_index % (image_width // merge_patch_size) | |
| y_index = visual_token_index // (image_width // merge_patch_size) | |
| px = x_index * merge_patch_size + merge_patch_size / 2 | |
| py = y_index * merge_patch_size + merge_patch_size / 2 | |
| return px, py | |
| class LazySupervisedDataset(Dataset): | |
| def __init__( | |
| self, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| processor: transformers.ProcessorMixin, | |
| data_path: str, | |
| data_args, | |
| ): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.processor = processor | |
| self.list_data_dict = [] | |
| self.list_image_path = [] | |
| self.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0] | |
| self.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0] | |
| self.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] | |
| # Handle multiple JSON files specified in the data_path | |
| if "{" in data_path and "}" in data_path: | |
| base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups() | |
| file_names = file_pattern.split(",") | |
| rank0_print(f"Loading {file_names} from {base_path}") | |
| data_args.dataset_paths = [] | |
| for file_name in file_names: | |
| data_args.dataset_paths.append(f"{base_path}{file_name}.json") | |
| full_path = f"{base_path}{file_name}.json" | |
| rank0_print(f"Loading {full_path}") | |
| with open(full_path) as file: | |
| cur_data_dict = json.load(file) | |
| rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}") | |
| self.list_data_dict.extend(cur_data_dict) | |
| elif data_path.endswith(".yaml"): | |
| with open(data_path) as file: | |
| yaml_data = yaml.safe_load(file) | |
| datasets = yaml_data.get("datasets") | |
| # file should be in the format of: | |
| # datasets: | |
| # - json_path: xxxx1.json | |
| # sampling_strategy: first:1000 | |
| # - json_path: xxxx2.json | |
| # sampling_strategy: end:3000 | |
| # - json_path: xxxx3.json | |
| # sampling_strategy: random:999 | |
| data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets] | |
| for dataset in datasets: | |
| json_path = dataset.get("json_path") | |
| sampling_strategy = dataset.get("sampling_strategy", "all") | |
| images_folder = dataset.get("images_folder") | |
| sampling_number = None | |
| rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy") | |
| if json_path.endswith(".jsonl"): | |
| cur_data_dict = [] | |
| with open(json_path) as json_file: | |
| for line in json_file: | |
| cur_data_dict.append(json.loads(line.strip())) | |
| elif json_path.endswith(".json"): | |
| # NOTE: we only use json_path with .json now | |
| # Handle the images_folder in yaml | |
| with open(json_path) as json_file: | |
| cur_data_dict = json.load(json_file) | |
| else: | |
| raise ValueError(f"Unsupported file type: {json_path}") | |
| if ":" in sampling_strategy: | |
| sampling_strategy, sampling_number = sampling_strategy.split(":") | |
| if "%" in sampling_number: | |
| sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100) | |
| else: | |
| sampling_number = int(sampling_number) | |
| # Apply the sampling strategy | |
| if sampling_strategy == "first" and sampling_number is not None: | |
| cur_data_dict = cur_data_dict[:sampling_number] | |
| elif sampling_strategy == "end" and sampling_number is not None: | |
| cur_data_dict = cur_data_dict[-sampling_number:] | |
| elif sampling_strategy == "random" and sampling_number is not None: | |
| random.shuffle(cur_data_dict) | |
| cur_data_dict = cur_data_dict[:sampling_number] | |
| rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}") | |
| self.list_data_dict.extend(cur_data_dict) | |
| self.list_image_path.extend([images_folder] * len(cur_data_dict)) | |
| else: | |
| data_args.dataset_paths = [data_path] | |
| rank0_print(f"Loading {data_path}") | |
| with open(data_path) as file: | |
| cur_data_dict = json.load(file) | |
| rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}") | |
| self.list_data_dict.extend(cur_data_dict) | |
| self.list_image_path.extend([""] * len(cur_data_dict)) # NOTE: the image subfolder is empty... | |
| rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}") | |
| rank0_print("Formatting inputs...Skip in lazy mode") | |
| self.tokenizer = tokenizer | |
| self.data_args = data_args | |
| def __len__(self): | |
| return len(self.list_data_dict) | |
| def lengths(self): | |
| length_list = [] | |
| for sample in self.list_data_dict: | |
| img_tokens = ( | |
| 1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0 | |
| ) | |
| length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) | |
| return length_list | |
| def modality_lengths(self): | |
| length_list = [] | |
| for sample in self.list_data_dict: | |
| cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) | |
| assert cur_len > 0, f"Conversation length is 0 for {sample}" | |
| img_tokens = ( | |
| 1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0 | |
| ) | |
| if "image" in sample or "video" in sample or self.data_args.early_mix_text: | |
| length_list.append(cur_len + img_tokens) | |
| else: | |
| length_list.append(-cur_len) | |
| return length_list | |
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: | |
| sample = self._get_item(i) | |
| if sample is None: | |
| new_index = random.randint(0, len(self.list_data_dict) - 1) | |
| return self.__getitem__(new_index) | |
| else: | |
| return sample | |
| try: | |
| sample = self._get_item(i) | |
| if sample is None: | |
| new_index = random.randint(0, len(self.list_data_dict) - 1) | |
| return self.__getitem__(new_index) | |
| except Exception as e: | |
| print(f"Failed to fetch sample {i}. Exception:", e) | |
| new_index = random.randint(0, len(self.list_data_dict) - 1) | |
| return self.__getitem__(new_index) | |
| return sample | |
| def _get_item(self, i) -> Dict[str, torch.Tensor]: | |
| sources = self.list_data_dict[i] | |
| image_path = os.path.join(self.data_args.image_folder, self.list_image_path[i]) | |
| if "image" in sources: | |
| image_file = self.list_data_dict[i]["image"] | |
| if type(image_file) is list: | |
| image_list = [os.path.join(image_path, image_file) for image_file in image_file] | |
| else: | |
| image_list = [os.path.join(image_path, image_file)] | |
| sources = copy.deepcopy(sources["conversations"]) | |
| elif "video" in sources: | |
| raise NotImplementedError("Video is not supported for Qwen2VL") | |
| else: | |
| sources = copy.deepcopy(sources["conversations"]) | |
| item_id = self.list_data_dict[i].get("id", i) | |
| data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list, id=item_id) | |
| if isinstance(i, int): | |
| data_dict = { | |
| "input_ids": data_dict["input_ids"][0], | |
| "labels": data_dict["labels"][0], | |
| "coordinates": data_dict["coordinates"][0], | |
| "visual_token_indices_of_coordinates": data_dict["visual_token_indices_of_coordinates"][0], | |
| "pixel_values": data_dict["pixel_values"], | |
| "image_grid_thw": data_dict["image_grid_thw"], | |
| "multi_patch_labels": data_dict["multi_patch_labels"][0], # add multi_patch_labels | |
| } | |
| data_dict["id"] = item_id | |
| # return None if the input_ids is longer than the model_max_length | |
| n_image_tokens = ( | |
| data_dict["image_grid_thw"][0][0] * | |
| data_dict["image_grid_thw"][0][1] * | |
| data_dict["image_grid_thw"][0][2] / | |
| self.processor.image_processor.merge_size / | |
| self.processor.image_processor.merge_size | |
| ) | |
| if (len(data_dict["input_ids"]) + n_image_tokens) > self.tokenizer.model_max_length: | |
| rank0_print(f"=== Removed data_dict {i} because it is longer than the model_max_length: {len(data_dict['input_ids'])} + {n_image_tokens} > {self.tokenizer.model_max_length}") | |
| return None | |
| return data_dict | |
| def preprocess_qwen2vl( | |
| self, | |
| source, # conversations | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| processor: transformers.ProcessorMixin, | |
| image: list, | |
| system_message: str = grounding_system_message, | |
| agent_mode: bool = True, | |
| chat_template: str = chat_template, | |
| assistant_template: str = assistant_template, | |
| id: int = None, | |
| ) -> Dict: | |
| roles = {"human": "user", "gpt": "assistant", "system": "system"} | |
| assistant_template = assistant_template if agent_mode else chat_template | |
| processor.tokenizer = tokenizer | |
| assert tokenizer.additional_special_tokens == ADDITIONAL_SPECIAL_TOKENS | |
| # Apply prompt templates | |
| pixel_values, image_grid_thw = None, None | |
| input_id, target = [], [] | |
| coordinates = [] | |
| visual_token_indices_of_coordinates = [] | |
| multi_patch_labels = [] | |
| image_list = [] | |
| image_index = 0 | |
| ## prepare the system message | |
| if roles[source[0]["from"]] == "system": | |
| system_message = source[0]["value"] | |
| source = source[1:self.data_args.max_conv_turns] | |
| # else: use the constant system message | |
| system_input_id = tokenizer.apply_chat_template( | |
| conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}], | |
| chat_template=chat_template, | |
| ) | |
| input_id += system_input_id | |
| target += [IGNORE_INDEX] * len(system_input_id) | |
| ## prepare user-assistant conversation | |
| for conv in source: | |
| # regularize the conversation format | |
| try: | |
| role = conv["role"] | |
| content = conv["content"] | |
| except Exception: | |
| role = conv["from"] | |
| content = conv["value"] | |
| role = roles.get(role, role) | |
| # Count the number of <image> tokens in the content | |
| image_count = content.count(DEFAULT_IMAGE_TOKEN) | |
| if image_count > 0: | |
| assert role == "user", "Images are only supported for user messages" | |
| # include image information regarding to current conversation turn | |
| image_placeholders = [] | |
| for _ in range(image_count): | |
| image_placeholders.append({ | |
| "type": "image", | |
| "image": image[image_index], | |
| "min_pixels": self.processor.image_processor.min_pixels, | |
| "max_pixels": self.processor.image_processor.max_pixels, | |
| }) | |
| image_index += 1 | |
| content = content.replace(DEFAULT_IMAGE_TOKEN, "") | |
| conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]} | |
| image_inputs, _ = process_vision_info([conv]) # list of PIL.Image.Image | |
| image_list.extend(image_inputs) | |
| templated_conv = tokenizer.apply_chat_template( | |
| conversation=[conv], chat_template=chat_template, tokenize=False | |
| ) | |
| inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt") | |
| if pixel_values is None and image_grid_thw is None: | |
| pixel_values = inputs["pixel_values"] | |
| image_grid_thw = inputs["image_grid_thw"] | |
| else: | |
| pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0) | |
| image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0) | |
| else: | |
| if role in ["user", "system"]: | |
| conv = {"role": role, "content": [{"type": "text", "text": content}]} | |
| else: # assistant | |
| conv = { | |
| "role": role, | |
| "content": [{"type": "text", "text": content}], | |
| "recipient": conv.get("recipient", "os"), | |
| "end_turn": conv.get("end_turn", True), | |
| "bbox_gt": conv.get("bbox_gt", None), | |
| } | |
| if conv["recipient"] == "os": | |
| if len(image_inputs) == 0: | |
| raise ValueError("No image found for visual grounding") | |
| # replace the coordinates with the special tokens | |
| text, coord = reformat_coordinates(conv["content"][0]["text"]) | |
| conv["content"][0]["text"] = text | |
| # rank0_print(f"coord: {coord}") | |
| # get the visual token indices of the coordinates | |
| coordinates.extend(coord) | |
| for (point_x, point_y) in coord: | |
| visual_token_index = get_token_index( | |
| processor.image_processor, | |
| image_list, | |
| point_x, | |
| point_y | |
| ) | |
| # px, py = token_index_to_coordinates( | |
| # processor.image_processor, | |
| # visual_token_index, | |
| # image_list[0].size[0], # make sure the size here is after qwen2vl processing | |
| # image_list[0].size[1] | |
| # ) | |
| # rank0_print(f"estimated px: {px}, py: {py}") | |
| visual_token_indices_of_coordinates.append(visual_token_index) | |
| if conv["bbox_gt"] is not None: | |
| patch_mask = get_multi_patch_labels( | |
| processor.image_processor, | |
| image_list, | |
| conv["bbox_gt"] | |
| ) | |
| multi_patch_labels.append(patch_mask) | |
| templated_conv = tokenizer.apply_chat_template( | |
| conversation=[conv], | |
| chat_template=assistant_template, | |
| tokenize=False, | |
| ) | |
| inputs = processor(text=[templated_conv], return_tensors="pt") | |
| encode_id = inputs.input_ids[0].tolist() | |
| input_id += encode_id | |
| if role in ["user", "system"]: | |
| target += [IGNORE_INDEX] * len(encode_id) | |
| else: | |
| target += encode_id | |
| assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" | |
| # make the labels of all pointer_end_token_id to be IGNORE_INDEX | |
| target = [IGNORE_INDEX if token == self.pointer_end_token_id else token for token in target] | |
| input_ids = torch.tensor([input_id], dtype=torch.long) | |
| targets = torch.tensor([target], dtype=torch.long) | |
| visual_token_indices_of_coordinates = torch.tensor([visual_token_indices_of_coordinates], dtype=torch.long) if len(visual_token_indices_of_coordinates) > 0 else [None] | |
| coordinates = [coordinates] if len(coordinates) > 0 else [None] | |
| # process multi_patch_labels | |
| if len(multi_patch_labels) > 0: | |
| multi_patch_labels = [torch.stack(multi_patch_labels)] | |
| else: | |
| multi_patch_labels = [None] | |
| data_dict = { | |
| "input_ids": input_ids, # tensor(bs x seq_len) | |
| "labels": targets, # tensor(bs x seq_len) | |
| } | |
| if pixel_values is not None: | |
| data_dict["pixel_values"] = pixel_values | |
| data_dict["image_grid_thw"] = image_grid_thw | |
| # if len(coordinates[0]) != len(visual_token_indices_of_coordinates[0]): | |
| # raise ValueError(f"The number of coordinates ({len(coordinates[0])}) does not match the number of image token indices ({len(visual_token_indices_of_coordinates[0])})") | |
| data_dict["coordinates"] = coordinates | |
| data_dict["visual_token_indices_of_coordinates"] = visual_token_indices_of_coordinates | |
| data_dict["multi_patch_labels"] = multi_patch_labels | |
| return data_dict | |