Spaces:
Running
Running
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from torch.utils.data import DataLoader | |
| from sklearn.metrics import confusion_matrix | |
| import numpy as np | |
| # Device configuration | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Streamlit interface | |
| st.title("CNN for Image Classification using CIFAR-10") | |
| st.write(""" | |
| This application demonstrates how to build and train a Convolutional Neural Network (CNN) for image classification using the CIFAR-10 dataset. You can adjust hyperparameters, visualize sample images, and see the model's performance. | |
| """) | |
| # Hyperparameters | |
| num_epochs = st.sidebar.slider("Number of epochs", 1, 20, 10) | |
| batch_size = st.sidebar.slider("Batch size", 10, 200, 100, step=10) | |
| learning_rate = st.sidebar.slider("Learning rate", 0.0001, 0.01, 0.001, step=0.0001) | |
| # CIFAR-10 dataset | |
| transform = transforms.Compose( | |
| [transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
| train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, | |
| download=True, transform=transform) | |
| test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, | |
| download=True, transform=transform) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| # Display some sample images | |
| st.write("## Sample Images from CIFAR-10 Dataset") | |
| sample_images, sample_labels = next(iter(train_loader)) | |
| fig, axes = plt.subplots(1, 6, figsize=(15, 5)) | |
| for i in range(6): | |
| axes[i].imshow(np.transpose(sample_images[i].numpy(), (1, 2, 0))) | |
| axes[i].set_title(f'Label: {sample_labels[i].item()}') | |
| axes[i].axis('off') | |
| st.pyplot(fig) | |
| # Class distribution | |
| st.write("## Class Distribution in CIFAR-10 Dataset") | |
| class_names = train_dataset.classes | |
| class_counts = np.bincount([sample_labels[i].item() for i in range(len(sample_labels))]) | |
| fig, ax = plt.subplots() | |
| sns.barplot(x=class_names, y=class_counts, ax=ax) | |
| ax.set_ylabel('Count') | |
| ax.set_title('Class Distribution') | |
| st.pyplot(fig) | |
| # Define a Convolutional Neural Network | |
| class CNN(nn.Module): | |
| def __init__(self): | |
| super(CNN, self).__init__() | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2)) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d(32, 64, kernel_size=3), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2)) | |
| # Automatically determine the size of the flattened features after convolution and pooling | |
| self._to_linear = None | |
| self.convs(torch.randn(1, 3, 32, 32)) | |
| self.fc1 = nn.Linear(self._to_linear, 600) | |
| self.drop = nn.Dropout2d(0.25) | |
| self.fc2 = nn.Linear(600, 100) | |
| self.fc3 = nn.Linear(100, 10) | |
| def convs(self, x): | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| if self._to_linear is None: | |
| self._to_linear = x.view(x.size(0), -1).shape[1] | |
| return x | |
| def forward(self, x): | |
| x = self.convs(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.fc1(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.fc3(x) | |
| return x | |
| model = CNN().to(device) | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| # Button to start training | |
| if st.button("Start Training"): | |
| # Lists to store losses and accuracy | |
| train_losses = [] | |
| test_losses = [] | |
| test_accuracies = [] | |
| # Progress bar | |
| progress_bar = st.progress(0) | |
| # Train the model | |
| total_step = len(train_loader) | |
| for epoch in range(num_epochs): | |
| train_loss = 0 | |
| for i, (images, labels) in enumerate(train_loader): | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| # Forward pass | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| # Backward and optimize | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| train_loss /= total_step | |
| train_losses.append(train_loss) | |
| st.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}') | |
| # Test the model | |
| model.eval() | |
| with torch.no_grad(): | |
| test_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| all_labels = [] | |
| all_predictions = [] | |
| for images, labels in test_loader: | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| test_loss += loss.item() | |
| _, predicted = torch.max(outputs.data, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| all_labels.extend(labels.cpu().numpy()) | |
| all_predictions.extend(predicted.cpu().numpy()) | |
| test_loss /= len(test_loader) | |
| test_losses.append(test_loss) | |
| accuracy = 100 * correct / total | |
| test_accuracies.append(accuracy) | |
| st.write(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%') | |
| model.train() | |
| # Update progress bar | |
| progress_bar.progress((epoch + 1) / num_epochs) | |
| # Plotting the loss and accuracy | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| ax1.plot(range(1, num_epochs + 1), train_losses, label='Train Loss') | |
| ax1.plot(range(1, num_epochs + 1), test_losses, label='Test Loss') | |
| ax1.set_xlabel('Epoch') | |
| ax1.set_ylabel('Loss') | |
| ax1.set_title('Training and Test Loss') | |
| ax1.legend() | |
| ax2.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy') | |
| ax2.set_xlabel('Epoch') | |
| ax2.set_ylabel('Accuracy (%)') | |
| ax2.set_title('Test Accuracy') | |
| ax2.legend() | |
| st.pyplot(fig) | |
| # Confusion Matrix | |
| cm = confusion_matrix(all_labels, all_predictions) | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names, cmap='Blues') | |
| ax.set_xlabel('Predicted') | |
| ax.set_ylabel('True') | |
| ax.set_title('Confusion Matrix') | |
| st.pyplot(fig) | |
| # Save the model checkpoint | |
| torch.save(model.state_dict(), 'cnn_model.pth') | |
| st.write("Model training completed and saved as 'cnn_model.pth'") | |