Spaces:
Runtime error
Runtime error
| import random | |
| import os | |
| import torch | |
| import pandas as pd | |
| from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection | |
| from PIL import Image | |
| CLIPmodel_import = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
| CLIPprocessor_import = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| DetrFeatureExtractor_import = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50") | |
| DetrModel_import = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| # Import list of coco example objects | |
| script_path = os.path.dirname(__file__) | |
| coco_objects = open(script_path+"/coco-labels-paper.txt", "r") | |
| coco_objects = coco_objects.read() | |
| coco_objects = coco_objects.split("\n") | |
| # Example image | |
| #test_image = Image.open('pages/Functions/test_image.png') | |
| #test_image = Image.open('pages/Functions/test_imageIV.png') | |
| ###### Helper functions | |
| def Coco_object_set(included_object, set_length=6): | |
| ''' | |
| Creates set of object based on coco objects and the currently correct object. | |
| ''' | |
| curr_object_set = set([included_object]) | |
| while len(curr_object_set)<set_length: | |
| temp_object = random.choice(coco_objects) | |
| curr_object_set.add(temp_object) | |
| return list(curr_object_set) | |
| def Object_set_creator(included_object, list_of_all_objects = coco_objects, excluded_objects_list = [], set_length=6): | |
| ''' | |
| Creates set of object based on list_of_all_objects. | |
| The included object will always be in the list. | |
| Optional list of objects to be excluded from the set. | |
| ''' | |
| curr_object_set = set([included_object]) | |
| # Check that the included object is not contained in the excluded objects | |
| if included_object in excluded_objects_list: | |
| raise ValueError('The included_object can not be part of the excluded_objects list.') | |
| while len(curr_object_set)<set_length: | |
| temp_object = random.choice(list_of_all_objects) | |
| if temp_object not in excluded_objects_list: | |
| curr_object_set.add(temp_object) | |
| return list(curr_object_set) | |
| ###### Single object recognition | |
| def CLIP_single_object_classifier(img, object_class, task_specific_label=None): | |
| ''' | |
| Test presence of object in image by using the "red herring strategy" and CLIP algorithm. | |
| Note that the task_specific_label is not used for this classifier. | |
| ''' | |
| # Define model and parameters | |
| word_list = Coco_object_set(object_class) | |
| inputs = CLIPprocessor_import(text=word_list, images=img, return_tensors="pt", padding=True) | |
| # Run inference | |
| outputs = CLIPmodel_import(**inputs) | |
| # Get image-text similarity score | |
| logits_per_image = outputs.logits_per_image | |
| # Get probabilities | |
| probs = logits_per_image.softmax(dim=1) | |
| # Return true if the highest prob value is recognised | |
| if word_list[probs.argmax().item()]==object_class: | |
| return True | |
| else: | |
| return False | |
| def CLIP_object_recognition(img, object_class, tested_classes): | |
| ''' | |
| More general CLIP object recogntintion implementation | |
| ''' | |
| if object_class not in tested_classes: | |
| raise ValueError('The object_class has to be part of the tested_classes list.') | |
| # Define model and parameters | |
| inputs = CLIPprocessor_import(text=tested_classes, images=img, return_tensors="pt", padding=True) | |
| # Run inference | |
| outputs = CLIPmodel_import(**inputs) | |
| # Get image-text similarity score | |
| logits_per_image = outputs.logits_per_image | |
| # Get probabilities | |
| probs = logits_per_image.softmax(dim=1) | |
| # Return true if the highest prob value is recognised | |
| if tested_classes[probs.argmax().item()]==object_class: | |
| return True | |
| else: | |
| return False | |
| ###### Multi object recognition | |
| #list_of_objects = ['cat','apple','cow'] | |
| def CLIP_multi_object_recognition(img, list_of_objects): | |
| ''' | |
| Algorithm based on CLIP to test presence of multiple objects. | |
| Currently has a debugging print call in. | |
| ''' | |
| # Loop over list of objects, test for presence of each inidividually, making sure that non of the other objects is part of test set | |
| for i_object in list_of_objects: | |
| # Create list with objects not in test set (all objects which arent i_object) | |
| untested_objects = [x for x in list_of_objects if x!= i_object] | |
| # Create set going into clip object recogniser and test this set using standard recognition function | |
| CLIP_test_classes = Object_set_creator(included_object=i_object, excluded_objects_list=untested_objects) | |
| i_object_present = CLIP_object_recognition(img, i_object, CLIP_test_classes) | |
| print(i_object+str(i_object_present)) | |
| # Stop loop and return false if one of the objects is not recognised by CLIP | |
| if i_object_present == False: | |
| return False | |
| # Return true if all objects were recognised | |
| return True | |
| def CLIP_multi_object_recognition_DSwrapper(img, representations, task_specific_label=None): | |
| ''' | |
| Dashboard wrapper of CLIP_multi_object_recognition | |
| Note that the task_specific_label is not used for this classifier. | |
| ''' | |
| list_of_objects = representations.split(', ') | |
| return CLIP_multi_object_recognition(img,list_of_objects) | |
| ###### Negation | |
| def CLIP_object_negation(img, present_object, absent_object): | |
| ''' | |
| Algorithm based on CLIP to test negation prompts | |
| ''' | |
| # Create sets of objects for present and absent object | |
| tested_classes_present = Object_set_creator( | |
| included_object=present_object, excluded_objects_list=[absent_object]) | |
| tested_classes_absent = Object_set_creator( | |
| included_object=absent_object, excluded_objects_list=[present_object],set_length=10) | |
| # Use CLIP object recognition to test for objects. | |
| presence_test = CLIP_object_recognition(img, present_object, tested_classes_present) | |
| absence_test = CLIP_object_recognition(img, absent_object, tested_classes_absent) | |
| if presence_test==True and absence_test==False: | |
| return True | |
| else: | |
| return False | |
| ###### Counting / arithmetic | |
| ''' | |
| test_image = Image.open('pages/Functions/test_imageIII.jpeg') | |
| object_classes = ['cat','remote'] | |
| object_counts = [2,2] | |
| ''' | |
| def DETR_multi_object_counting(img, object_classes, object_counts, confidence_treshold=0.5): | |
| # Apply Detr to image | |
| inputs = DetrFeatureExtractor_import(images=img, return_tensors="pt") | |
| outputs = DetrModel_import(**inputs) | |
| # Convert outputs (bounding boxes and class logits) to COCO API | |
| target_sizes = torch.tensor([img.size[::-1]]) | |
| results = DetrFeatureExtractor_import.post_process_object_detection( | |
| outputs, threshold=confidence_treshold, target_sizes=target_sizes)[0] | |
| # Create dict with value_counts | |
| count_dict = pd.Series(results['labels'].numpy()) | |
| count_dict = count_dict.map(DetrModel_import.config.id2label) | |
| count_dict = count_dict.value_counts().to_dict() | |
| # Create dict for correct response | |
| label_dict = dict(zip(object_classes, object_counts)) | |
| # Return False is the count for a given label does not match | |
| for i_item in label_dict.items(): | |
| # Check whether current label item exists in count dict, else return false | |
| if i_item[0] not in count_dict: | |
| return False | |
| # Now that we checked the label item is in count dict, check that the count matches | |
| if int(count_dict[i_item[0]])==int(i_item[1]): # Adding type control for comparison due to str read in | |
| print(str(i_item)+'_true') | |
| else: | |
| print(str(i_item)+'_false') | |
| print("oberserved: "+str(count_dict[i_item[0]])) | |
| return False | |
| # If all match, return true | |
| return True | |
| def DETR_multi_object_counting_DSwrapper(img, representations, Task_specific_label): | |
| ''' | |
| Dashboard wrapper of DETR_multi_object_counting | |
| ''' | |
| list_of_objects = representations.split(', ') | |
| object_counts = Task_specific_label.split(', ') | |
| return DETR_multi_object_counting(img,list_of_objects, object_counts, confidence_treshold=0.5) |