Spaces:
Configuration error
Configuration error
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import einops | |
| from einops.layers.torch import Rearrange | |
| def normalize(in_channels): | |
| return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
| def swish(x): | |
| return x*torch.sigmoid(x) | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels=None, activation_fn="relu"): | |
| super(ResBlock, self).__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels if out_channels is None else out_channels | |
| self.norm1 = normalize(in_channels) | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
| self.norm2 = normalize(out_channels) | |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
| if self.in_channels != self.out_channels: | |
| self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) | |
| self.activation_fn = activation_fn | |
| if activation_fn=="relu": | |
| self.actn = nn.ReLU() | |
| def forward(self, x_in): | |
| x = x_in | |
| x = self.norm1(x) | |
| if self.activation_fn=="relu": | |
| x = self.actn(x) | |
| elif self.activation_fn=="swish": | |
| x = swish(x) | |
| x = self.conv1(x) | |
| x = self.norm2(x) | |
| if self.activation_fn=="relu": | |
| x = self.actn(x) | |
| elif self.activation_fn=="swish": | |
| x = swish(x) | |
| x = self.conv2(x) | |
| if self.in_channels != self.out_channels: | |
| x_in = self.conv_out(x_in) | |
| return x + x_in | |
| class Encoder(nn.Module): | |
| def __init__(self, ): | |
| super().__init__() | |
| self.filters = 128 | |
| self.num_res_blocks = 2 | |
| self.ch_mult = [1,1,2,2,4] | |
| self.in_ch_mult = (1,)+tuple(self.ch_mult) | |
| self.embedding_dim = 32 | |
| self.conv_downsample = False | |
| self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False) | |
| blocks = [] | |
| for i in range(len(self.ch_mult)): | |
| block_in_ch = self.filters * self.in_ch_mult[i] | |
| block_out_ch = self.filters * self.ch_mult[i] | |
| for _ in range(self.num_res_blocks): | |
| blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) | |
| block_in_ch = block_out_ch | |
| for _ in range(self.num_res_blocks): | |
| blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) | |
| self.norm1 = normalize(block_in_ch) | |
| self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0) | |
| self.blocks = nn.ModuleList(blocks) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| for i in range(len(self.ch_mult)): | |
| for j in range(self.num_res_blocks): | |
| x = self.blocks[i*2+j](x) | |
| if i < len(self.ch_mult) -1: | |
| x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2)) | |
| x = self.blocks[-2](x) | |
| x = self.blocks[-1](x) | |
| x = self.norm1(x) | |
| x = swish(x) | |
| x = self.conv2(x) | |
| return x | |
| class VectorQuantizer(nn.Module): | |
| def __init__(self, codebook_size=8192, emb_dim=32, beta=None): | |
| super(VectorQuantizer, self).__init__() | |
| self.codebook_size = codebook_size # number of embeddings | |
| self.emb_dim = emb_dim # dimension of embedding | |
| self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) | |
| self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) | |
| self.beta=0.0 | |
| self.z_dim = emb_dim | |
| def forward(self, z): | |
| # preprocess | |
| b, c, h, w = z.size() | |
| flatten = z.permute(0, 2, 3, 1).reshape(-1, c) | |
| codebook = self.embedding.weight | |
| with torch.no_grad(): | |
| tokens = torch.cdist(flatten, codebook).argmin(dim=1) | |
| quantized = F.embedding(tokens, | |
| codebook).view(b, h, w, c).permute(0, 3, 1, 2) | |
| # compute loss | |
| codebook_loss = F.mse_loss(quantized, z.detach()) | |
| commitment_loss = F.mse_loss(quantized.detach(), z) | |
| loss = codebook_loss + self.beta * commitment_loss | |
| # perplexity | |
| counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) | |
| # dist.all_reduce(counts) | |
| p = counts / counts.sum() | |
| perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) | |
| # postprocess | |
| tokens = tokens.view(b, h, w) | |
| quantized = z + (quantized - z).detach() | |
| # quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c)) | |
| return quantized, tokens, loss, perplexity | |
| def get_codebook_feat(self, indices, shape=None): | |
| # input indices: batch*token_num -> (batch*token_num)*1 | |
| # shape: batch, height, width, channel | |
| indices = indices.view(-1,1) | |
| min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) | |
| min_encodings.scatter_(1, indices, 1) | |
| # get quantized latent vectors | |
| z_q = torch.matmul(min_encodings.float(), self.embedding.weight) | |
| if shape is not None: # reshape back to match original input shape | |
| z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() | |
| return z_q | |
| class Decoder(nn.Module): | |
| def __init__(self,): | |
| super().__init__() | |
| self.filters = 128 | |
| self.num_res_blocks = 2 | |
| self.ch_mult = [1,1,2,2,4] | |
| self.in_ch_mult = (1,)+tuple(self.ch_mult) | |
| self.embedding_dim =32 | |
| self.out_channels = 3 | |
| self.in_channels = self.embedding_dim | |
| self.conv_downsample = False | |
| self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1) | |
| blocks = [] | |
| block_in_ch = self.filters * self.ch_mult[-1] | |
| block_out_ch = self.filters * self.ch_mult[-1] | |
| #blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) | |
| for _ in range(self.num_res_blocks): | |
| blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) | |
| upsample_conv_layers = [] | |
| for i in reversed(range(len(self.ch_mult))): | |
| block_out_ch = self.filters * self.ch_mult[i] | |
| for _ in range(self.num_res_blocks): | |
| blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish")) | |
| block_in_ch = block_out_ch | |
| if i > 0: | |
| upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1)) | |
| self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2) | |
| self.norm1 = normalize(block_in_ch) | |
| # self.act_fn | |
| self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1) | |
| self.blocks = nn.ModuleList(blocks) | |
| self.up_convs = nn.ModuleList(upsample_conv_layers) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.blocks[0](x) | |
| x = self.blocks[1](x) | |
| for i in range(len(self.ch_mult)): | |
| for j in range(self.num_res_blocks): | |
| x = self.blocks[2+i*2+j](x) | |
| if i < len(self.ch_mult)-1: | |
| x = self.up_convs[i](x) | |
| #print("pre: x.size()",x.size()) | |
| x = x.permute(0,2,3,1) | |
| x = self.upsample(x) | |
| x = x.permute(0,3,1,2) | |
| #print("post: x.size()", x.size()) | |
| x = self.norm1(x) | |
| x = swish(x) | |
| x = self.conv6(x) | |
| return x | |
| class VQVAE(nn.Module): | |
| def __init__(self, ): | |
| super().__init__() | |
| self.encoder = Encoder() | |
| self.quantizer = VectorQuantizer() | |
| self.decoder = Decoder() | |
| def forward(self, x): | |
| x = self.encoder(x) | |
| quant,tokens, loss, perplexity = self.quantizer(x) | |
| x = self.decoder(quant) | |
| return x | |
| def tokenize(self, x): | |
| batch_shape = x.shape[:-3] | |
| x = x.reshape(-1, *x.shape[-3:]) | |
| x = self.encoder(x) | |
| quant,tokens, loss, perplexity = self.quantizer(x) | |
| return tokens.reshape(*batch_shape, *tokens.shape[1:]) | |
| def decode(self, tokens): | |
| tokens = einops.rearrange(tokens, 'b ... -> b (...)') | |
| b = tokens.shape[0] | |
| if tokens.shape[-1] == 256: | |
| hw = 16 | |
| elif tokens.shape[-1] == 224: | |
| hw = 14 | |
| else: | |
| raise ValueError("Invalid tokens shape") | |
| quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32)) | |
| x = self.decoder(quant) | |
| return x | |
| class VAEDecoder(nn.Module): | |
| def __init__(self, ): | |
| super().__init__() | |
| self.quantizer = VectorQuantizer() | |
| self.decoder = Decoder() | |
| def forward(self, x): | |
| quant = self.quantizer.get_codebook_feat(x,(1,14,14,32)) | |
| x = self.decoder(quant) | |
| return x | |
| def get_tokenizer(): | |
| checkpoint_path = os.path.join( | |
| os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth" | |
| ) | |
| torch_state_dict = torch.load(checkpoint_path) | |
| net = VQVAE() | |
| net.load_state_dict(torch_state_dict) | |
| return net | |