Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torchvision.transforms as transforms | |
| import torchvision.utils as vutils | |
| import streamlit as st | |
| # Define the Generator | |
| class Generator(nn.Module): | |
| def __init__(self): | |
| super(Generator, self).__init__() | |
| self.main = nn.Sequential( | |
| nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(True), | |
| nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), | |
| nn.Tanh() | |
| ) | |
| def forward(self, input): | |
| return self.main(input) | |
| # Define the Discriminator | |
| class Discriminator(nn.Module): | |
| def __init__(self): | |
| super(Discriminator, self).__init__() | |
| self.main = nn.Sequential( | |
| nn.Conv2d(1, 64, 4, 2, 1, bias=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(64, 128, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(128), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(128, 256, 4, 2, 1, bias=False), | |
| nn.BatchNorm2d(256), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(256, 1, 4, 1, 0, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, input): | |
| return self.main(input) | |
| # Initialize the models | |
| netG = Generator() | |
| netD = Discriminator() | |
| # Loss function | |
| criterion = nn.BCELoss() | |
| # Optimizers | |
| optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
| optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
| # Device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| netG.to(device) | |
| netD.to(device) | |
| criterion.to(device) | |
| # Function to generate and save images | |
| def generate_images(num_images, noise_dim): | |
| netG.eval() | |
| noise = torch.randn(num_images, noise_dim, 1, 1, device=device) | |
| fake_images = netG(noise) | |
| return fake_images | |
| # Streamlit interface | |
| st.title("Simple GAN with Streamlit") | |
| st.write("Generate images using a simple GAN") | |
| num_images = st.slider("Number of images to generate", min_value=1, max_value=64, value=8) | |
| noise_dim = 100 | |
| if st.button("Generate Images"): | |
| with st.spinner("Generating images..."): | |
| fake_images = generate_images(num_images, noise_dim) | |
| grid = vutils.make_grid(fake_images.cpu(), padding=2, normalize=True) | |
| st.image(grid.permute(1, 2, 0).numpy(), caption="Generated Images") | |