WithAnyone's picture
Upload 29 files
4910a8a verified
raw
history blame
1.6 kB
import torch
from einops import rearrange
from torch import Tensor
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
from torch import Tensor
from matplotlib.colors import LinearSegmentedColormap
from dataclasses import dataclass
# a return class
@dataclass
class AttentionReturnQAndMAP:
result: Tensor
attention_map: Tensor
Q: Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask = None, token_aug_idx = -1, text_length = None, image_length = None, return_map = False) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)