File size: 7,272 Bytes
3f8fc3b 9d7ec24 30b5c8e 9d7ec24 3f8fc3b 2beb7c3 9d7ec24 2beb7c3 30b5c8e 2beb7c3 30b5c8e 2beb7c3 30b5c8e 2beb7c3 30b5c8e 5655e4a 30b5c8e 2beb7c3 30b5c8e 2beb7c3 9d7ec24 2beb7c3 9d7ec24 2beb7c3 9d7ec24 2beb7c3 9d7ec24 2beb7c3 9d7ec24 2beb7c3 9d7ec24 2beb7c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
"""
# Welcome to Streamlit!
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
forums](https://discuss.streamlit.io).
In the meantime, below is an example of what you can do with just a few lines of code:
"""
import os
# Redirect Streamlit and Matplotlib config to temporary, writable directories
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
os.environ["MPLCONFIGDIR"] = "/tmp"
import streamlit as st
# Disable Streamlit usage stats and file watcher to prevent config writes
st._config.set_option("browser.gatherUsageStats", False)
st._config.set_option("server.fileWatcherType", "none")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image as Img
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from lime.lime_image import LimeImageExplainer
from skimage.segmentation import mark_boundaries
import shap
from shap import GradientExplainer
device = "cuda" if torch.cuda.is_available() else "cpu"
num_classes = 4
image_size = (224, 224)
# Define CNN Model
class MyModel(nn.Module):
def __init__(self, num_classes=4):
super(MyModel, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(512 * 3 * 3, 1024),
nn.ReLU(inplace=True),
nn.Dropout(0.25),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.25),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# Load model
model = MyModel(num_classes=num_classes).to(device)
try:
model.load_state_dict(torch.load("src/brainCNNpytorch_model", map_location=torch.device('cpu')))
except FileNotFoundError:
st.error("Model file 'brainCNNpytorch_model' not found. Please upload the file correctly.")
st.stop()
model.eval()
# Label dictionary
label_dict = {0: "Meningioma", 1: "Glioma", 2: "No Tumor", 3: "Pituitary"}
# Preprocessing
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0).to(device)
# Grad-CAM
def visualize_grad_cam(image, model, target_layer, label):
img_np = np.array(image) / 255.0
img_np = cv2.resize(img_np, (224, 224))
img_tensor = preprocess_image(image)
with torch.no_grad():
output = model(img_tensor)
_, target_index = torch.max(output, 1)
cam = GradCAM(model=model, target_layers=[target_layer])
grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(target_index.item())])[0]
grayscale_cam_resized = cv2.resize(grayscale_cam, (224, 224))
visualization = show_cam_on_image(img_np, grayscale_cam_resized, use_rgb=True)
return visualization
# LIME
def model_predict(images):
preprocessed_images = [preprocess_image(Img.fromarray(img)) for img in images]
images_tensor = torch.cat(preprocessed_images).to(device)
with torch.no_grad():
logits = model(images_tensor)
probabilities = F.softmax(logits, dim=1)
return probabilities.cpu().numpy()
def visualize_lime(image):
explainer = LimeImageExplainer()
original_image = np.array(image)
explanation = explainer.explain_instance(original_image, model_predict, top_labels=3, hide_color=0, num_samples=100)
top_label = explanation.top_labels[0]
temp, mask = explanation.get_image_and_mask(label=top_label, positive_only=True, num_features=10, hide_rest=False)
return mark_boundaries(temp / 255.0, mask)
# SHAP
def visualize_shap(image):
img_tensor = preprocess_image(image).to(device)
if img_tensor.shape[1] == 1:
img_tensor = img_tensor.expand(-1, 3, -1, -1)
background = torch.cat([img_tensor] * 10, dim=0)
explainer = shap.GradientExplainer(model, background)
shap_values = explainer.shap_values(img_tensor)
# Prepare image
img_numpy = img_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
shap_values = np.array(shap_values[0]).squeeze()
shap_values = shap_values / np.abs(shap_values).max() if np.abs(shap_values).max() != 0 else shap_values
shap_values = np.transpose(shap_values, (1, 2, 0))
# Plotting
fig, ax = plt.subplots(figsize=(5, 5))
ax.imshow(img_numpy)
ax.imshow(shap_values, cmap='jet', alpha=0.5)
ax.axis('off')
plt.tight_layout()
return fig
# Streamlit UI
st.title("Brain Tumor Classification with Grad-CAM, LIME, and SHAP")
uploaded_file = st.file_uploader("Upload an MRI Image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
image = Img.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
if st.button("Classify & Visualize"):
image_tensor = preprocess_image(image)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
label = label_dict[predicted.item()]
st.write(f"### Prediction: {label}")
# Grad-CAM
target_layer = model.features[16] # Last Conv layer
grad_cam_img = visualize_grad_cam(image, model, target_layer, label)
# LIME
lime_img = visualize_lime(image)
# SHAP is shown directly in Streamlit
col1, col2, col3 = st.columns(3)
with col1:
st.subheader("Grad-CAM")
st.image(grad_cam_img, caption="Grad-CAM", use_container_width=True)
with col2:
st.subheader("LIME")
st.image(lime_img, caption="LIME Explanation", use_container_width=True)
with col3:
st.subheader("SHAP")
fig = visualize_shap(image)
st.pyplot(fig) |