Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import tqdm | |
| from typing import Callable, List | |
| from pytorch_grad_cam.base_cam import BaseCAM | |
| from pytorch_grad_cam.utils.find_layers import replace_layer_recursive | |
| from pytorch_grad_cam.ablation_layer import AblationLayer | |
| """ Implementation of AblationCAM | |
| https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf | |
| Ablate individual activations, and then measure the drop in the target scores. | |
| In the current implementation, the target layer activations is cached, so it won't be re-computed. | |
| However layers before it, if any, will not be cached. | |
| This means that if the target layer is a large block, for example model.featuers (in vgg), there will | |
| be a large save in run time. | |
| Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass, | |
| it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster. | |
| The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method | |
| (to be improved). The default 1.0 value means that all channels will be ablated. | |
| """ | |
| class AblationCAM(BaseCAM): | |
| def __init__(self, | |
| model: torch.nn.Module, | |
| target_layers: List[torch.nn.Module], | |
| reshape_transform: Callable = None, | |
| ablation_layer: torch.nn.Module = AblationLayer(), | |
| batch_size: int = 32, | |
| ratio_channels_to_ablate: float = 1.0) -> None: | |
| super(AblationCAM, self).__init__(model, | |
| target_layers, | |
| reshape_transform, | |
| uses_gradients=False) | |
| self.batch_size = batch_size | |
| self.ablation_layer = ablation_layer | |
| self.ratio_channels_to_ablate = ratio_channels_to_ablate | |
| def save_activation(self, module, input, output) -> None: | |
| """ Helper function to save the raw activations from the target layer """ | |
| self.activations = output | |
| def assemble_ablation_scores(self, | |
| new_scores: list, | |
| original_score: float, | |
| ablated_channels: np.ndarray, | |
| number_of_channels: int) -> np.ndarray: | |
| """ Take the value from the channels that were ablated, | |
| and just set the original score for the channels that were skipped """ | |
| index = 0 | |
| result = [] | |
| sorted_indices = np.argsort(ablated_channels) | |
| ablated_channels = ablated_channels[sorted_indices] | |
| new_scores = np.float32(new_scores)[sorted_indices] | |
| for i in range(number_of_channels): | |
| if index < len(ablated_channels) and ablated_channels[index] == i: | |
| weight = new_scores[index] | |
| index = index + 1 | |
| else: | |
| weight = original_score | |
| result.append(weight) | |
| return result | |
| def get_cam_weights(self, | |
| input_tensor: torch.Tensor, | |
| target_layer: torch.nn.Module, | |
| targets: List[Callable], | |
| activations: torch.Tensor, | |
| grads: torch.Tensor) -> np.ndarray: | |
| # Do a forward pass, compute the target scores, and cache the | |
| # activations | |
| handle = target_layer.register_forward_hook(self.save_activation) | |
| with torch.no_grad(): | |
| outputs = self.model(input_tensor) | |
| handle.remove() | |
| original_scores = np.float32( | |
| [target(output).cpu().item() for target, output in zip(targets, outputs)]) | |
| # Replace the layer with the ablation layer. | |
| # When we finish, we will replace it back, so the | |
| # original model is unchanged. | |
| ablation_layer = self.ablation_layer | |
| replace_layer_recursive(self.model, target_layer, ablation_layer) | |
| number_of_channels = activations.shape[1] | |
| weights = [] | |
| # This is a "gradient free" method, so we don't need gradients here. | |
| with torch.no_grad(): | |
| # Loop over each of the batch images and ablate activations for it. | |
| for batch_index, (target, tensor) in enumerate( | |
| zip(targets, input_tensor)): | |
| new_scores = [] | |
| batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1) | |
| # Check which channels should be ablated. Normally this will be all channels, | |
| # But we can also try to speed this up by using a low | |
| # ratio_channels_to_ablate. | |
| channels_to_ablate = ablation_layer.activations_to_be_ablated( | |
| activations[batch_index, :], self.ratio_channels_to_ablate) | |
| number_channels_to_ablate = len(channels_to_ablate) | |
| for i in tqdm.tqdm( | |
| range( | |
| 0, | |
| number_channels_to_ablate, | |
| self.batch_size)): | |
| if i + self.batch_size > number_channels_to_ablate: | |
| batch_tensor = batch_tensor[:( | |
| number_channels_to_ablate - i)] | |
| # Change the state of the ablation layer so it ablates the next channels. | |
| # TBD: Move this into the ablation layer forward pass. | |
| ablation_layer.set_next_batch( | |
| input_batch_index = batch_index, | |
| activations = self.activations, | |
| num_channels_to_ablate = batch_tensor.size(0)) | |
| score = [target(o).cpu().item() | |
| for o in self.model(batch_tensor)] | |
| new_scores.extend(score) | |
| ablation_layer.indices = ablation_layer.indices[batch_tensor.size( | |
| 0):] | |
| new_scores = self.assemble_ablation_scores( | |
| new_scores, | |
| original_scores[batch_index], | |
| channels_to_ablate, | |
| number_of_channels) | |
| weights.extend(new_scores) | |
| weights = np.float32(weights) | |
| weights = weights.reshape(activations.shape[:2]) | |
| original_scores = original_scores[:, None] | |
| weights = (original_scores - weights) / original_scores | |
| # Replace the model back to the original state | |
| replace_layer_recursive(self.model, ablation_layer, target_layer) | |
| # Returning the weights from new_scores | |
| return weights | |