Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import pathlib | |
| import re | |
| from typing import Tuple | |
| from typing import Union, List | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from dateutil import parser as dateparser | |
| from torchvision import transforms | |
| from torchvision.ops import box_iou | |
| from word2number import w2n | |
| from vision_processes import forward | |
| def load_json(path: str): | |
| if isinstance(path, str): | |
| path = pathlib.Path(path) | |
| if path.suffix != '.json': | |
| path = path.with_suffix('.json') | |
| with open(path, 'r') as f: | |
| data = json.load(f) | |
| return data | |
| class ImagePatch: | |
| """A Python class containing a crop of an image centered around a particular object, as well as relevant | |
| information. | |
| Attributes | |
| ---------- | |
| cropped_image : array_like | |
| An array-like of the cropped image taken from the original image. | |
| left : int | |
| An int describing the position of the left border of the crop's bounding box in the original image. | |
| lower : int | |
| An int describing the position of the bottom border of the crop's bounding box in the original image. | |
| right : int | |
| An int describing the position of the right border of the crop's bounding box in the original image. | |
| upper : int | |
| An int describing the position of the top border of the crop's bounding box in the original image. | |
| Methods | |
| ------- | |
| find(object_name: str)->List[ImagePatch] | |
| Returns a list of new ImagePatch objects containing crops of the image centered around any objects found in the | |
| image matching the object_name. | |
| exists(object_name: str)->bool | |
| Returns True if the object specified by object_name is found in the image, and False otherwise. | |
| verify_property(property: str)->bool | |
| Returns True if the property is met, and False otherwise. | |
| best_text_match(option_list: List[str], prefix: str)->str | |
| Returns the string that best matches the image. | |
| simple_query(question: str=None)->str | |
| Returns the answer to a basic question asked about the image. If no question is provided, returns the answer | |
| to "What is this?". | |
| compute_depth()->float | |
| Returns the median depth of the image crop. | |
| crop(left: int, lower: int, right: int, upper: int)->ImagePatch | |
| Returns a new ImagePatch object containing a crop of the image at the given coordinates. | |
| """ | |
| def __init__(self, image: Union[Image.Image, torch.Tensor, np.ndarray], left: int = None, lower: int = None, | |
| right: int = None, upper: int = None, parent_left=0, parent_lower=0, queues=None, | |
| parent_img_patch=None): | |
| """Initializes an ImagePatch object by cropping the image at the given coordinates and stores the coordinates as | |
| attributes. If no coordinates are provided, the image is left unmodified, and the coordinates are set to the | |
| dimensions of the image. | |
| Parameters | |
| ------- | |
| image : array_like | |
| An array-like of the original image. | |
| left : int | |
| An int describing the position of the left border of the crop's bounding box in the original image. | |
| lower : int | |
| An int describing the position of the bottom border of the crop's bounding box in the original image. | |
| right : int | |
| An int describing the position of the right border of the crop's bounding box in the original image. | |
| upper : int | |
| An int describing the position of the top border of the crop's bounding box in the original image. | |
| """ | |
| if isinstance(image, Image.Image): | |
| image = transforms.ToTensor()(image) | |
| elif isinstance(image, np.ndarray): | |
| image = torch.tensor(image).permute(1, 2, 0) | |
| elif isinstance(image, torch.Tensor) and image.dtype == torch.uint8: | |
| image = image / 255 | |
| if left is None and right is None and upper is None and lower is None: | |
| self.cropped_image = image | |
| self.left = 0 | |
| self.lower = 0 | |
| self.right = image.shape[2] # width | |
| self.upper = image.shape[1] # height | |
| else: | |
| self.cropped_image = image[:, image.shape[1] - upper:image.shape[1] - lower, left:right] | |
| self.left = left + parent_left | |
| self.upper = upper + parent_lower | |
| self.right = right + parent_left | |
| self.lower = lower + parent_lower | |
| self.height = self.cropped_image.shape[1] | |
| self.width = self.cropped_image.shape[2] | |
| self.cache = {} | |
| self.queues = (None, None) if queues is None else queues | |
| self.parent_img_patch = parent_img_patch | |
| self.horizontal_center = (self.left + self.right) / 2 | |
| self.vertical_center = (self.lower + self.upper) / 2 | |
| if self.cropped_image.shape[1] == 0 or self.cropped_image.shape[2] == 0: | |
| raise Exception("ImagePatch has no area") | |
| self.possible_options = load_json('./useful_lists/possible_options.json') | |
| def forward(self, model_name, *args, **kwargs): | |
| return forward(model_name, *args, **kwargs) | |
| # return forward(model_name, *args, queues=self.queues, **kwargs) | |
| def original_image(self): | |
| if self.parent_img_patch is None: | |
| return self.cropped_image | |
| else: | |
| return self.parent_img_patch.original_image | |
| def find(self, object_name: str, confidence_threshold: float = None, return_confidence: bool = False) -> List: | |
| """Returns a list of ImagePatch objects matching object_name contained in the crop if any are found. | |
| Otherwise, returns an empty list. | |
| Parameters | |
| ---------- | |
| object_name : str | |
| the name of the object to be found | |
| Returns | |
| ------- | |
| List[ImagePatch] | |
| a list of ImagePatch objects matching object_name contained in the crop | |
| """ | |
| if confidence_threshold is not None: | |
| confidence_threshold = float(confidence_threshold) | |
| if object_name in ["object", "objects"]: | |
| all_object_coordinates, all_object_scores = self.forward('maskrcnn', self.cropped_image, | |
| confidence_threshold=confidence_threshold) | |
| all_object_coordinates = all_object_coordinates[0] | |
| all_object_scores = all_object_scores[0] | |
| else: | |
| if object_name == 'person': | |
| object_name = 'people' # GLIP does better at people than person | |
| all_object_coordinates, all_object_scores = self.forward('glip', self.cropped_image, object_name, | |
| confidence_threshold=confidence_threshold) | |
| if len(all_object_coordinates) == 0: | |
| return [] | |
| threshold = 0.0 | |
| if threshold > 0: | |
| area_im = self.width * self.height | |
| all_areas = torch.tensor([(coord[2] - coord[0]) * (coord[3] - coord[1]) / area_im | |
| for coord in all_object_coordinates]) | |
| mask = all_areas > threshold | |
| # if not mask.any(): | |
| # mask = all_areas == all_areas.max() # At least return one element | |
| all_object_coordinates = all_object_coordinates[mask] | |
| all_object_scores = all_object_scores[mask] | |
| boxes = [self.crop(*coordinates) for coordinates in all_object_coordinates] | |
| if return_confidence: | |
| return [(box, float(score)) for box, score in zip(boxes, all_object_scores.reshape(-1))] | |
| else: | |
| return boxes | |
| def exists(self, object_name) -> bool: | |
| """Returns True if the object specified by object_name is found in the image, and False otherwise. | |
| Parameters | |
| ------- | |
| object_name : str | |
| A string describing the name of the object to be found in the image. | |
| """ | |
| if object_name.isdigit() or object_name.lower().startswith("number"): | |
| object_name = object_name.lower().replace("number", "").strip() | |
| object_name = w2n.word_to_num(object_name) | |
| answer = self.simple_query("What number is written in the image (in digits)?") | |
| return w2n.word_to_num(answer) == object_name | |
| patches = self.find(object_name) | |
| filtered_patches = [] | |
| for patch in patches: | |
| if "yes" in patch.simple_query(f"Is this a {object_name}?"): | |
| filtered_patches.append(patch) | |
| return len(filtered_patches) > 0 | |
| def _score(self, category: str, negative_categories=None, model='clip') -> float: | |
| """ | |
| Returns a binary score for the similarity between the image and the category. | |
| The negative categories are used to compare to (score is relative to the scores of the negative categories). | |
| """ | |
| if model == 'clip': | |
| res = self.forward('clip', self.cropped_image, category, task='score', | |
| negative_categories=negative_categories) | |
| elif model == 'tcl': | |
| res = self.forward('tcl', self.cropped_image, category, task='score') | |
| else: # xvlm | |
| task = 'binary_score' if negative_categories is not None else 'score' | |
| res = self.forward('xvlm', self.cropped_image, category, task=task, negative_categories=negative_categories) | |
| res = res.item() | |
| return res | |
| def _detect(self, category: str, thresh, negative_categories=None, model='clip') -> Tuple[bool, float]: | |
| score = self._score(category, negative_categories, model) | |
| return score > thresh, float(score) | |
| def verify_property(self, object_name: str, attribute: str, return_confidence: bool = False): | |
| """Returns True if the object possesses the property, and False otherwise. | |
| Differs from 'exists' in that it presupposes the existence of the object specified by object_name, instead | |
| checking whether the object possesses the property. | |
| Parameters | |
| ------- | |
| object_name : str | |
| A string describing the name of the object to be found in the image. | |
| attribute : str | |
| A string describing the property to be checked. | |
| """ | |
| name = f"{attribute} {object_name}" | |
| model = "xvlm" | |
| negative_categories = [f"{att} {object_name}" for att in self.possible_options['attributes']] | |
| # if model == 'clip': | |
| # ret, score = self._detect(name, negative_categories=negative_categories, | |
| # thresh=config.verify_property.thresh_clip, model='clip') | |
| # elif model == 'tcl': | |
| # ret, score = self._detect(name, thresh=config.verify_property.thresh_tcl, model='tcl') | |
| # else: # 'xvlm' | |
| ret, score = self._detect(name, negative_categories=negative_categories, thresh=0.6, model='xvlm') | |
| if return_confidence: | |
| return ret, score | |
| else: | |
| return ret | |
| def best_text_match(self, option_list: list[str] = None, prefix: str = None) -> str: | |
| """Returns the string that best matches the image. | |
| Parameters | |
| ------- | |
| option_list : str | |
| A list with the names of the different options | |
| prefix : str | |
| A string with the prefixes to append to the options | |
| """ | |
| option_list_to_use = option_list | |
| if prefix is not None: | |
| option_list_to_use = [prefix + " " + option for option in option_list] | |
| model_name = "xvlm" | |
| image = self.cropped_image | |
| text = option_list_to_use | |
| if model_name in ('clip', 'tcl'): | |
| selected = self.forward(model_name, image, text, task='classify') | |
| elif model_name == 'xvlm': | |
| res = self.forward(model_name, image, text, task='score') | |
| res = res.argmax().item() | |
| selected = res | |
| else: | |
| raise NotImplementedError | |
| return option_list[selected] | |
| def simple_query(self, question: str, return_confidence: bool = False): | |
| """Returns the answer to a basic question asked about the image. If no question is provided, returns the answer | |
| to "What is this?". The questions are about basic perception, and are not meant to be used for complex reasoning | |
| or external knowledge. | |
| Parameters | |
| ------- | |
| question : str | |
| A string describing the question to be asked. | |
| """ | |
| text, score = self.forward('blip', self.cropped_image, question, task='qa') | |
| if return_confidence: | |
| return text, score | |
| else: | |
| return text | |
| def compute_depth(self): | |
| """Returns the median depth of the image crop | |
| Parameters | |
| ---------- | |
| Returns | |
| ------- | |
| float | |
| the median depth of the image crop | |
| """ | |
| original_image = self.original_image | |
| depth_map = self.forward('depth', original_image) | |
| depth_map = depth_map[original_image.shape[1] - self.upper:original_image.shape[1] - self.lower, | |
| self.left:self.right] | |
| return depth_map.median() # Ideally some kind of mode, but median is good enough for now | |
| def crop(self, left: int, lower: int, right: int, upper: int) -> ImagePatch: | |
| """Returns a new ImagePatch containing a crop of the original image at the given coordinates. | |
| Parameters | |
| ---------- | |
| left : int | |
| the position of the left border of the crop's bounding box in the original image | |
| lower : int | |
| the position of the bottom border of the crop's bounding box in the original image | |
| right : int | |
| the position of the right border of the crop's bounding box in the original image | |
| upper : int | |
| the position of the top border of the crop's bounding box in the original image | |
| Returns | |
| ------- | |
| ImagePatch | |
| a new ImagePatch containing a crop of the original image at the given coordinates | |
| """ | |
| # make all inputs ints | |
| left = int(left) | |
| lower = int(lower) | |
| right = int(right) | |
| upper = int(upper) | |
| if True: | |
| left = max(0, left - 10) | |
| lower = max(0, lower - 10) | |
| right = min(self.width, right + 10) | |
| upper = min(self.height, upper + 10) | |
| return ImagePatch(self.cropped_image, left, lower, right, upper, self.left, self.lower, queues=self.queues, | |
| parent_img_patch=self) | |
| def overlaps_with(self, left, lower, right, upper): | |
| """Returns True if a crop with the given coordinates overlaps with this one, | |
| else False. | |
| Parameters | |
| ---------- | |
| left : int | |
| the left border of the crop to be checked | |
| lower : int | |
| the lower border of the crop to be checked | |
| right : int | |
| the right border of the crop to be checked | |
| upper : int | |
| the upper border of the crop to be checked | |
| Returns | |
| ------- | |
| bool | |
| True if a crop with the given coordinates overlaps with this one, else False | |
| """ | |
| return self.left <= right and self.right >= left and self.lower <= upper and self.upper >= lower | |
| def llm_query(self, question: str, long_answer: bool = True) -> str: | |
| return llm_query(question, None, long_answer) | |
| # def print_image(self, size: tuple[int, int] = None): | |
| # show_single_image(self.cropped_image, size) | |
| def __repr__(self): | |
| return "ImagePatch(left={}, right={}, upper={}, lower={}, height={}, width={}, horizontal_center={}, vertical_center={})".format( | |
| self.left, self.right, self.upper, self.lower, self.height, self.width, | |
| self.horizontal_center, self.vertical_center, | |
| ) | |
| # return "ImagePatch({}, {}, {}, {})".format(self.left, self.lower, self.right, self.upper) | |
| def best_image_match(list_patches: list[ImagePatch], content: List[str], return_index: bool = False) -> \ | |
| Union[ImagePatch, None]: | |
| """Returns the patch most likely to contain the content. | |
| Parameters | |
| ---------- | |
| list_patches : List[ImagePatch] | |
| content : List[str] | |
| the object of interest | |
| return_index : bool | |
| if True, returns the index of the patch most likely to contain the object | |
| Returns | |
| ------- | |
| int | |
| Patch most likely to contain the object | |
| """ | |
| if len(list_patches) == 0: | |
| return None | |
| model = "xvlm" | |
| scores = [] | |
| for cont in content: | |
| if model == 'clip': | |
| res = list_patches[0].forward(model, [p.cropped_image for p in list_patches], cont, task='compare', | |
| return_scores=True) | |
| else: | |
| res = list_patches[0].forward(model, [p.cropped_image for p in list_patches], cont, task='score') | |
| scores.append(res) | |
| scores = torch.stack(scores).mean(dim=0) | |
| scores = scores.argmax().item() # Argmax over all image patches | |
| if return_index: | |
| return scores | |
| return list_patches[scores] | |
| def distance(patch_a: Union[ImagePatch, float], patch_b: Union[ImagePatch, float]) -> float: | |
| """ | |
| Returns the distance between the edges of two ImagePatches, or between two floats. | |
| If the patches overlap, it returns a negative distance corresponding to the negative intersection over union. | |
| """ | |
| if isinstance(patch_a, ImagePatch) and isinstance(patch_b, ImagePatch): | |
| a_min = np.array([patch_a.left, patch_a.lower]) | |
| a_max = np.array([patch_a.right, patch_a.upper]) | |
| b_min = np.array([patch_b.left, patch_b.lower]) | |
| b_max = np.array([patch_b.right, patch_b.upper]) | |
| u = np.maximum(0, a_min - b_max) | |
| v = np.maximum(0, b_min - a_max) | |
| dist = np.sqrt((u ** 2).sum() + (v ** 2).sum()) | |
| if dist == 0: | |
| box_a = torch.tensor([patch_a.left, patch_a.lower, patch_a.right, patch_a.upper])[None] | |
| box_b = torch.tensor([patch_b.left, patch_b.lower, patch_b.right, patch_b.upper])[None] | |
| dist = - box_iou(box_a, box_b).item() | |
| else: | |
| dist = abs(patch_a - patch_b) | |
| return dist | |
| def bool_to_yesno(bool_answer: bool) -> str: | |
| """Returns a yes/no answer to a question based on the boolean value of bool_answer. | |
| Parameters | |
| ---------- | |
| bool_answer : bool | |
| a boolean value | |
| Returns | |
| ------- | |
| str | |
| a yes/no answer to a question based on the boolean value of bool_answer | |
| """ | |
| return "yes" if bool_answer else "no" | |
| def llm_query(query, context=None, long_answer=True, queues=None): | |
| """Answers a text question using GPT-3. The input question is always a formatted string with a variable in it. | |
| Parameters | |
| ---------- | |
| query: str | |
| the text question to ask. Must not contain any reference to 'the image' or 'the photo', etc. | |
| """ | |
| if long_answer: | |
| return forward(model_name='gpt3_general', prompt=query, queues=queues) | |
| else: | |
| return forward(model_name='gpt3_qa', prompt=[query, context], queues=queues) | |
| def process_guesses(prompt, guess1=None, guess2=None, queues=None): | |
| return forward(model_name='gpt3_guess', prompt=[prompt, guess1, guess2], queues=queues) | |
| def coerce_to_numeric(string, no_string=False): | |
| """ | |
| This function takes a string as input and returns a numeric value after removing any non-numeric characters. | |
| If the input string contains a range (e.g. "10-15"), it returns the first value in the range. | |
| # TODO: Cases like '25to26' return 2526, which is not correct. | |
| """ | |
| if any(month in string.lower() for month in ['january', 'february', 'march', 'april', 'may', 'june', 'july', | |
| 'august', 'september', 'october', 'november', 'december']): | |
| try: | |
| return dateparser.parse(string).timestamp().year | |
| except: # Parse Error | |
| pass | |
| try: | |
| # If it is a word number (e.g. 'zero') | |
| numeric = w2n.word_to_num(string) | |
| return numeric | |
| except ValueError: | |
| pass | |
| # Remove any non-numeric characters except the decimal point and the negative sign | |
| string_re = re.sub("[^0-9\.\-]", "", string) | |
| if string_re.startswith('-'): | |
| string_re = '&' + string_re[1:] | |
| # Check if the string includes a range | |
| if "-" in string_re: | |
| # Split the string into parts based on the dash character | |
| parts = string_re.split("-") | |
| return coerce_to_numeric(parts[0].replace('&', '-')) | |
| else: | |
| string_re = string_re.replace('&', '-') | |
| try: | |
| # Convert the string to a float or int depending on whether it has a decimal point | |
| if "." in string_re: | |
| numeric = float(string_re) | |
| else: | |
| numeric = int(string_re) | |
| except: | |
| if no_string: | |
| raise ValueError | |
| # No numeric values. Return input | |
| return string | |
| return numeric | |