Spaces:
Runtime error
Runtime error
| from typing import Union, List | |
| import gradio as gr | |
| import matplotlib | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from pytorch_lightning.utilities.types import EPOCH_OUTPUT | |
| matplotlib.use('Agg') | |
| import numpy as np | |
| from PIL import Image | |
| import albumentations as A | |
| import albumentations.pytorch as al_pytorch | |
| import torchvision | |
| from pl_bolts.models.gans import Pix2Pix | |
| """ Class """ | |
| class OverpoweredPix2Pix(Pix2Pix): | |
| def validation_step(self, batch, batch_idx): | |
| """ Validation step """ | |
| real, condition = batch | |
| with torch.no_grad(): | |
| loss = self._disc_step(real, condition) | |
| self.log("val_PatchGAN_loss", loss) | |
| loss = self._gen_step(real, condition) | |
| self.log("val_generator_loss", loss) | |
| return { | |
| 'sketch': real, | |
| 'colour': condition | |
| } | |
| def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: | |
| sketch = outputs[0]['sketch'] | |
| colour = outputs[0]['colour'] | |
| with torch.no_grad(): | |
| gen_coloured = self.gen(sketch) | |
| grid_image = torchvision.utils.make_grid( | |
| [ | |
| sketch[0], colour[0], gen_coloured[0], | |
| ], | |
| normalize=True | |
| ) | |
| self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)}', grid_image, self.current_epoch) | |
| """ Load the model """ | |
| model_checkpoint_path = "model/lightning_bolts_model/epoch=99-step=89000.ckpt" | |
| # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt" | |
| # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth" | |
| model = OverpoweredPix2Pix.load_from_checkpoint( | |
| model_checkpoint_path | |
| ) | |
| model_chk = torch.load( | |
| model_checkpoint_path, map_location=torch.device('cpu') | |
| ) | |
| # model = gen().load_state_dict(model_chk) | |
| model.eval() | |
| def greet(name): | |
| return "Hello " + name + "!!" | |
| def predict(img: Image): | |
| # transform img | |
| image = np.asarray(img) | |
| # image = image[:, image.shape[1] // 2:, :] | |
| # use on inference | |
| inference_transform = A.Compose([ | |
| A.Resize(width=256, height=256), | |
| A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), | |
| al_pytorch.ToTensorV2(), | |
| ]) | |
| # inverse_transform = A.Compose([ | |
| # A.Normalize( | |
| # mean=[0.485, 0.456, 0.406], | |
| # std=[0.229, 0.224, 0.225] | |
| # ), | |
| # ]) | |
| inference_img = inference_transform( | |
| image=image | |
| )['image'].unsqueeze(0) | |
| with torch.no_grad(): | |
| result = model.gen(inference_img) | |
| # torchvision.utils.save_image(inference_img, "inference_image.png", normalize=True) | |
| torchvision.utils.save_image(result, "inference_image.png", normalize=True) | |
| """ | |
| result_grid = torchvision.utils.make_grid( | |
| [result[0]], | |
| normalize=True | |
| ) | |
| # plt.imsave("coloured_grid.png", (result_grid.permute(1,2,0).detach().numpy()*255).astype(int)) | |
| torchvision.utils.save_image( | |
| result_grid, "coloured_image.png", normalize=True | |
| ) | |
| """ | |
| return "inference_image.png" # 'coloured_image.png', | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.inputs.Image(type="pil"), | |
| #inputs="sketchpad", | |
| examples=[ | |
| "examples/thesis_test.png", | |
| "examples/thesis_test2.png", | |
| "examples/thesis1.png", | |
| "examples/thesis4.png", | |
| "examples/thesis5.png", | |
| "examples/thesis6.png", | |
| # "examples/1000000.png" | |
| ], | |
| outputs=gr.outputs.Image(type="pil",), | |
| #outputs=[ | |
| # "image", | |
| # # "image" | |
| #], | |
| title="Colour your sketches!", | |
| description=" Upload a sketch and the conditional gan will colour it for you!", | |
| article="WIP repo lives here - https://github.com/nmud19/thesisGAN " | |
| ) | |
| iface.launch() | |