Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision.transforms as transforms | |
| import torchvision.datasets as datasets | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import streamlit as st | |
| # Define the Generator | |
| class Generator(nn.Module): | |
| def __init__(self, input_dim, output_dim): | |
| super(Generator, self).__init__() | |
| self.model = nn.Sequential( | |
| nn.Linear(input_dim, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, output_dim), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| # Define the Discriminator | |
| class Discriminator(nn.Module): | |
| def __init__(self, input_dim): | |
| super(Discriminator, self).__init__() | |
| self.model = nn.Sequential( | |
| nn.Linear(input_dim, 256), | |
| nn.LeakyReLU(0.2), | |
| nn.Linear(256, 128), | |
| nn.LeakyReLU(0.2), | |
| nn.Linear(128, 1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| # Hyperparameters | |
| latent_dim = 100 | |
| image_dim = 28 * 28 # MNIST images are 28x28 pixels | |
| lr = 0.0002 | |
| batch_size = 64 | |
| # Prepare the data | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ]) | |
| dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| # Initialize the models | |
| generator = Generator(latent_dim, image_dim) | |
| discriminator = Discriminator(image_dim) | |
| # Optimizers | |
| optimizer_G = optim.Adam(generator.parameters(), lr=lr) | |
| optimizer_D = optim.Adam(discriminator.parameters(), lr=lr) | |
| # Loss function | |
| criterion = nn.BCELoss() | |
| # Streamlit interface | |
| st.title("GAN with PyTorch and Hugging Face") | |
| st.write("Training a GAN to generate MNIST digits") | |
| # Slider for epochs | |
| epochs = st.slider("Number of Epochs", min_value=1, max_value=100, value=50) | |
| train_gan = st.button("Train GAN") | |
| if train_gan: | |
| # Training loop | |
| for epoch in range(epochs): | |
| for i, (imgs, _) in enumerate(dataloader): | |
| # Prepare real and fake data | |
| real_imgs = imgs.view(imgs.size(0), -1) | |
| real_labels = torch.ones(imgs.size(0), 1) | |
| fake_labels = torch.zeros(imgs.size(0), 1) | |
| z = torch.randn(imgs.size(0), latent_dim) | |
| fake_imgs = generator(z) | |
| # Train Discriminator | |
| optimizer_D.zero_grad() | |
| real_loss = criterion(discriminator(real_imgs), real_labels) | |
| fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels) | |
| d_loss = real_loss + fake_loss | |
| d_loss.backward() | |
| optimizer_D.step() | |
| # Train Generator | |
| optimizer_G.zero_grad() | |
| g_loss = criterion(discriminator(fake_imgs), real_labels) | |
| g_loss.backward() | |
| optimizer_G.step() | |
| st.write(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}") | |
| st.write("Training completed") | |
| # Generate and display images | |
| z = torch.randn(16, latent_dim) | |
| generated_imgs = generator(z).view(-1, 1, 28, 28).detach().cpu().numpy() | |
| fig, axes = plt.subplots(4, 4, figsize=(8, 8)) | |
| for img, ax in zip(generated_imgs, axes.flatten()): | |
| ax.imshow(img.reshape(28, 28), cmap="gray") | |
| ax.axis('off') | |
| st.pyplot(fig) | |
| else: | |
| st.write("Use the slider to select the number of epochs and click the button to start training the GAN") | |