import math import torch from torch import Tensor, IntTensor, BoolTensor from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature from einops import rearrange def exist(item): return item is not None def freeze(model): for p in model.parameters(): p.requires_grad = False return model @torch.autocast(device_type="cuda", enabled=False) def get_freqs(dim, max_period=10000.0): freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim ) return freqs def fractal_flatten(x, rope, shape, block_mask=False): if block_mask: pixel_size = 8 x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) x = x.flatten(0, 1) rope = rope.flatten(0, 1) else: x = x.flatten(0, 2) rope = rope.flatten(0, 2) return x, rope def fractal_unflatten(x, shape, block_mask=False): if block_mask: pixel_size = 8 x = x.reshape(-1, pixel_size**2, *x.shape[1:]) x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) else: x = x.reshape(*shape, *x.shape[1:]) return x def local_patching(x, shape, group_size, dim=0): duration, height, width = shape g1, g2, g3 = group_size x = x.reshape( *x.shape[:dim], duration // g1, g1, height // g2, g2, width // g3, g3, *x.shape[dim + 3 :] ) x = x.permute( *range(len(x.shape[:dim])), dim, dim + 2, dim + 4, dim + 1, dim + 3, dim + 5, *range(dim + 6, len(x.shape)) ) x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) return x def local_merge(x, shape, group_size, dim=0): duration, height, width = shape g1, g2, g3 = group_size x = x.reshape( *x.shape[:dim], duration // g1, height // g2, width // g3, g1, g2, g3, *x.shape[dim + 2 :] ) x = x.permute( *range(len(x.shape[:dim])), dim, dim + 3, dim + 1, dim + 4, dim + 2, dim + 5, *range(dim + 6, len(x.shape)) ) x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) return x def fast_sta_nabla( T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" ) -> Tensor: l = torch.Tensor([T, H, W]).amax() r = torch.arange(0, l, 1, dtype=torch.int16, device=device) mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() sta_t, sta_h, sta_w = ( mat[:T, :T].flatten(), mat[:H, :H].flatten(), mat[:W, :W].flatten(), ) sta_t = sta_t <= wT // 2 sta_h = sta_h <= wH // 2 sta_w = sta_w <= wW // 2 sta_hw = ( (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) .reshape(H, H, W, W) .transpose(1, 2) .flatten() ) sta = ( (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) .reshape(T, T, H * W, H * W) .transpose(1, 2) ) return sta.reshape(T * H * W, T * H * W) def nablaT_v2( q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9, ) -> BlockMask: # Map estimation B, h, S, D = q.shape s1 = S // 64 qa = q.reshape(B, h, s1, 64, D).mean(-2) ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) map = qa @ ka map = torch.softmax(map / math.sqrt(D), dim=-1) # Map binarization vals, inds = map.sort(-1) cvals = vals.cumsum_(-1) mask = (cvals >= 1 - thr).int() mask = mask.gather(-1, inds.argsort(-1)) mask = torch.logical_or(mask, sta) # BlockMask creation kv_nb = mask.sum(-1).to(torch.int32) kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) return BlockMask.from_kv_blocks( torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None )