Spaces:
Runtime error
Runtime error
| from einops import rearrange | |
| import torch | |
| from torch import Tensor | |
| import torch.nn.functional as F | |
| from flash_attn import flash_attn_varlen_func | |
| from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | |
| def _upad_input(query_layer, key_layer, value_layer, query_mask, key_mask, query_length): | |
| def _get_unpad_data(attention_mask): | |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
| max_seqlen_in_batch = seqlens_in_batch.max().item() | |
| cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | |
| return ( | |
| indices, | |
| cu_seqlens, | |
| max_seqlen_in_batch, | |
| ) | |
| indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(key_mask) | |
| _, q_seq_len, num_query_heads, _ = query_layer.shape | |
| batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
| key_layer = index_first_axis( | |
| key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), | |
| indices_k, | |
| ) | |
| value_layer = index_first_axis( | |
| value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), | |
| indices_k, | |
| ) | |
| if query_length == kv_seq_len and key_mask is None: | |
| query_layer = index_first_axis( | |
| query_layer.reshape(batch_size * kv_seq_len, num_query_heads, head_dim), | |
| indices_k, | |
| ) | |
| cu_seqlens_q = cu_seqlens_k | |
| max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
| indices_q = indices_k | |
| elif query_length == 1: | |
| max_seqlen_in_batch_q = 1 | |
| cu_seqlens_q = torch.arange( | |
| batch_size + 1, dtype=torch.int32, device=query_layer.device | |
| ) # There is a memcpy here, that is very bad. | |
| indices_q = cu_seqlens_q[:-1] | |
| query_layer = query_layer.squeeze(1) | |
| else: | |
| # The -q_len: slice assumes left padding. | |
| query_mask = query_mask[:, -query_length:] | |
| query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, _ = unpad_input(query_layer, query_mask) | |
| return ( | |
| query_layer, | |
| key_layer, | |
| value_layer, | |
| indices_q, | |
| (cu_seqlens_q, cu_seqlens_k), | |
| (max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
| ) | |
| def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None, drop_mask: Tensor | None = None) -> Tensor: | |
| q, k = apply_rope(q, k, pe) | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| B, L, H, D = q.shape | |
| if drop_mask is None: # todo: remove drop mask | |
| drop_mask = attn_mask | |
| ( | |
| query_states, | |
| key_states, | |
| value_states, | |
| indices_q, | |
| cu_seq_lens, | |
| max_seq_lens, | |
| ) = _upad_input(q, k, v, attn_mask, drop_mask, L) | |
| cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
| attn_output_unpad = flash_attn_varlen_func( | |
| query_states, | |
| key_states, | |
| value_states, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_in_batch_q, | |
| max_seqlen_k=max_seqlen_in_batch_k, | |
| dropout_p=0.0, | |
| causal=False, | |
| ) | |
| x = pad_input(attn_output_unpad, indices_q, B, L) | |
| x = rearrange(x, "B L H D -> B L (H D)") | |
| return x | |
| def rope(pos: Tensor, dim: int, theta: int) -> Tensor: | |
| assert dim % 2 == 0 | |
| scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim | |
| omega = 1.0 / (theta**scale) | |
| out = torch.einsum("...n,d->...nd", pos, omega) | |
| out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) | |
| out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | |
| return out.float() | |
| def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: | |
| xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | |
| xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | |
| xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | |
| xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | |
| return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) | |