Spaces:
Running
Running
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from torchvision.utils import make_grid | |
| import matplotlib.pyplot as plt | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Hyperparameters | |
| z_dim = 64 | |
| image_dim = 28 * 28 | |
| batch_size = 32 | |
| lr = 3e-4 | |
| # Load Data | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) | |
| ]) | |
| dataset = torchvision.datasets.MNIST(root='dataset/', transform=transform, download=True) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| # Generator | |
| class Generator(nn.Module): | |
| def __init__(self, z_dim, img_dim): | |
| super().__init__() | |
| self.gen = nn.Sequential( | |
| nn.Linear(z_dim, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 1024), | |
| nn.ReLU(), | |
| nn.Linear(1024, img_dim), | |
| nn.Tanh() | |
| ) | |
| def forward(self, x): | |
| return self.gen(x) | |
| # Discriminator | |
| class Discriminator(nn.Module): | |
| def __init__(self, img_dim): | |
| super().__init__() | |
| self.disc = nn.Sequential( | |
| nn.Linear(img_dim, 1024), | |
| nn.ReLU(), | |
| nn.Linear(1024, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| return self.disc(x) | |
| # Initialize generator and discriminator | |
| gen = Generator(z_dim, image_dim).to(device) | |
| disc = Discriminator(image_dim).to(device) | |
| # Optimizers | |
| opt_gen = optim.Adam(gen.parameters(), lr=lr) | |
| opt_disc = optim.Adam(disc.parameters(), lr=lr) | |
| # Loss function | |
| criterion = nn.BCELoss() | |
| # Function to train the model | |
| def train_gan(epochs): | |
| for epoch in range(epochs): | |
| for batch_idx, (real, _) in enumerate(dataloader): | |
| real = real.view(-1, 784).to(device) | |
| batch_size = real.shape[0] | |
| # Train Discriminator | |
| noise = torch.randn(batch_size, z_dim).to(device) | |
| fake = gen(noise) | |
| disc_real = disc(real).view(-1) | |
| lossD_real = criterion(disc_real, torch.ones_like(disc_real)) | |
| disc_fake = disc(fake).view(-1) | |
| lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) | |
| lossD = (lossD_real + lossD_fake) / 2 | |
| disc.zero_grad() | |
| lossD.backward(retain_graph=True) | |
| opt_disc.step() | |
| # Train Generator | |
| output = disc(fake).view(-1) | |
| lossG = criterion(output, torch.ones_like(output)) | |
| gen.zero_grad() | |
| lossG.backward() | |
| opt_gen.step() | |
| st.write(f"Epoch [{epoch+1}/{epochs}] Loss D: {lossD:.4f}, Loss G: {lossG:.4f}") | |
| return fake | |
| # Streamlit interface | |
| st.title("Simple GAN with Epoch Slider") | |
| epochs = st.slider("Number of Epochs", 1, 100, 1) | |
| if st.button("Train GAN"): | |
| fake_images = train_gan(epochs) | |
| fake_images = fake_images.view(-1, 1, 28, 28) | |
| fake_images = make_grid(fake_images, nrow=8, normalize=True) | |
| plt.imshow(fake_images.permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') | |
| st.pyplot(plt.gcf()) | |