Spaces:
Runtime error
Runtime error
| from tqdm import tqdm | |
| import numpy as np | |
| from pathlib import Path | |
| import json | |
| # torch | |
| import torch | |
| from einops import repeat | |
| # vision imports | |
| from PIL import Image | |
| # dalle related classes and utils | |
| from dalle_pytorch import VQGanVAE, DALLE | |
| from dalle_pytorch.tokenizer import tokenizer | |
| from io import BytesIO | |
| import gradio as gr | |
| # load DALL-E | |
| def exists(val): | |
| return val is not None | |
| models = json.load(open("model_paths.json")) | |
| vae = VQGanVAE(None, None) | |
| dalles = {} | |
| for name, model_path in models.items(): | |
| assert Path(model_path).exists(), 'trained DALL-E '+model_path+' must exist' | |
| load_obj = torch.load(model_path) | |
| dalle_params, _, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights') | |
| dalle_params.pop('vae', None) # cleanup later | |
| dalle = DALLE(vae = vae, **dalle_params).cuda() | |
| dalle.load_state_dict(weights) | |
| dalles[name] = dalle | |
| batch_size = 4 | |
| top_k = 0.9 | |
| # generate images | |
| image_size = vae.image_size | |
| def generate(text): | |
| text_input = text | |
| num_images = 4 | |
| dalle_name = "weird_car" | |
| dalle = dalles[dalle_name] | |
| text = tokenizer.tokenize([text_input], dalle.text_seq_len).cuda() | |
| text = repeat(text, '() n -> b n', b = num_images) | |
| outputs = [] | |
| for text_chunk in tqdm(text.split(batch_size), desc = f'generating images for - {text}'): | |
| output = dalle.generate_images(text_chunk, filter_thres = top_k) | |
| outputs.append(output) | |
| outputs = torch.cat(outputs) | |
| response = [] | |
| for image in tqdm(outputs, desc = 'saving images'): | |
| np_image = np.moveaxis(image.cpu().numpy(), 0, -1) | |
| formatted = (np_image * 255).astype('uint8') | |
| img = Image.fromarray(formatted) | |
| response.append(img) | |
| return response | |
| iface = gr.Interface(fn=generate, inputs="text", outputs=gr.outputs.Carousel("image")) | |
| iface.launch(share=True) |