rahul7star's picture
Migrated from GitHub
0084610 verified
import math
import torch
from torch import Tensor, IntTensor, BoolTensor
from torch.nn.attention.flex_attention import BlockMask, _mask_mod_signature
from einops import rearrange
def exist(item):
return item is not None
def freeze(model):
for p in model.parameters():
p.requires_grad = False
return model
@torch.autocast(device_type="cuda", enabled=False)
def get_freqs(dim, max_period=10000.0):
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=dim, dtype=torch.float32)
/ dim
)
return freqs
def fractal_flatten(x, rope, shape, block_mask=False):
if block_mask:
pixel_size = 8
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
x = x.flatten(0, 1)
rope = rope.flatten(0, 1)
else:
x = x.flatten(0, 2)
rope = rope.flatten(0, 2)
return x, rope
def fractal_unflatten(x, shape, block_mask=False):
if block_mask:
pixel_size = 8
x = x.reshape(-1, pixel_size**2, *x.shape[1:])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
else:
x = x.reshape(*shape, *x.shape[1:])
return x
def local_patching(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
g1,
height // g2,
g2,
width // g3,
g3,
*x.shape[dim + 3 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 2,
dim + 4,
dim + 1,
dim + 3,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
return x
def local_merge(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
height // g2,
width // g3,
g1,
g2,
g3,
*x.shape[dim + 2 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 3,
dim + 1,
dim + 4,
dim + 2,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
return x
def fast_sta_nabla(
T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda"
) -> Tensor:
l = torch.Tensor([T, H, W]).amax()
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
sta_t, sta_h, sta_w = (
mat[:T, :T].flatten(),
mat[:H, :H].flatten(),
mat[:W, :W].flatten(),
)
sta_t = sta_t <= wT // 2
sta_h = sta_h <= wH // 2
sta_w = sta_w <= wW // 2
sta_hw = (
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
.reshape(H, H, W, W)
.transpose(1, 2)
.flatten()
)
sta = (
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
.reshape(T, T, H * W, H * W)
.transpose(1, 2)
)
return sta.reshape(T * H * W, T * H * W)
def nablaT_v2(
q: Tensor,
k: Tensor,
sta: Tensor,
thr: float = 0.9,
) -> BlockMask:
# Map estimation
B, h, S, D = q.shape
s1 = S // 64
qa = q.reshape(B, h, s1, 64, D).mean(-2)
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
map = qa @ ka
map = torch.softmax(map / math.sqrt(D), dim=-1)
# Map binarization
vals, inds = map.sort(-1)
cvals = vals.cumsum_(-1)
mask = (cvals >= 1 - thr).int()
mask = mask.gather(-1, inds.argsort(-1))
mask = torch.logical_or(mask, sta)
# BlockMask creation
kv_nb = mask.sum(-1).to(torch.int32)
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
return BlockMask.from_kv_blocks(
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
)