Spaces:
Sleeping
Sleeping
| import matplotlib | |
| from matplotlib import pyplot as plt | |
| from matplotlib.lines import Line2D | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torchvision.transforms import Compose, Normalize, ToTensor | |
| from typing import List, Dict | |
| import math | |
| def preprocess_image( | |
| img: np.ndarray, mean=[ | |
| 0.5, 0.5, 0.5], std=[ | |
| 0.5, 0.5, 0.5]) -> torch.Tensor: | |
| preprocessing = Compose([ | |
| ToTensor(), | |
| Normalize(mean=mean, std=std) | |
| ]) | |
| return preprocessing(img.copy()).unsqueeze(0) | |
| def deprocess_image(img): | |
| """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ | |
| img = img - np.mean(img) | |
| img = img / (np.std(img) + 1e-5) | |
| img = img * 0.1 | |
| img = img + 0.5 | |
| img = np.clip(img, 0, 1) | |
| return np.uint8(img * 255) | |
| def show_cam_on_image(img: np.ndarray, | |
| mask: np.ndarray, | |
| use_rgb: bool = False, | |
| colormap: int = cv2.COLORMAP_JET, | |
| image_weight: float = 0.5) -> np.ndarray: | |
| """ This function overlays the cam mask on the image as an heatmap. | |
| By default the heatmap is in BGR format. | |
| :param img: The base image in RGB or BGR format. | |
| :param mask: The cam mask. | |
| :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. | |
| :param colormap: The OpenCV colormap to be used. | |
| :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. | |
| :returns: The default image with the cam overlay. | |
| """ | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
| if use_rgb: | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| heatmap = np.float32(heatmap) / 255 | |
| if np.max(img) > 1: | |
| raise Exception( | |
| "The input image should np.float32 in the range [0, 1]") | |
| if image_weight < 0 or image_weight > 1: | |
| raise Exception( | |
| f"image_weight should be in the range [0, 1].\ | |
| Got: {image_weight}") | |
| cam = (1 - image_weight) * heatmap + image_weight * img | |
| cam = cam / np.max(cam) | |
| return np.uint8(255 * cam) | |
| def create_labels_legend(concept_scores: np.ndarray, | |
| labels: Dict[int, str], | |
| top_k=2): | |
| concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k] | |
| concept_labels_topk = [] | |
| for concept_index in range(concept_categories.shape[0]): | |
| categories = concept_categories[concept_index, :] | |
| concept_labels = [] | |
| for category in categories: | |
| score = concept_scores[concept_index, category] | |
| label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}" | |
| concept_labels.append(label) | |
| concept_labels_topk.append("\n".join(concept_labels)) | |
| return concept_labels_topk | |
| def show_factorization_on_image(img: np.ndarray, | |
| explanations: np.ndarray, | |
| colors: List[np.ndarray] = None, | |
| image_weight: float = 0.5, | |
| concept_labels: List = None) -> np.ndarray: | |
| """ Color code the different component heatmaps on top of the image. | |
| Every component color code will be magnified according to the heatmap itensity | |
| (by modifying the V channel in the HSV color space), | |
| and optionally create a lagend that shows the labels. | |
| Since different factorization component heatmaps can overlap in principle, | |
| we need a strategy to decide how to deal with the overlaps. | |
| This keeps the component that has a higher value in it's heatmap. | |
| :param img: The base image RGB format. | |
| :param explanations: A tensor of shape num_componetns x height x width, with the component visualizations. | |
| :param colors: List of R, G, B colors to be used for the components. | |
| If None, will use the gist_rainbow cmap as a default. | |
| :param image_weight: The final result is image_weight * img + (1-image_weight) * visualization. | |
| :concept_labels: A list of strings for every component. If this is paseed, a legend that shows | |
| the labels and their colors will be added to the image. | |
| :returns: The visualized image. | |
| """ | |
| n_components = explanations.shape[0] | |
| if colors is None: | |
| # taken from https://github.com/edocollins/DFF/blob/master/utils.py | |
| _cmap = plt.cm.get_cmap('gist_rainbow') | |
| colors = [ | |
| np.array( | |
| _cmap(i)) for i in np.arange( | |
| 0, | |
| 1, | |
| 1.0 / | |
| n_components)] | |
| concept_per_pixel = explanations.argmax(axis=0) | |
| masks = [] | |
| for i in range(n_components): | |
| mask = np.zeros(shape=(img.shape[0], img.shape[1], 3)) | |
| mask[:, :, :] = colors[i][:3] | |
| explanation = explanations[i] | |
| explanation[concept_per_pixel != i] = 0 | |
| mask = np.uint8(mask * 255) | |
| mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV) | |
| mask[:, :, 2] = np.uint8(255 * explanation) | |
| mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB) | |
| mask = np.float32(mask) / 255 | |
| masks.append(mask) | |
| mask = np.sum(np.float32(masks), axis=0) | |
| result = img * image_weight + mask * (1 - image_weight) | |
| result = np.uint8(result * 255) | |
| if concept_labels is not None: | |
| px = 1 / plt.rcParams['figure.dpi'] # pixel in inches | |
| fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px)) | |
| plt.rcParams['legend.fontsize'] = int( | |
| 14 * result.shape[0] / 256 / max(1, n_components / 6)) | |
| lw = 5 * result.shape[0] / 256 | |
| lines = [Line2D([0], [0], color=colors[i], lw=lw) | |
| for i in range(n_components)] | |
| plt.legend(lines, | |
| concept_labels, | |
| mode="expand", | |
| fancybox=True, | |
| shadow=True) | |
| plt.tight_layout(pad=0, w_pad=0, h_pad=0) | |
| plt.axis('off') | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| plt.close(fig=fig) | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| data = cv2.resize(data, (result.shape[1], result.shape[0])) | |
| result = np.hstack((result, data)) | |
| return result | |
| def scale_cam_image(cam, target_size=None): | |
| result = [] | |
| for img in cam: | |
| img = img - np.min(img) | |
| img = img / (1e-7 + np.max(img)) | |
| if target_size is not None: | |
| img = cv2.resize(img, target_size) | |
| result.append(img) | |
| result = np.float32(result) | |
| return result | |
| def scale_accross_batch_and_channels(tensor, target_size): | |
| batch_size, channel_size = tensor.shape[:2] | |
| reshaped_tensor = tensor.reshape( | |
| batch_size * channel_size, *tensor.shape[2:]) | |
| result = scale_cam_image(reshaped_tensor, target_size) | |
| result = result.reshape( | |
| batch_size, | |
| channel_size, | |
| target_size[1], | |
| target_size[0]) | |
| return result | |