Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from functools import partial | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| from typing import List, Optional | |
| import torch | |
| from torch.nn.functional import one_hot, softmax | |
| import dinov2.distributed as distributed | |
| from dinov2.data import SamplerType, make_data_loader, make_dataset | |
| from dinov2.data.transforms import make_classification_eval_transform | |
| from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric | |
| from dinov2.eval.setup import get_args_parser as get_setup_args_parser | |
| from dinov2.eval.setup import setup_and_build_model | |
| from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features | |
| logger = logging.getLogger("dinov2") | |
| def get_args_parser( | |
| description: Optional[str] = None, | |
| parents: Optional[List[argparse.ArgumentParser]] = None, | |
| add_help: bool = True, | |
| ): | |
| parents = parents or [] | |
| setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) | |
| parents = [setup_args_parser] | |
| parser = argparse.ArgumentParser( | |
| description=description, | |
| parents=parents, | |
| add_help=add_help, | |
| ) | |
| parser.add_argument( | |
| "--train-dataset", | |
| dest="train_dataset_str", | |
| type=str, | |
| help="Training dataset", | |
| ) | |
| parser.add_argument( | |
| "--val-dataset", | |
| dest="val_dataset_str", | |
| type=str, | |
| help="Validation dataset", | |
| ) | |
| parser.add_argument( | |
| "--nb_knn", | |
| nargs="+", | |
| type=int, | |
| help="Number of NN to use. 20 is usually working the best.", | |
| ) | |
| parser.add_argument( | |
| "--temperature", | |
| type=float, | |
| help="Temperature used in the voting coefficient", | |
| ) | |
| parser.add_argument( | |
| "--gather-on-cpu", | |
| action="store_true", | |
| help="Whether to gather the train features on cpu, slower" | |
| "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| help="Batch size.", | |
| ) | |
| parser.add_argument( | |
| "--n-per-class-list", | |
| nargs="+", | |
| type=int, | |
| help="Number to take per class", | |
| ) | |
| parser.add_argument( | |
| "--n-tries", | |
| type=int, | |
| help="Number of tries", | |
| ) | |
| parser.set_defaults( | |
| train_dataset_str="ImageNet:split=TRAIN", | |
| val_dataset_str="ImageNet:split=VAL", | |
| nb_knn=[10, 20, 100, 200], | |
| temperature=0.07, | |
| batch_size=256, | |
| n_per_class_list=[-1], | |
| n_tries=1, | |
| ) | |
| return parser | |
| class KnnModule(torch.nn.Module): | |
| """ | |
| Gets knn of test features from all processes on a chunk of the train features | |
| Each rank gets a chunk of the train features as well as a chunk of the test features. | |
| In `compute_neighbors`, for each rank one after the other, its chunk of test features | |
| is sent to all devices, partial knns are computed with each chunk of train features | |
| then collated back on the original device. | |
| """ | |
| def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): | |
| super().__init__() | |
| self.global_rank = distributed.get_global_rank() | |
| self.global_size = distributed.get_global_size() | |
| self.device = device | |
| self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) | |
| self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) | |
| self.nb_knn = nb_knn | |
| self.max_k = max(self.nb_knn) | |
| self.T = T | |
| self.num_classes = num_classes | |
| def _get_knn_sims_and_labels(self, similarity, train_labels): | |
| topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) | |
| neighbors_labels = torch.gather(train_labels, 1, indices) | |
| return topk_sims, neighbors_labels | |
| def _similarity_for_rank(self, features_rank, source_rank): | |
| # Send the features from `source_rank` to all ranks | |
| broadcast_shape = torch.tensor(features_rank.shape).to(self.device) | |
| torch.distributed.broadcast(broadcast_shape, source_rank) | |
| broadcasted = features_rank | |
| if self.global_rank != source_rank: | |
| broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) | |
| torch.distributed.broadcast(broadcasted, source_rank) | |
| # Compute the neighbors for `source_rank` among `train_features_rank_T` | |
| similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) | |
| candidate_labels = self.candidates.expand(len(similarity_rank), -1) | |
| return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) | |
| def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): | |
| # Gather all neighbors for `target_rank` | |
| topk_sims_rank = retrieved_rank = None | |
| if self.global_rank == target_rank: | |
| topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] | |
| retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] | |
| torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) | |
| torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) | |
| if self.global_rank == target_rank: | |
| # Perform a second top-k on the k * global_size retrieved neighbors | |
| topk_sims_rank = torch.cat(topk_sims_rank, dim=1) | |
| retrieved_rank = torch.cat(retrieved_rank, dim=1) | |
| results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) | |
| return results | |
| return None | |
| def compute_neighbors(self, features_rank): | |
| for rank in range(self.global_size): | |
| topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) | |
| results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) | |
| if results is not None: | |
| topk_sims_rank, neighbors_labels_rank = results | |
| return topk_sims_rank, neighbors_labels_rank | |
| def forward(self, features_rank): | |
| """ | |
| Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` | |
| """ | |
| assert all(k <= self.max_k for k in self.nb_knn) | |
| topk_sims, neighbors_labels = self.compute_neighbors(features_rank) | |
| batch_size = neighbors_labels.shape[0] | |
| topk_sims_transform = softmax(topk_sims / self.T, 1) | |
| matmul = torch.mul( | |
| one_hot(neighbors_labels, num_classes=self.num_classes), | |
| topk_sims_transform.view(batch_size, -1, 1), | |
| ) | |
| probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} | |
| return probas_for_k | |
| class DictKeysModule(torch.nn.Module): | |
| def __init__(self, keys): | |
| super().__init__() | |
| self.keys = keys | |
| def forward(self, features_dict, targets): | |
| for k in self.keys: | |
| features_dict = features_dict[k] | |
| return {"preds": features_dict, "target": targets} | |
| def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): | |
| modules = {} | |
| mapping = create_class_indices_mapping(train_labels) | |
| for npc in n_per_class_list: | |
| if npc < 0: # Only one try needed when using the full data | |
| full_module = module( | |
| train_features=train_features, | |
| train_labels=train_labels, | |
| nb_knn=nb_knn, | |
| ) | |
| modules["full"] = ModuleDictWithForward({"1": full_module}) | |
| continue | |
| all_tries = {} | |
| for t in range(n_tries): | |
| final_indices = filter_train(mapping, npc, seed=t) | |
| k_list = list(set(nb_knn + [npc])) | |
| k_list = sorted([el for el in k_list if el <= npc]) | |
| all_tries[str(t)] = module( | |
| train_features=train_features[final_indices], | |
| train_labels=train_labels[final_indices], | |
| nb_knn=k_list, | |
| ) | |
| modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) | |
| return ModuleDictWithForward(modules) | |
| def filter_train(mapping, n_per_class, seed): | |
| torch.manual_seed(seed) | |
| final_indices = [] | |
| for k in mapping.keys(): | |
| index = torch.randperm(len(mapping[k]))[:n_per_class] | |
| final_indices.append(mapping[k][index]) | |
| return torch.cat(final_indices).squeeze() | |
| def create_class_indices_mapping(labels): | |
| unique_labels, inverse = torch.unique(labels, return_inverse=True) | |
| mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} | |
| return mapping | |
| class ModuleDictWithForward(torch.nn.ModuleDict): | |
| def forward(self, *args, **kwargs): | |
| return {k: module(*args, **kwargs) for k, module in self._modules.items()} | |
| def eval_knn( | |
| model, | |
| train_dataset, | |
| val_dataset, | |
| accuracy_averaging, | |
| nb_knn, | |
| temperature, | |
| batch_size, | |
| num_workers, | |
| gather_on_cpu, | |
| n_per_class_list=[-1], | |
| n_tries=1, | |
| ): | |
| model = ModelWithNormalize(model) | |
| logger.info("Extracting features for train set...") | |
| train_features, train_labels = extract_features( | |
| model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu | |
| ) | |
| logger.info(f"Train features created, shape {train_features.shape}.") | |
| val_dataloader = make_data_loader( | |
| dataset=val_dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| sampler_type=SamplerType.DISTRIBUTED, | |
| drop_last=False, | |
| shuffle=False, | |
| persistent_workers=True, | |
| ) | |
| num_classes = train_labels.max() + 1 | |
| metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) | |
| device = torch.cuda.current_device() | |
| partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) | |
| knn_module_dict = create_module_dict( | |
| module=partial_module, | |
| n_per_class_list=n_per_class_list, | |
| n_tries=n_tries, | |
| nb_knn=nb_knn, | |
| train_features=train_features, | |
| train_labels=train_labels, | |
| ) | |
| postprocessors, metrics = {}, {} | |
| for n_per_class, knn_module in knn_module_dict.items(): | |
| for t, knn_try in knn_module.items(): | |
| postprocessors = { | |
| **postprocessors, | |
| **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, | |
| } | |
| metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} | |
| model_with_knn = torch.nn.Sequential(model, knn_module_dict) | |
| # ============ evaluation ... ============ | |
| logger.info("Start the k-NN classification.") | |
| _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) | |
| # Averaging the results over the n tries for each value of n_per_class | |
| for n_per_class, knn_module in knn_module_dict.items(): | |
| first_try = list(knn_module.keys())[0] | |
| k_list = knn_module[first_try].nb_knn | |
| for k in k_list: | |
| keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5` | |
| results_dict[(n_per_class, k)] = { | |
| key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) | |
| for key in keys | |
| } | |
| for t in knn_module.keys(): | |
| del results_dict[(n_per_class, t, k)] | |
| return results_dict | |
| def eval_knn_with_model( | |
| model, | |
| output_dir, | |
| train_dataset_str="ImageNet:split=TRAIN", | |
| val_dataset_str="ImageNet:split=VAL", | |
| nb_knn=(10, 20, 100, 200), | |
| temperature=0.07, | |
| autocast_dtype=torch.float, | |
| accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, | |
| transform=None, | |
| gather_on_cpu=False, | |
| batch_size=256, | |
| num_workers=5, | |
| n_per_class_list=[-1], | |
| n_tries=1, | |
| ): | |
| transform = transform or make_classification_eval_transform() | |
| train_dataset = make_dataset( | |
| dataset_str=train_dataset_str, | |
| transform=transform, | |
| ) | |
| val_dataset = make_dataset( | |
| dataset_str=val_dataset_str, | |
| transform=transform, | |
| ) | |
| with torch.cuda.amp.autocast(dtype=autocast_dtype): | |
| results_dict_knn = eval_knn( | |
| model=model, | |
| train_dataset=train_dataset, | |
| val_dataset=val_dataset, | |
| accuracy_averaging=accuracy_averaging, | |
| nb_knn=nb_knn, | |
| temperature=temperature, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| gather_on_cpu=gather_on_cpu, | |
| n_per_class_list=n_per_class_list, | |
| n_tries=n_tries, | |
| ) | |
| results_dict = {} | |
| if distributed.is_main_process(): | |
| for knn_ in results_dict_knn.keys(): | |
| top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 | |
| top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 | |
| results_dict[f"{knn_} Top 1"] = top1 | |
| results_dict[f"{knn_} Top 5"] = top5 | |
| logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") | |
| metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") | |
| with open(metrics_file_path, "a") as f: | |
| for k, v in results_dict.items(): | |
| f.write(json.dumps({k: v}) + "\n") | |
| if distributed.is_enabled(): | |
| torch.distributed.barrier() | |
| return results_dict | |
| def main(args): | |
| model, autocast_dtype = setup_and_build_model(args) | |
| eval_knn_with_model( | |
| model=model, | |
| output_dir=args.output_dir, | |
| train_dataset_str=args.train_dataset_str, | |
| val_dataset_str=args.val_dataset_str, | |
| nb_knn=args.nb_knn, | |
| temperature=args.temperature, | |
| autocast_dtype=autocast_dtype, | |
| accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, | |
| transform=None, | |
| gather_on_cpu=args.gather_on_cpu, | |
| batch_size=args.batch_size, | |
| num_workers=5, | |
| n_per_class_list=args.n_per_class_list, | |
| n_tries=args.n_tries, | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| description = "DINOv2 k-NN evaluation" | |
| args_parser = get_args_parser(description=description) | |
| args = args_parser.parse_args() | |
| sys.exit(main(args)) | |