Spaces:
Runtime error
Runtime error
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .nn import timestep_embedding | |
| from .unet import UNetModel | |
| from .xf import LayerNorm, Transformer, convert_module_to_f16 | |
| class Text2ImUNet(UNetModel): | |
| """ | |
| A UNetModel that conditions on text with an encoding transformer. | |
| Expects an extra kwarg `tokens` of text. | |
| :param text_ctx: number of text tokens to expect. | |
| :param xf_width: width of the transformer. | |
| :param xf_layers: depth of the transformer. | |
| :param xf_heads: heads in the transformer. | |
| :param xf_final_ln: use a LayerNorm after the output layer. | |
| :param tokenizer: the text tokenizer for sampling/vocab size. | |
| """ | |
| def __init__( | |
| self, | |
| text_ctx, | |
| xf_width, | |
| xf_layers, | |
| xf_heads, | |
| xf_final_ln, | |
| tokenizer, | |
| *args, | |
| cache_text_emb=False, | |
| xf_ar=0.0, | |
| xf_padding=False, | |
| share_unemb=False, | |
| **kwargs, | |
| ): | |
| self.text_ctx = text_ctx | |
| self.xf_width = xf_width | |
| self.xf_ar = xf_ar | |
| self.xf_padding = xf_padding | |
| self.tokenizer = tokenizer | |
| if not xf_width: | |
| super().__init__(*args, **kwargs, encoder_channels=None) | |
| else: | |
| super().__init__(*args, **kwargs, encoder_channels=xf_width) | |
| if self.xf_width: | |
| self.transformer = Transformer( | |
| text_ctx, | |
| xf_width, | |
| xf_layers, | |
| xf_heads, | |
| ) | |
| if xf_final_ln: | |
| self.final_ln = LayerNorm(xf_width) | |
| else: | |
| self.final_ln = None | |
| self.token_embedding = nn.Embedding(self.tokenizer.n_vocab, xf_width) | |
| self.positional_embedding = nn.Parameter(th.empty(text_ctx, xf_width, dtype=th.float32)) | |
| self.transformer_proj = nn.Linear(xf_width, self.model_channels * 4) | |
| if self.xf_padding: | |
| self.padding_embedding = nn.Parameter( | |
| th.empty(text_ctx, xf_width, dtype=th.float32) | |
| ) | |
| if self.xf_ar: | |
| self.unemb = nn.Linear(xf_width, self.tokenizer.n_vocab) | |
| if share_unemb: | |
| self.unemb.weight = self.token_embedding.weight | |
| self.cache_text_emb = cache_text_emb | |
| self.cache = None | |
| def convert_to_fp16(self): | |
| super().convert_to_fp16() | |
| if self.xf_width: | |
| self.transformer.apply(convert_module_to_f16) | |
| self.transformer_proj.to(th.float16) | |
| self.token_embedding.to(th.float16) | |
| self.positional_embedding.to(th.float16) | |
| if self.xf_padding: | |
| self.padding_embedding.to(th.float16) | |
| if self.xf_ar: | |
| self.unemb.to(th.float16) | |
| def get_text_emb(self, tokens, mask): | |
| assert tokens is not None | |
| if self.cache_text_emb and self.cache is not None: | |
| assert ( | |
| tokens == self.cache["tokens"] | |
| ).all(), f"Tokens {tokens.cpu().numpy().tolist()} do not match cache {self.cache['tokens'].cpu().numpy().tolist()}" | |
| return self.cache | |
| xf_in = self.token_embedding(tokens.long()) | |
| xf_in = xf_in + self.positional_embedding[None] | |
| if self.xf_padding: | |
| assert mask is not None | |
| xf_in = th.where(mask[..., None], xf_in, self.padding_embedding[None]) | |
| xf_out = self.transformer(xf_in.to(self.dtype)) | |
| if self.final_ln is not None: | |
| xf_out = self.final_ln(xf_out) | |
| xf_proj = self.transformer_proj(xf_out[:, -1]) | |
| xf_out = xf_out.permute(0, 2, 1) # NLC -> NCL | |
| outputs = dict(xf_proj=xf_proj, xf_out=xf_out) | |
| if self.cache_text_emb: | |
| self.cache = dict( | |
| tokens=tokens, | |
| xf_proj=xf_proj.detach(), | |
| xf_out=xf_out.detach() if xf_out is not None else None, | |
| ) | |
| return outputs | |
| def del_cache(self): | |
| self.cache = None | |
| def forward(self, x, timesteps, tokens=None, mask=None): | |
| hs = [] | |
| emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) | |
| if self.xf_width: | |
| text_outputs = self.get_text_emb(tokens, mask) | |
| xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"] | |
| emb = emb + xf_proj.to(emb) | |
| else: | |
| xf_out = None | |
| h = x.type(self.dtype) | |
| for module in self.input_blocks: | |
| h = module(h, emb, xf_out) | |
| hs.append(h) | |
| h = self.middle_block(h, emb, xf_out) | |
| for module in self.output_blocks: | |
| h = th.cat([h, hs.pop()], dim=1) | |
| h = module(h, emb, xf_out) | |
| h = h.type(x.dtype) | |
| h = self.out(h) | |
| return h | |
| class SuperResText2ImUNet(Text2ImUNet): | |
| """ | |
| A text2im model that performs super-resolution. | |
| Expects an extra kwarg `low_res` to condition on a low-resolution image. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| if "in_channels" in kwargs: | |
| kwargs = dict(kwargs) | |
| kwargs["in_channels"] = kwargs["in_channels"] * 2 | |
| else: | |
| # Curse you, Python. Or really, just curse positional arguments :|. | |
| args = list(args) | |
| args[1] = args[1] * 2 | |
| super().__init__(*args, **kwargs) | |
| def forward(self, x, timesteps, low_res=None, **kwargs): | |
| _, _, new_height, new_width = x.shape | |
| upsampled = F.interpolate( | |
| low_res, (new_height, new_width), mode="bilinear", align_corners=False | |
| ) | |
| x = th.cat([x, upsampled], dim=1) | |
| return super().forward(x, timesteps, **kwargs) | |
| class InpaintText2ImUNet(Text2ImUNet): | |
| """ | |
| A text2im model which can perform inpainting. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| if "in_channels" in kwargs: | |
| kwargs = dict(kwargs) | |
| kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1 | |
| else: | |
| # Curse you, Python. Or really, just curse positional arguments :|. | |
| args = list(args) | |
| args[1] = args[1] * 2 + 1 | |
| super().__init__(*args, **kwargs) | |
| def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs): | |
| if inpaint_image is None: | |
| inpaint_image = th.zeros_like(x) | |
| if inpaint_mask is None: | |
| inpaint_mask = th.zeros_like(x[:, :1]) | |
| return super().forward( | |
| th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1), | |
| timesteps, | |
| **kwargs, | |
| ) | |
| class SuperResInpaintText2ImUnet(Text2ImUNet): | |
| """ | |
| A text2im model which can perform both upsampling and inpainting. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| if "in_channels" in kwargs: | |
| kwargs = dict(kwargs) | |
| kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1 | |
| else: | |
| # Curse you, Python. Or really, just curse positional arguments :|. | |
| args = list(args) | |
| args[1] = args[1] * 3 + 1 | |
| super().__init__(*args, **kwargs) | |
| def forward( | |
| self, | |
| x, | |
| timesteps, | |
| inpaint_image=None, | |
| inpaint_mask=None, | |
| low_res=None, | |
| **kwargs, | |
| ): | |
| if inpaint_image is None: | |
| inpaint_image = th.zeros_like(x) | |
| if inpaint_mask is None: | |
| inpaint_mask = th.zeros_like(x[:, :1]) | |
| _, _, new_height, new_width = x.shape | |
| upsampled = F.interpolate( | |
| low_res, (new_height, new_width), mode="bilinear", align_corners=False | |
| ) | |
| return super().forward( | |
| th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1), | |
| timesteps, | |
| **kwargs, | |
| ) | |