| import torch | |
| import numpy as np | |
| from typing import List, Callable | |
| import numpy as np | |
| import cv2 | |
| class PerturbationConfidenceMetric: | |
| def __init__(self, perturbation): | |
| self.perturbation = perturbation | |
| def __call__(self, input_tensor: torch.Tensor, | |
| cams: np.ndarray, | |
| targets: List[Callable], | |
| model: torch.nn.Module, | |
| return_visualization=False, | |
| return_diff=True): | |
| if return_diff: | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| scores = [target(output).cpu().numpy() | |
| for target, output in zip(targets, outputs)] | |
| scores = np.float32(scores) | |
| batch_size = input_tensor.size(0) | |
| perturbated_tensors = [] | |
| for i in range(batch_size): | |
| cam = cams[i] | |
| tensor = self.perturbation(input_tensor[i, ...].cpu(), | |
| torch.from_numpy(cam)) | |
| tensor = tensor.to(input_tensor.device) | |
| perturbated_tensors.append(tensor.unsqueeze(0)) | |
| perturbated_tensors = torch.cat(perturbated_tensors) | |
| with torch.no_grad(): | |
| outputs_after_imputation = model(perturbated_tensors) | |
| scores_after_imputation = [ | |
| target(output).cpu().numpy() for target, output in zip( | |
| targets, outputs_after_imputation)] | |
| scores_after_imputation = np.float32(scores_after_imputation) | |
| if return_diff: | |
| result = scores_after_imputation - scores | |
| else: | |
| result = scores_after_imputation | |
| if return_visualization: | |
| return result, perturbated_tensors | |
| else: | |
| return result | |
| class RemoveMostRelevantFirst: | |
| def __init__(self, percentile, imputer): | |
| self.percentile = percentile | |
| self.imputer = imputer | |
| def __call__(self, input_tensor, mask): | |
| imputer = self.imputer | |
| if self.percentile != 'auto': | |
| threshold = np.percentile(mask.cpu().numpy(), self.percentile) | |
| binary_mask = np.float32(mask < threshold) | |
| else: | |
| _, binary_mask = cv2.threshold( | |
| np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| binary_mask = torch.from_numpy(binary_mask) | |
| binary_mask = binary_mask.to(mask.device) | |
| return imputer(input_tensor, binary_mask) | |
| class RemoveLeastRelevantFirst(RemoveMostRelevantFirst): | |
| def __init__(self, percentile, imputer): | |
| super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer) | |
| def __call__(self, input_tensor, mask): | |
| return super(RemoveLeastRelevantFirst, self).__call__( | |
| input_tensor, 1 - mask) | |
| class AveragerAcrossThresholds: | |
| def __init__( | |
| self, | |
| imputer, | |
| percentiles=[ | |
| 10, | |
| 20, | |
| 30, | |
| 40, | |
| 50, | |
| 60, | |
| 70, | |
| 80, | |
| 90]): | |
| self.imputer = imputer | |
| self.percentiles = percentiles | |
| def __call__(self, | |
| input_tensor: torch.Tensor, | |
| cams: np.ndarray, | |
| targets: List[Callable], | |
| model: torch.nn.Module): | |
| scores = [] | |
| for percentile in self.percentiles: | |
| imputer = self.imputer(percentile) | |
| scores.append(imputer(input_tensor, cams, targets, model)) | |
| return np.mean(np.float32(scores), axis=0) | |