|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from collections import OrderedDict |
|
|
from typing import Optional, Tuple, Callable |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
|
|
|
from .dconv_fwdbwd import dynamic_conv_triton_autograd |
|
|
from .dconv_fwd_cache import dynamic_conv_triton_cache |
|
|
from .dconv_step import causal_conv_step_triton |
|
|
|
|
|
|
|
|
class DynamicShortConvolution(nn.Module): |
|
|
""" |
|
|
Simple wrapper around `nn.Conv1d` that accepts dimension last. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
kernel_size: int, |
|
|
generator_input_size: Optional[int] = None, |
|
|
generator_reduction: Optional[int] = None, |
|
|
generator_activation: str = 'silu', |
|
|
activation: Optional[str] = 'silu', |
|
|
static_conv_init: Callable = None, |
|
|
use_fast_conv1d: bool = True, |
|
|
implementation: str = "naive", |
|
|
) -> DynamicShortConvolution: |
|
|
super().__init__() |
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.generator_input_size = hidden_size if generator_input_size is None else generator_input_size |
|
|
self.generator_hidden_size = hidden_size if generator_reduction is None else (hidden_size // generator_reduction) |
|
|
self.kernel_size = kernel_size |
|
|
self.activation = None |
|
|
self.use_fast_conv1d = use_fast_conv1d |
|
|
self.implementation = implementation |
|
|
|
|
|
if activation is not None: |
|
|
assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet." |
|
|
self.activation = activation |
|
|
|
|
|
self.static_conv_init = static_conv_init |
|
|
|
|
|
self.kernel_generator = nn.Sequential( |
|
|
OrderedDict([ |
|
|
("w1", nn.Linear(self.generator_input_size, self.generator_hidden_size, bias=False)), |
|
|
("act", ACT2FN[generator_activation]), |
|
|
("w2", nn.Linear(self.generator_hidden_size, self.hidden_size * self.kernel_size, bias=True)), |
|
|
]) |
|
|
) |
|
|
self._init_kernel_generator() |
|
|
|
|
|
def _init_kernel_generator(self): |
|
|
""" |
|
|
Initialize the kernel generator. |
|
|
""" |
|
|
for layer in self.kernel_generator: |
|
|
if isinstance(layer, nn.Linear): |
|
|
layer.weight.data.zero_() |
|
|
if layer.bias is not None: |
|
|
layer.bias.data.zero_() |
|
|
|
|
|
if self.static_conv_init is not None: |
|
|
|
|
|
self.static_conv_init(self.kernel_generator.w2.bias) |
|
|
|
|
|
def get_kernel(self, x: torch.Tensor) -> torch.Tensor: |
|
|
flat_kernels = self.kernel_generator(x) |
|
|
if flat_kernels.dim() == 3: |
|
|
kernels = rearrange(flat_kernels, 'b t (d w) -> b t d w', w=self.kernel_size) |
|
|
elif flat_kernels.dim() == 2: |
|
|
kernels = rearrange(flat_kernels, 'b (d w) -> b d w', w=self.kernel_size) |
|
|
else: |
|
|
raise ValueError(f"Invalid kernel shape: {flat_kernels.shape}") |
|
|
return kernels |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
cache: Optional[torch.Tensor] = None, |
|
|
output_final_state: bool = False, |
|
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
|
generator_input: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
x (`torch.Tensor`): |
|
|
Tensor of shape `[B, T, D]`. |
|
|
If `seq_idx` is provided, `B` must be 1. |
|
|
mask (`Optional[torch.Tensor]`): |
|
|
Attention mask dealing with padded positions. |
|
|
cache (`Optional[torch.Tensor]`): |
|
|
Previous cache tensor of shape `[N, D, W]`, where `W` is the kernel size. |
|
|
If provided, the cache is updated **inplace**. |
|
|
output_final_state (Optional[bool]): |
|
|
Whether to output the final state of shape `[N, D, W]`. Default: `False`. |
|
|
cu_seqlens (Optional[torch.LongTensor]): |
|
|
Cumulative sequence lengths for each batch. Used for varlen. Default: `None`. |
|
|
Shape: [B+1] |
|
|
|
|
|
Returns: |
|
|
Tensor of shape `[B, T, D]`. |
|
|
""" |
|
|
|
|
|
""" |
|
|
x: [B, T, D] |
|
|
return: [B, T, D] |
|
|
""" |
|
|
|
|
|
assert cu_seqlens is None, "cu_seqlens not supported yet." |
|
|
|
|
|
B, T, D, W = *x.shape, self.kernel_size |
|
|
N = B |
|
|
|
|
|
input_dtype = x.dtype |
|
|
|
|
|
if mask is not None: |
|
|
x = x.mul_(mask.unsqueeze(-1)) |
|
|
|
|
|
implementation = self.implementation |
|
|
if implementation == "triton" and not self.training: |
|
|
implementation = "triton_cache" |
|
|
|
|
|
|
|
|
if cache is not None and B * T == N: |
|
|
assert T == 1 |
|
|
if implementation in ["naive", "triton_training"]: |
|
|
x, cache = self._step_naive(x, cache, cu_seqlens, generator_input=generator_input) |
|
|
elif implementation in ["triton", "triton_cache", "triton_decoding"]: |
|
|
x, cache = self._step_triton(x, cache, cu_seqlens, generator_input=generator_input) |
|
|
else: |
|
|
raise ValueError(f"Unknown implementation: {implementation}") |
|
|
return x, cache |
|
|
|
|
|
if output_final_state: |
|
|
new_cache = rearrange(x[..., -min(W, T):, :], 'n w d -> n d w') |
|
|
else: |
|
|
new_cache = None |
|
|
|
|
|
if implementation in ["naive", "triton_decoding"]: |
|
|
x = self._forward_naive(x, generator_input=generator_input) |
|
|
elif implementation in ["triton", "triton_training"]: |
|
|
assert cache is None, "Cache not supported in pure triton mode. Please set model.eval() or use triton_cache mode." |
|
|
x = self._forward_triton(x, generator_input=generator_input) |
|
|
elif implementation == "triton_cache": |
|
|
x = self._forward_triton_cache(x, generator_input=generator_input, cache=cache) |
|
|
else: |
|
|
raise ValueError(f"Unknown implementation: {implementation}") |
|
|
|
|
|
if self.activation is not None: |
|
|
x = ACT2FN[self.activation](x) |
|
|
|
|
|
x = x.to(input_dtype) |
|
|
if output_final_state: |
|
|
if cache is None: |
|
|
cache = x.new_zeros(N, D, W) |
|
|
cache[:, :, -min(W, T):].copy_(new_cache) |
|
|
|
|
|
return x, cache |
|
|
|
|
|
def _forward_naive(self, x: torch.Tensor, generator_input: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
W = self.kernel_size |
|
|
generator_input = x if generator_input is None else generator_input |
|
|
kernels = self.get_kernel(generator_input) |
|
|
x = F.pad(x.transpose(1, 2), (W - 1, 0)) |
|
|
x = x.unfold(dimension=2, size=W, step=1) |
|
|
x = x.permute(0, 2, 1, 3) |
|
|
x = (x * kernels).sum(dim=-1) |
|
|
return x |
|
|
|
|
|
def _forward_triton(self, x: torch.Tensor, generator_input: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
generator_input = x if generator_input is None else generator_input |
|
|
kernels = self.get_kernel(generator_input) |
|
|
output_triton = dynamic_conv_triton_autograd(x, kernels) |
|
|
return output_triton |
|
|
|
|
|
@torch.no_grad |
|
|
def _forward_triton_cache(self, x: torch.Tensor, generator_input: Optional[torch.Tensor] = None, cache: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
generator_input = x if generator_input is None else generator_input |
|
|
assert not self.training, "Triton implementation is only available in eval mode." |
|
|
|
|
|
CHUNK_SIZE = 2048 |
|
|
n_chunk = (x.shape[1] + CHUNK_SIZE - 1) // CHUNK_SIZE |
|
|
output_triton = torch.zeros_like(x) |
|
|
if cache is not None: |
|
|
cache = rearrange(cache, "b d t -> b t d") |
|
|
for i in range(n_chunk): |
|
|
start = i * CHUNK_SIZE |
|
|
end = min((i + 1) * CHUNK_SIZE, x.shape[1]) |
|
|
kernels = self.get_kernel(generator_input[:, start:end]) |
|
|
out = dynamic_conv_triton_cache(x[:, start:end], kernels, cache=cache) |
|
|
output_triton[:, i*CHUNK_SIZE:end, :] = out |
|
|
cache = x[:, end-self.kernel_size:end, :] |
|
|
return output_triton |
|
|
|
|
|
def _step_naive( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
cache: torch.Tensor, |
|
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
|
generator_input: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
assert x.shape[1] == 1, "x must be of shape [B, 1, D]" |
|
|
shape = x.shape |
|
|
generator_input = x if generator_input is None else generator_input |
|
|
x = x.squeeze(1) |
|
|
generator_input = generator_input.squeeze(1) |
|
|
B, D, W = *x.shape, self.kernel_size |
|
|
|
|
|
|
|
|
cache.copy_(cache.roll(shifts=-1, dims=-1)) |
|
|
cache[:, :, -1] = x |
|
|
|
|
|
kernels = self.get_kernel(generator_input) |
|
|
x = torch.sum(cache * kernels, dim=-1) |
|
|
|
|
|
if self.activation is not None: |
|
|
x = ACT2FN[self.activation](x) |
|
|
|
|
|
return x.view(shape), cache |
|
|
|
|
|
def _step_triton( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
cache: torch.Tensor, |
|
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
|
generator_input: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
assert x.shape[1] == 1, "x must be of shape [B, 1, D]" |
|
|
shape = x.shape |
|
|
generator_input = x if generator_input is None else generator_input |
|
|
|
|
|
|
|
|
kernels_triton = self.get_kernel(generator_input.squeeze(1)) |
|
|
|
|
|
|
|
|
x_out_triton = causal_conv_step_triton( |
|
|
x, |
|
|
cache, |
|
|
kernels_triton, |
|
|
) |
|
|
|
|
|
|
|
|
if self.activation is not None: |
|
|
x_out_triton = ACT2FN[self.activation](x_out_triton) |
|
|
|
|
|
|
|
|
return x_out_triton.view(shape), cache |
|
|
|