|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
from einops import rearrange |
|
|
import torch.nn.functional as F |
|
|
from torch.autograd import Function |
|
|
|
|
|
|
|
|
def ensure_contiguous(t: torch.Tensor) -> torch.Tensor: |
|
|
return t if t.is_contiguous() else t.contiguous() |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _dynamic_conv_fwd_kernel( |
|
|
X_ptr, K_ptr, Out_ptr, |
|
|
B, T, D, |
|
|
X_stride_b, X_stride_t, X_stride_d, |
|
|
K_stride_b, K_stride_t, K_stride_d, K_stride_w, |
|
|
Out_stride_b, Out_stride_t, Out_stride_d, |
|
|
W: tl.constexpr, |
|
|
BLOCK_SIZE_D: tl.constexpr, |
|
|
): |
|
|
pid_batch_time = tl.program_id(0) |
|
|
pid_d_block = tl.program_id(1) |
|
|
|
|
|
batch_idx = tl.cast(pid_batch_time // T, tl.int64) |
|
|
time_idx = pid_batch_time % T |
|
|
|
|
|
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) |
|
|
d_mask = offs_d < D |
|
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) |
|
|
offs_w = tl.arange(0, W) |
|
|
|
|
|
|
|
|
k_ptrs = K_ptr + (batch_idx * K_stride_b + time_idx * K_stride_t + |
|
|
offs_d[:, None] * K_stride_d + offs_w[None, :] * K_stride_w) |
|
|
k_vals = tl.load(k_ptrs, mask=d_mask[:, None], other=0.0) |
|
|
|
|
|
|
|
|
t_in_offs = time_idx + offs_w - W + 1 |
|
|
t_in_mask = (t_in_offs >= 0) & (t_in_offs < T) |
|
|
x_ptrs = X_ptr + (batch_idx * X_stride_b + t_in_offs[None, :] * X_stride_t + |
|
|
offs_d[:, None] * X_stride_d) |
|
|
x_load_mask = d_mask[:, None] & t_in_mask[None, :] |
|
|
x_vals = tl.load(x_ptrs, mask=x_load_mask, other=0.0) |
|
|
|
|
|
|
|
|
product = k_vals * x_vals |
|
|
accumulator += tl.sum(product, axis=1) |
|
|
|
|
|
|
|
|
out_ptrs = Out_ptr + (batch_idx * Out_stride_b + time_idx * Out_stride_t + |
|
|
offs_d * Out_stride_d) |
|
|
tl.store(out_ptrs, accumulator, mask=d_mask) |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _dynamic_conv_bwd_dx_kernel( |
|
|
GradOut_ptr, K_ptr, GradX_ptr, |
|
|
B, T, D, |
|
|
GradOut_stride_b, GradOut_stride_t, GradOut_stride_d, |
|
|
K_stride_b, K_stride_t, K_stride_d, K_stride_w, |
|
|
GradX_stride_b, GradX_stride_t, GradX_stride_d, |
|
|
W: tl.constexpr, |
|
|
BLOCK_SIZE_D: tl.constexpr, |
|
|
): |
|
|
""" |
|
|
Computes gradient w.r.t. input X. |
|
|
Grid: (B * T, cdiv(D, BLOCK_SIZE_D)) - covering GradX output |
|
|
GradX[b, t_x, d] = sum_{w=0}^{W-1} GradOut[b, t, d] * K[b, t, d, w] |
|
|
where t = t_x + W - 1 - w |
|
|
""" |
|
|
pid_batch_time_x = tl.program_id(0) |
|
|
pid_d_block = tl.program_id(1) |
|
|
|
|
|
batch_idx = pid_batch_time_x // T |
|
|
time_idx_x = pid_batch_time_x % T |
|
|
|
|
|
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) |
|
|
d_mask = offs_d < D |
|
|
|
|
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) |
|
|
offs_w = tl.arange(0, W) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_k_gradout_offs = time_idx_x + W - 1 - offs_w |
|
|
|
|
|
|
|
|
t_k_gradout_mask = (t_k_gradout_offs >= 0) & (t_k_gradout_offs < T) |
|
|
|
|
|
|
|
|
|
|
|
gradout_ptrs = GradOut_ptr + (batch_idx * GradOut_stride_b + |
|
|
t_k_gradout_offs[None, :] * GradOut_stride_t + |
|
|
offs_d[:, None] * GradOut_stride_d) |
|
|
|
|
|
gradout_load_mask = d_mask[:, None] & t_k_gradout_mask[None, :] |
|
|
|
|
|
gradout_vals = tl.load(gradout_ptrs, mask=gradout_load_mask, other=0.0) |
|
|
|
|
|
|
|
|
|
|
|
k_ptrs = K_ptr + (batch_idx * K_stride_b + |
|
|
t_k_gradout_offs[None, :] * K_stride_t + |
|
|
offs_d[:, None] * K_stride_d + |
|
|
offs_w[None, :] * K_stride_w) |
|
|
|
|
|
k_load_mask = d_mask[:, None] & t_k_gradout_mask[None, :] |
|
|
|
|
|
k_vals = tl.load(k_ptrs, mask=k_load_mask, other=0.0) |
|
|
|
|
|
|
|
|
|
|
|
product = gradout_vals * k_vals |
|
|
|
|
|
accumulator += tl.sum(product, axis=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradx_ptrs = GradX_ptr + (batch_idx * GradX_stride_b + |
|
|
time_idx_x * GradX_stride_t + |
|
|
offs_d * GradX_stride_d) |
|
|
tl.store(gradx_ptrs, accumulator, mask=d_mask) |
|
|
|
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _dynamic_conv_bwd_dk_kernel( |
|
|
GradOut_ptr, X_ptr, GradK_ptr, |
|
|
B, T, D, |
|
|
GradOut_stride_b, GradOut_stride_t, GradOut_stride_d, |
|
|
X_stride_b, X_stride_t, X_stride_d, |
|
|
GradK_stride_b, GradK_stride_t, GradK_stride_d, GradK_stride_w, |
|
|
W: tl.constexpr, |
|
|
BLOCK_SIZE_D: tl.constexpr, |
|
|
): |
|
|
""" |
|
|
Computes gradient w.r.t. kernels K. |
|
|
Grid: (B * T, cdiv(D, BLOCK_SIZE_D)) - covering GradK output dims B, T, D |
|
|
GradK[b, t, d, w] = GradOut[b, t, d] * X[b, t + w - W + 1, d] |
|
|
""" |
|
|
pid_batch_time = tl.program_id(0) |
|
|
pid_d_block = tl.program_id(1) |
|
|
|
|
|
batch_idx = pid_batch_time // T |
|
|
time_idx = pid_batch_time % T |
|
|
|
|
|
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D) |
|
|
d_mask = offs_d < D |
|
|
|
|
|
offs_w = tl.arange(0, W) |
|
|
|
|
|
|
|
|
|
|
|
gradout_ptrs = GradOut_ptr + (batch_idx * GradOut_stride_b + |
|
|
time_idx * GradOut_stride_t + |
|
|
offs_d * GradOut_stride_d) |
|
|
|
|
|
gradout_vals = tl.load(gradout_ptrs, mask=d_mask, other=0.0) |
|
|
|
|
|
|
|
|
|
|
|
t_in_offs = time_idx + offs_w - W + 1 |
|
|
|
|
|
t_in_mask = (t_in_offs >= 0) & (t_in_offs < T) |
|
|
|
|
|
|
|
|
x_ptrs = X_ptr + (batch_idx * X_stride_b + |
|
|
t_in_offs[None, :] * X_stride_t + |
|
|
offs_d[:, None] * X_stride_d) |
|
|
|
|
|
x_load_mask = d_mask[:, None] & t_in_mask[None, :] |
|
|
|
|
|
x_vals = tl.load(x_ptrs, mask=x_load_mask, other=0.0) |
|
|
|
|
|
|
|
|
|
|
|
gradk_vals = gradout_vals[:, None] * x_vals |
|
|
|
|
|
|
|
|
|
|
|
gradk_ptrs = GradK_ptr + (batch_idx * GradK_stride_b + |
|
|
time_idx * GradK_stride_t + |
|
|
offs_d[:, None] * GradK_stride_d + |
|
|
offs_w[None, :] * GradK_stride_w) |
|
|
|
|
|
|
|
|
tl.store(gradk_ptrs, gradk_vals, mask=d_mask[:, None]) |
|
|
|
|
|
|
|
|
|
|
|
class DynamicConvTritonFunc(Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, x, kernels): |
|
|
""" |
|
|
Args: |
|
|
x: Input tensor [B, T, D] |
|
|
kernels: Kernels tensor [B, T, D, W] |
|
|
""" |
|
|
x = ensure_contiguous(x) |
|
|
kernels = ensure_contiguous(kernels) |
|
|
|
|
|
B, T, D = x.shape |
|
|
W = kernels.shape[3] |
|
|
assert W <= 4, "Kernel W > 4 not expected for this version" |
|
|
|
|
|
out = torch.empty_like(x) |
|
|
|
|
|
grid = lambda meta: (B * T, triton.cdiv(D, meta['BLOCK_SIZE_D'])) |
|
|
BLOCK_SIZE_D = 128 |
|
|
|
|
|
_dynamic_conv_fwd_kernel[grid]( |
|
|
x, kernels, out, |
|
|
B, T, D, |
|
|
x.stride(0), x.stride(1), x.stride(2), |
|
|
kernels.stride(0), kernels.stride(1), kernels.stride(2), kernels.stride(3), |
|
|
out.stride(0), out.stride(1), out.stride(2), |
|
|
W=W, |
|
|
BLOCK_SIZE_D=BLOCK_SIZE_D, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
ctx.save_for_backward(x, kernels) |
|
|
|
|
|
ctx.W = W |
|
|
ctx.BLOCK_SIZE_D = BLOCK_SIZE_D |
|
|
|
|
|
return out |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_out): |
|
|
""" |
|
|
Args: |
|
|
grad_out: Gradient w.r.t. the output tensor [B, T, D] |
|
|
Returns: |
|
|
grad_x: Gradient w.r.t. input x [B, T, D] |
|
|
grad_kernels: Gradient w.r.t. kernels [B, T, D, W] |
|
|
""" |
|
|
grad_out = ensure_contiguous(grad_out) |
|
|
x, kernels = ctx.saved_tensors |
|
|
W = ctx.W |
|
|
BLOCK_SIZE_D = ctx.BLOCK_SIZE_D |
|
|
|
|
|
B, T, D = x.shape |
|
|
|
|
|
|
|
|
|
|
|
grad_x = torch.zeros_like(x) |
|
|
|
|
|
|
|
|
grad_kernels = torch.empty_like(kernels) |
|
|
|
|
|
|
|
|
grid = lambda meta: (B * T, triton.cdiv(D, meta['BLOCK_SIZE_D'])) |
|
|
|
|
|
|
|
|
_dynamic_conv_bwd_dx_kernel[grid]( |
|
|
grad_out, kernels, grad_x, |
|
|
B, T, D, |
|
|
grad_out.stride(0), grad_out.stride(1), grad_out.stride(2), |
|
|
kernels.stride(0), kernels.stride(1), kernels.stride(2), kernels.stride(3), |
|
|
grad_x.stride(0), grad_x.stride(1), grad_x.stride(2), |
|
|
W=W, |
|
|
BLOCK_SIZE_D=BLOCK_SIZE_D, |
|
|
) |
|
|
|
|
|
|
|
|
_dynamic_conv_bwd_dk_kernel[grid]( |
|
|
grad_out, x, grad_kernels, |
|
|
B, T, D, |
|
|
grad_out.stride(0), grad_out.stride(1), grad_out.stride(2), |
|
|
x.stride(0), x.stride(1), x.stride(2), |
|
|
grad_kernels.stride(0), grad_kernels.stride(1), grad_kernels.stride(2), grad_kernels.stride(3), |
|
|
W=W, |
|
|
BLOCK_SIZE_D=BLOCK_SIZE_D, |
|
|
) |
|
|
|
|
|
|
|
|
return grad_x, grad_kernels |
|
|
|
|
|
|
|
|
def dynamic_conv_triton_autograd(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Fused dynamic convolution with autograd support using Triton kernels. |
|
|
Assumes W <= 4. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape [B, T, D]. |
|
|
kernels: Dynamic kernels of shape [B, T, D, W]. |
|
|
|
|
|
Returns: |
|
|
Output tensor of shape [B, T, D]. |
|
|
""" |
|
|
return DynamicConvTritonFunc.apply(x, kernels) |