Spaces:
Runtime error
Runtime error
| import os | |
| os.system('pip install -e .') | |
| import gradio as gr | |
| import base64 | |
| from io import BytesIO | |
| # from fastapi import FastAPI | |
| from PIL import Image | |
| import torch as th | |
| from glide_text2im.download import load_checkpoint | |
| from glide_text2im.model_creation import ( | |
| create_model_and_diffusion, | |
| model_and_diffusion_defaults, | |
| model_and_diffusion_defaults_upsampler | |
| ) | |
| """ | |
| credit: follows the gradio glide example by valhalla https://huggingface.co/spaces/valhalla/glide-text2im | |
| """ | |
| # print("Loading models...") | |
| # app = FastAPI() | |
| # This notebook supports both CPU and GPU. | |
| # On CPU, generating one sample may take on the order of 20 minutes. | |
| # On a GPU, it should be under a minute. | |
| has_cuda = th.cuda.is_available() | |
| device = th.device('cpu' if not has_cuda else 'cuda') | |
| # Create base model. | |
| options = model_and_diffusion_defaults() | |
| options['use_fp16'] = has_cuda | |
| options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling | |
| model, diffusion = create_model_and_diffusion(**options) | |
| model.eval() | |
| if has_cuda: | |
| model.convert_to_fp16() | |
| model.to(device) | |
| model.load_state_dict(load_checkpoint('base', device)) | |
| print('total base parameters', sum(x.numel() for x in model.parameters())) | |
| # Create upsampler model. | |
| options_up = model_and_diffusion_defaults_upsampler() | |
| options_up['use_fp16'] = has_cuda | |
| options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling | |
| model_up, diffusion_up = create_model_and_diffusion(**options_up) | |
| model_up.eval() | |
| if has_cuda: | |
| model_up.convert_to_fp16() | |
| model_up.to(device) | |
| model_up.load_state_dict(load_checkpoint('upsample', device)) | |
| print('total upsampler parameters', sum(x.numel() for x in model_up.parameters())) | |
| def get_images(batch: th.Tensor): | |
| """ Display a batch of images inline. """ | |
| scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu() | |
| reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3]) | |
| return Image.fromarray(reshaped.numpy()) | |
| # Create a classifier-free guidance sampling function | |
| guidance_scale = 3.0 | |
| def model_fn(x_t, ts, **kwargs): | |
| half = x_t[: len(x_t) // 2] | |
| combined = th.cat([half, half], dim=0) | |
| model_out = model(combined, ts, **kwargs) | |
| eps, rest = model_out[:, :3], model_out[:, 3:] | |
| cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0) | |
| half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) | |
| eps = th.cat([half_eps, half_eps], dim=0) | |
| return th.cat([eps, rest], dim=1) | |
| # @app.get("/") | |
| def read_root(): | |
| return {"glide!"} | |
| # @app.get("/{generate}") | |
| def sample(prompt): | |
| # Sampling parameters | |
| batch_size = 1 | |
| # Tune this parameter to control the sharpness of 256x256 images. | |
| # A value of 1.0 is sharper, but sometimes results in grainy artifacts. | |
| upsample_temp = 0.997 | |
| ############################## | |
| # Sample from the base model # | |
| ############################## | |
| # Create the text tokens to feed to the model. | |
| tokens = model.tokenizer.encode(prompt) | |
| tokens, mask = model.tokenizer.padded_tokens_and_mask( | |
| tokens, options['text_ctx'] | |
| ) | |
| # Create the classifier-free guidance tokens (empty) | |
| full_batch_size = batch_size * 2 | |
| uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask( | |
| [], options['text_ctx'] | |
| ) | |
| # Pack the tokens together into model kwargs. | |
| model_kwargs = dict( | |
| tokens=th.tensor( | |
| [tokens] * batch_size + [uncond_tokens] * batch_size, device=device | |
| ), | |
| mask=th.tensor( | |
| [mask] * batch_size + [uncond_mask] * batch_size, | |
| dtype=th.bool, | |
| device=device, | |
| ), | |
| ) | |
| # Sample from the base model. | |
| model.del_cache() | |
| samples = diffusion.p_sample_loop( | |
| model_fn, | |
| (full_batch_size, 3, options["image_size"], options["image_size"]), | |
| device=device, | |
| clip_denoised=True, | |
| progress=True, | |
| model_kwargs=model_kwargs, | |
| cond_fn=None, | |
| )[:batch_size] | |
| model.del_cache() | |
| ############################## | |
| # Upsample the 64x64 samples # | |
| ############################## | |
| tokens = model_up.tokenizer.encode(prompt) | |
| tokens, mask = model_up.tokenizer.padded_tokens_and_mask( | |
| tokens, options_up['text_ctx'] | |
| ) | |
| # Create the model conditioning dict. | |
| model_kwargs = dict( | |
| # Low-res image to upsample. | |
| low_res=((samples+1)*127.5).round()/127.5 - 1, | |
| # Text tokens | |
| tokens=th.tensor( | |
| [tokens] * batch_size, device=device | |
| ), | |
| mask=th.tensor( | |
| [mask] * batch_size, | |
| dtype=th.bool, | |
| device=device, | |
| ), | |
| ) | |
| # Sample from the base model. | |
| model_up.del_cache() | |
| up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"]) | |
| up_samples = diffusion_up.ddim_sample_loop( | |
| model_up, | |
| up_shape, | |
| noise=th.randn(up_shape, device=device) * upsample_temp, | |
| device=device, | |
| clip_denoised=True, | |
| progress=True, | |
| model_kwargs=model_kwargs, | |
| cond_fn=None, | |
| )[:batch_size] | |
| model_up.del_cache() | |
| # Show the output | |
| image = get_images(up_samples) | |
| # image = to_base64(image) | |
| # return {"image": image} | |
| return image | |
| def to_base64(pil_image): | |
| buffered = BytesIO() | |
| pil_image.save(buffered, format="JPEG") | |
| return base64.b64encode(buffered.getvalue()) | |
| title = "glide test" | |
| description = "text conditioned image generation demo using openai's GLIDE model (text-guided diffusion model) https://arxiv.org/abs/2112.10741 & https://github.com/openai/glide-text2im/. should take ~500s to run. credit to valhalla for gradio template https://huggingface.co/spaces/valhalla/." | |
| iface = gr.Interface(fn=sample, | |
| inputs=gr.inputs.Textbox(label='enter text'), | |
| outputs=gr.outputs.Image(type="pil", label="..."), | |
| title=title, | |
| description=description) | |
| iface.launch(debug=True,enable_queue=True) | |