Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import argparse | |
| from collections import defaultdict | |
| import gc | |
| import math | |
| import os | |
| from PIL import Image | |
| import random | |
| from tqdm import tqdm | |
| from typing import Any, Dict, Iterable, List, Tuple | |
| #import cv2 | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.distributed as dist | |
| from torch.utils.data import DataLoader | |
| import torch.nn.functional as F | |
| from torchvision.utils import make_grid | |
| from einops import rearrange | |
| from datasets import load_dataset_builder, load_dataset | |
| from datasets.distributed import split_dataset_by_node | |
| #from common import rank_print, load_model, get_standard_transform, collate | |
| # | |
| #try: | |
| # import wandb | |
| #except ImportError: | |
| # wandb = None | |
| LAYER_STATS = dict() | |
| def main(rank: int = 0, world_size: int = 1): | |
| ''' | |
| Computes the RankMe (http://arxiv.org/abs/2210.02885) and LiDAR (http://arxiv.org/abs/2312.04000) | |
| estimates of the rank of the produced embeddings. While RADIO doesn't train in a multi-view setting | |
| which is an assumption of LiDAR, the metric does integrate an important concept of the invariance of the | |
| summary features to different view/augmentations of the same image. | |
| ''' | |
| local_rank = rank % torch.cuda.device_count() | |
| torch.cuda.set_device(local_rank) | |
| cv2.setNumThreads(1) | |
| device = torch.device('cuda', local_rank) | |
| parser = argparse.ArgumentParser(description='Compute SSL embedding rank estimates') | |
| parser.add_argument('-v', '--model-version', default='radio_v2', | |
| help='Which radio model to load.' | |
| ) | |
| parser.add_argument('-d', '--dataset', default='imagenet-1k', | |
| help='The name of the dataset to classify' | |
| ) | |
| parser.add_argument('--split', default='validation', | |
| help='The dataset split to use.' | |
| ) | |
| parser.add_argument('-n', default=10, type=int, help='The number of samples to load') | |
| parser.add_argument('-r', '--resolution', nargs='+', type=int, default=None, | |
| help='The input image resolution.' | |
| ' If one value is specified, the shortest dimension is resized to this.' | |
| ' If two, the image is center cropped.' | |
| ' If not specified, center cropped 378px is used.' | |
| ' Default: The RADIO model\'s preferred resolution.' | |
| ) | |
| parser.add_argument('--resize-multiple', type=int, default=None, | |
| help='Resize images with dimensions a multiple of this value.' | |
| ' This should be equal to the patch size of a ViT (e.g. RADIOv1)' | |
| ) | |
| parser.add_argument('--batch-size', type=int, default=16, | |
| help='The batch size. If the input is variable sized, then this argument becomes a maximum.' | |
| ) | |
| parser.add_argument('--workers', default=8, type=int, help='Number of loader workers to use') | |
| parser.add_argument('--vitdet-window-size', default=None, type=int, help='Enable ViTDet at the specific window size') | |
| parser.add_argument('--output-dir', default='vis_denoise', type=str) | |
| parser.add_argument('--adaptor-name', default=None, type=str, help='Generate features from a teacher adaptor') | |
| args, _ = parser.parse_known_args() | |
| torch.manual_seed(42 + rank) | |
| np.random.seed(42 + rank) | |
| random.seed(42 + rank) | |
| rank_print('Loading model...') | |
| model, preprocessor, info = load_model(args.model_version, vitdet_window_size=args.vitdet_window_size, adaptor_name=args.adaptor_name) | |
| model.to(device=device).eval() | |
| if isinstance(preprocessor, nn.Module): | |
| preprocessor.to(device).eval() | |
| rank_print('Done') | |
| rank_print('Loading dataset...') | |
| ds_builder = load_dataset_builder(args.dataset, trust_remote_code=True) | |
| if args.resolution is None: | |
| args.resolution = (model.preferred_resolution.height, model.preferred_resolution.width) | |
| patch_size = model.patch_size | |
| if args.resize_multiple is None: | |
| args.resize_multiple = getattr(model, 'min_resolution_step', model.patch_size) | |
| transform = get_standard_transform(args.resolution, args.resize_multiple) | |
| dataset = ds_builder.as_dataset(split=args.split) | |
| dataset = dataset.to_iterable_dataset(num_shards=world_size * max(1, args.workers)) | |
| dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) | |
| dataset = dataset.map(lambda ex: dict(image=transform(ex['image']), label=torch.as_tensor(ex['label'], dtype=torch.int64))) | |
| loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, | |
| num_workers=args.workers, collate_fn=collate, | |
| pin_memory=args.workers > 0, | |
| drop_last=False, | |
| ) | |
| rank_print('Done') | |
| rank_print(f'Description: {ds_builder.info.description}') | |
| dirs = dict( | |
| orig=os.path.join(args.output_dir, 'orig'), | |
| viz=os.path.join(args.output_dir, 'viz'), | |
| sbs=os.path.join(args.output_dir, 'sbs'), | |
| ) | |
| for d in dirs.values(): | |
| os.makedirs(d, exist_ok=True) | |
| ctr = 0 | |
| for batches in loader: | |
| if ctr >= args.n: | |
| break | |
| for images, _ in batches: | |
| images = images.to(device=device, non_blocking=True) | |
| all_feat = [] | |
| with torch.autocast(device.type, dtype=torch.bfloat16): | |
| p_images = preprocessor(images) | |
| output = model(p_images) | |
| if args.adaptor_name: | |
| all_feat = [ | |
| output['backbone'].features, | |
| output[args.adaptor_name].features, | |
| ] | |
| else: | |
| all_feat = [output[1]] | |
| all_feat = torch.stack(all_feat, dim=1) | |
| num_rows = images.shape[-2] // patch_size | |
| num_cols = images.shape[-1] // patch_size | |
| all_feat = rearrange(all_feat, 'b m (h w) c -> b m h w c', h=num_rows, w=num_cols).float() | |
| for i, feats in enumerate(all_feat): | |
| colored = [] | |
| for features in feats: | |
| color = get_pca_map(features, images.shape[-2:]) | |
| colored.append(color) | |
| orig = cv2.cvtColor(images[i].permute(1, 2, 0).cpu().numpy(), cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(f'{dirs["orig"]}/vis_{ctr}.jpg', orig * 255) | |
| cv2.imwrite(f'{dirs["viz"]}/vis_{ctr}.jpg', colored[-1] * 255) | |
| op = np.concatenate([orig] + colored, axis=1) * 255 | |
| cv2.imwrite(f'{dirs["sbs"]}/vis_{ctr}.jpg', op) | |
| ctr += 1 | |
| def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): | |
| # features: (N, C) | |
| # m: a hyperparam controlling how many std dev outside for outliers | |
| assert len(features.shape) == 2, "features should be (N, C)" | |
| reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] | |
| colors = features @ reduction_mat | |
| if remove_first_component: | |
| colors_min = colors.min(dim=0).values | |
| colors_max = colors.max(dim=0).values | |
| tmp_colors = (colors - colors_min) / (colors_max - colors_min) | |
| fg_mask = tmp_colors[..., 0] < 0.2 | |
| reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] | |
| colors = features @ reduction_mat | |
| else: | |
| fg_mask = torch.ones_like(colors[:, 0]).bool() | |
| d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) | |
| mdev = torch.median(d, dim=0).values | |
| s = d / mdev | |
| try: | |
| rins = colors[fg_mask][s[:, 0] < m, 0] | |
| gins = colors[fg_mask][s[:, 1] < m, 1] | |
| bins = colors[fg_mask][s[:, 2] < m, 2] | |
| rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) | |
| rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) | |
| except: | |
| rins = colors | |
| gins = colors | |
| bins = colors | |
| rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) | |
| rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) | |
| return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) | |
| def get_pca_map( | |
| feature_map: torch.Tensor, | |
| img_size, | |
| interpolation="bicubic", | |
| return_pca_stats=False, | |
| pca_stats=None, | |
| ): | |
| """ | |
| feature_map: (1, h, w, C) is the feature map of a single image. | |
| """ | |
| feature_map = feature_map.float() | |
| if feature_map.shape[0] != 1: | |
| # make it (1, h, w, C) | |
| feature_map = feature_map[None] | |
| if pca_stats is None: | |
| reduct_mat, color_min, color_max = get_robust_pca( | |
| feature_map.reshape(-1, feature_map.shape[-1]) | |
| ) | |
| else: | |
| reduct_mat, color_min, color_max = pca_stats | |
| pca_color = feature_map @ reduct_mat | |
| pca_color = (pca_color - color_min) / (color_max - color_min) | |
| pca_color = F.interpolate( | |
| pca_color.permute(0, 3, 1, 2), | |
| size=img_size, | |
| mode=interpolation, | |
| ).permute(0, 2, 3, 1) | |
| pca_color = pca_color.clamp(0, 1) | |
| pca_color = pca_color.cpu().numpy().squeeze(0) | |
| if return_pca_stats: | |
| return pca_color, (reduct_mat, color_min, color_max) | |
| return pca_color | |
| def get_scale_map( | |
| scalar_map: torch.Tensor, | |
| img_size, | |
| interpolation="nearest", | |
| ): | |
| """ | |
| scalar_map: (1, h, w, C) is the feature map of a single image. | |
| """ | |
| if scalar_map.shape[0] != 1: | |
| scalar_map = scalar_map[None] | |
| scalar_map = (scalar_map - scalar_map.min()) / ( | |
| scalar_map.max() - scalar_map.min() + 1e-6 | |
| ) | |
| scalar_map = F.interpolate( | |
| scalar_map.permute(0, 3, 1, 2), | |
| size=img_size, | |
| mode=interpolation, | |
| ).permute(0, 2, 3, 1) | |
| # cmap = plt.get_cmap("viridis") | |
| # scalar_map = cmap(scalar_map)[..., :3] | |
| # make it 3 channels | |
| scalar_map = torch.cat([scalar_map] * 3, dim=-1) | |
| scalar_map = scalar_map.cpu().numpy().squeeze(0) | |
| return scalar_map | |
| def get_similarity_map(features: torch.Tensor, img_size=(224, 224)): | |
| """ | |
| compute the similarity map of the central patch to the rest of the image | |
| """ | |
| assert len(features.shape) == 4, "features should be (1, C, H, W)" | |
| H, W, C = features.shape[1:] | |
| center_patch_feature = features[0, H // 2, W // 2, :] | |
| center_patch_feature_normalized = center_patch_feature / center_patch_feature.norm() | |
| center_patch_feature_normalized = center_patch_feature_normalized.unsqueeze(1) | |
| # Reshape and normalize the entire feature tensor | |
| features_flat = features.view(-1, C) | |
| features_normalized = features_flat / features_flat.norm(dim=1, keepdim=True) | |
| similarity_map_flat = features_normalized @ center_patch_feature_normalized | |
| # Reshape the flat similarity map back to the spatial dimensions (H, W) | |
| similarity_map = similarity_map_flat.view(H, W) | |
| # Normalize the similarity map to be in the range [0, 1] for visualization | |
| similarity_map = (similarity_map - similarity_map.min()) / ( | |
| similarity_map.max() - similarity_map.min() | |
| ) | |
| # we don't want the center patch to be the most similar | |
| similarity_map[H // 2, W // 2] = -1.0 | |
| similarity_map = ( | |
| F.interpolate( | |
| similarity_map.unsqueeze(0).unsqueeze(0), | |
| size=img_size, | |
| mode="bilinear", | |
| ) | |
| .squeeze(0) | |
| .squeeze(0) | |
| ) | |
| similarity_map_np = similarity_map.cpu().numpy() | |
| negative_mask = similarity_map_np < 0 | |
| colormap = plt.get_cmap("turbo") | |
| # Apply the colormap directly to the normalized similarity map and multiply by 255 to get RGB values | |
| similarity_map_rgb = colormap(similarity_map_np)[..., :3] | |
| similarity_map_rgb[negative_mask] = [1.0, 0.0, 0.0] | |
| return similarity_map_rgb | |
| def get_cluster_map( | |
| feature_map: torch.Tensor, | |
| img_size, | |
| num_clusters=10, | |
| ) -> torch.Tensor: | |
| kmeans = KMeans(n_clusters=num_clusters, distance=CosineSimilarity, verbose=False) | |
| if feature_map.shape[0] != 1: | |
| # make it (1, h, w, C) | |
| feature_map = feature_map[None] | |
| labels = kmeans.fit_predict( | |
| feature_map.reshape(1, -1, feature_map.shape[-1]) | |
| ).float() | |
| labels = ( | |
| F.interpolate( | |
| labels.reshape(1, *feature_map.shape[:-1]), size=img_size, mode="nearest" | |
| ) | |
| .squeeze() | |
| .cpu() | |
| .numpy() | |
| ).astype(int) | |
| cmap = plt.get_cmap("rainbow", num_clusters) | |
| cluster_map = cmap(labels)[..., :3] | |
| return cluster_map.reshape(img_size[0], img_size[1], 3) | |
| if __name__ == '__main__': | |
| rank = 0 | |
| world_size = 1 | |
| # if 'WORLD_SIZE' in os.environ: | |
| # dist.init_process_group(backend='nccl') | |
| # rank = dist.get_rank() | |
| # world_size = dist.get_world_size() | |
| main(rank, world_size) | |