Jet-Nemotron-2B / dconv_fwd_cache.py
t1101675's picture
Upload folder using huggingface_hub
5e94a88 verified
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
import torch.nn.functional as F
from torch.autograd import Function
# Helper function to ensure tensors are contiguous for Triton
def ensure_contiguous(t: torch.Tensor) -> torch.Tensor:
return t if t.is_contiguous() else t.contiguous()
# --- Forward Kernel (modified for optional cache) ---
@triton.jit
def _dynamic_conv_fwd_kernel(
X_ptr, K_ptr, Out_ptr,
Cache_ptr, # New: Pointer to cache tensor
B, T, D, T_CACHE: tl.constexpr, # New: T is shape of x, T_CACHE is shape of cache
X_stride_b, X_stride_t, X_stride_d,
K_stride_b, K_stride_t, K_stride_d, K_stride_w,
Out_stride_b, Out_stride_t, Out_stride_d,
Cache_stride_b, Cache_stride_t, Cache_stride_d, # New: Strides for cache tensor
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_batch_time = tl.program_id(0) # Covers B * T_out
pid_d_block = tl.program_id(1)
# T here is the time dimension of x and Out
batch_idx = tl.cast(pid_batch_time // T, tl.int64)
time_idx = pid_batch_time % T # Current output time step for x (0 to T-1)
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
offs_w = tl.arange(0, W) # Kernel window offsets [0, 1, ..., W-1]
# Load Kernels (kernels are aligned with x's T dimension)
# K_ptr is indexed by time_idx which is the output time relative to x's start
k_ptrs = K_ptr + (batch_idx * K_stride_b + time_idx * K_stride_t +
offs_d[:, None] * K_stride_d + offs_w[None, :] * K_stride_w)
k_vals = tl.load(k_ptrs, mask=d_mask[:, None], other=0.0) # Shape: [BLOCK_SIZE_D, W]
# --- Load Input from conceptual [Cache, X] tensor ---
# `time_idx` is the current output time step (0 to T-1, where T is x.shape[1])
# `offs_w` is [0, ..., W-1]
# Convolution input time indices relative to the *start of x*:
# e.g., for W=3, offs_w - W + 1 gives [-2, -1, 0]
# so input_time_indices_rel_to_x_start are [time_idx-2, time_idx-1, time_idx]
input_time_indices_rel_to_x_start = time_idx + offs_w - W + 1 # Shape: [W]
# Effective input time indices in the conceptual [Cache, X] sequence:
# These indices range from 0 (start of cache) to T_CACHE + T - 1 (end of x)
eff_t_indices = input_time_indices_rel_to_x_start + T_CACHE # Shape: [W]
# Overall mask for valid time indices within the conceptual [Cache, X] tensor
# Total effective length is T_CACHE (for cache) + T (for x)
eff_t_valid_mask = (eff_t_indices >= 0) & (eff_t_indices < (T_CACHE + T)) # Shape: [W]
# --- Load from Cache ---
# Condition for loading from cache: index is valid AND index < T_CACHE
# (eff_t_indices are 0-indexed from the start of the cache)
cache_load_time_mask = eff_t_valid_mask & (eff_t_indices < T_CACHE) # Shape: [W]
cache_ptr_indices = eff_t_indices # Use directly if in cache range
cache_ptrs = Cache_ptr + (batch_idx * Cache_stride_b +
cache_ptr_indices[None, :] * Cache_stride_t +
offs_d[:, None] * Cache_stride_d)
cache_final_load_mask = d_mask[:, None] & cache_load_time_mask[None, :] # Shape: [BLOCK_SIZE_D, W]
vals_from_cache = tl.load(cache_ptrs, mask=cache_final_load_mask, other=0.0) # Shape: [BLOCK_SIZE_D, W]
# --- Load from X ---
# Condition for loading from X: index is valid AND index >= T_CACHE
x_load_time_mask = eff_t_valid_mask & (eff_t_indices >= T_CACHE) # Shape: [W]
# Adjust indices for X_ptr: X_ptr expects indices from 0 to T-1 (relative to start of x)
x_ptr_indices = eff_t_indices - T_CACHE # Shape: [W]
x_ptrs = X_ptr + (batch_idx * X_stride_b +
x_ptr_indices[None, :] * X_stride_t +
offs_d[:, None] * X_stride_d)
x_final_load_mask = d_mask[:, None] & x_load_time_mask[None, :] # Shape: [BLOCK_SIZE_D, W]
vals_from_x = tl.load(x_ptrs, mask=x_final_load_mask, other=0.0) # Shape: [BLOCK_SIZE_D, W]
# Combine values. Masks ensure only one source contributes non-zero per element.
# If T_CACHE == 0, cache_load_time_mask is all False, so vals_from_cache is 0.0.
x_input_vals = vals_from_cache + vals_from_x # Shape: [BLOCK_SIZE_D, W]
# Compute and Accumulate
product = k_vals * x_input_vals # Element-wise product
accumulator += tl.sum(product, axis=1) # Sum over W dimension
# Store Result
out_ptrs = Out_ptr + (batch_idx * Out_stride_b + time_idx * Out_stride_t +
offs_d * Out_stride_d)
tl.store(out_ptrs, accumulator, mask=d_mask)
# --- Backward Kernel for Input Gradient (dX) ---
@triton.jit
def _dynamic_conv_bwd_dx_kernel(
GradOut_ptr, K_ptr, GradX_ptr, # Note: GradX is accumulated into
B, T, D,
GradOut_stride_b, GradOut_stride_t, GradOut_stride_d,
K_stride_b, K_stride_t, K_stride_d, K_stride_w,
GradX_stride_b, GradX_stride_t, GradX_stride_d,
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
"""
Computes gradient w.r.t. input X.
Grid: (B * T, cdiv(D, BLOCK_SIZE_D)) - covering GradX output
GradX[b, t_x, d] = sum_{w=0}^{W-1} GradOut[b, t, d] * K[b, t, d, w]
where t = t_x + W - 1 - w
"""
pid_batch_time_x = tl.program_id(0) # Covers B * T for output GradX
pid_d_block = tl.program_id(1)
batch_idx = pid_batch_time_x // T
time_idx_x = pid_batch_time_x % T # This is t_x
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
# Accumulator for GradX elements
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
offs_w = tl.arange(0, W) # [W]
# Loop over W to accumulate contributions
# Calculate the 't' index needed for GradOut and K based on t_x and w
# t = t_x + W - 1 - w
t_k_gradout_offs = time_idx_x + W - 1 - offs_w # Shape [W]
# Mask for valid 't' indices [0, T)
t_k_gradout_mask = (t_k_gradout_offs >= 0) & (t_k_gradout_offs < T) # Shape [W]
# --- Load GradOut ---
# Pointers shape: [BLOCK_SIZE_D, W]
gradout_ptrs = GradOut_ptr + (batch_idx * GradOut_stride_b +
t_k_gradout_offs[None, :] * GradOut_stride_t +
offs_d[:, None] * GradOut_stride_d)
# Combined mask for loading GradOut (valid D and valid t)
gradout_load_mask = d_mask[:, None] & t_k_gradout_mask[None, :]
# Shape: [BLOCK_SIZE_D, W]
gradout_vals = tl.load(gradout_ptrs, mask=gradout_load_mask, other=0.0)
# --- Load Kernels ---
# Pointers shape: [BLOCK_SIZE_D, W]
k_ptrs = K_ptr + (batch_idx * K_stride_b +
t_k_gradout_offs[None, :] * K_stride_t +
offs_d[:, None] * K_stride_d +
offs_w[None, :] * K_stride_w) # Index K with 't' and 'w'
# Combined mask for loading K (valid D and valid t)
k_load_mask = d_mask[:, None] & t_k_gradout_mask[None, :]
# Shape: [BLOCK_SIZE_D, W]
k_vals = tl.load(k_ptrs, mask=k_load_mask, other=0.0)
# --- Compute product and accumulate ---
# Shape: [BLOCK_SIZE_D, W]
product = gradout_vals * k_vals
# Sum contributions over the W dimension
accumulator += tl.sum(product, axis=1) # Shape: [BLOCK_SIZE_D]
# --- Store accumulated gradients ---
gradx_ptrs = GradX_ptr + (batch_idx * GradX_stride_b +
time_idx_x * GradX_stride_t +
offs_d * GradX_stride_d)
tl.store(gradx_ptrs, accumulator, mask=d_mask)
# --- Backward Kernel for Kernel Gradient (dK) ---
@triton.jit
def _dynamic_conv_bwd_dk_kernel(
GradOut_ptr, X_ptr, GradK_ptr, # Note: GradK is written directly
B, T, D,
GradOut_stride_b, GradOut_stride_t, GradOut_stride_d,
X_stride_b, X_stride_t, X_stride_d,
GradK_stride_b, GradK_stride_t, GradK_stride_d, GradK_stride_w,
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
"""
Computes gradient w.r.t. kernels K.
Grid: (B * T, cdiv(D, BLOCK_SIZE_D)) - covering GradK output dims B, T, D
GradK[b, t, d, w] = GradOut[b, t, d] * X[b, t + w - W + 1, d]
"""
pid_batch_time = tl.program_id(0) # Covers B * T for output GradK
pid_d_block = tl.program_id(1)
batch_idx = pid_batch_time // T
time_idx = pid_batch_time % T # This is 't' for GradK and GradOut
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
offs_w = tl.arange(0, W) # [W]
# --- Load GradOut ---
# Pointers shape: [BLOCK_SIZE_D] (only depends on b, t, d)
gradout_ptrs = GradOut_ptr + (batch_idx * GradOut_stride_b +
time_idx * GradOut_stride_t +
offs_d * GradOut_stride_d)
# Shape: [BLOCK_SIZE_D]
gradout_vals = tl.load(gradout_ptrs, mask=d_mask, other=0.0)
# --- Load Input X with implicit padding ---
# Calculate X's time index: t_x = t + w - W + 1
t_in_offs = time_idx + offs_w - W + 1 # Shape [W]
# Mask for valid t_x index [0, T)
t_in_mask = (t_in_offs >= 0) & (t_in_offs < T) # Shape [W]
# Pointers shape: [BLOCK_SIZE_D, W]
x_ptrs = X_ptr + (batch_idx * X_stride_b +
t_in_offs[None, :] * X_stride_t +
offs_d[:, None] * X_stride_d)
# Combined mask for loading X (valid D and valid t_x)
x_load_mask = d_mask[:, None] & t_in_mask[None, :] # Shape [BLOCK_SIZE_D, W]
# Shape: [BLOCK_SIZE_D, W]
x_vals = tl.load(x_ptrs, mask=x_load_mask, other=0.0)
# --- Compute GradK = GradOut * X ---
# Broadcast gradout_vals: [BLOCK_SIZE_D, 1] * [BLOCK_SIZE_D, W] -> [BLOCK_SIZE_D, W]
gradk_vals = gradout_vals[:, None] * x_vals # Shape [BLOCK_SIZE_D, W]
# --- Store gradients for Kernels ---
# Pointers shape: [BLOCK_SIZE_D, W]
gradk_ptrs = GradK_ptr + (batch_idx * GradK_stride_b +
time_idx * GradK_stride_t +
offs_d[:, None] * GradK_stride_d +
offs_w[None, :] * GradK_stride_w)
# Mask only needed for D dimension (W is fully computed)
# Store computed gradient values.
tl.store(gradk_ptrs, gradk_vals, mask=d_mask[:, None])
# --- Autograd Function ---
class DynamicConvTritonFunc(Function):
@staticmethod
def forward(ctx, x, kernels, cache=None): # Added cache argument
"""
Args:
x: Input tensor [B, T, D]
kernels: Kernels tensor [B, T, D, W]
cache: Optional past context tensor [B, T_cache, D]
"""
x = ensure_contiguous(x)
kernels = ensure_contiguous(kernels)
B, T, D = x.shape # T is the time dimension of the current input x
_B_k, _T_k, _D_k, W = kernels.shape # Kernels are [B, T_x, D, W]
assert B == _B_k and T == _T_k and D == _D_k, \
f"Shape mismatch between x ({x.shape}) and kernels ({kernels.shape}) on B, T, or D dims"
assert W <= 4, "Kernel W > 4 not expected for this version"
out = torch.empty_like(x) # Output shape [B, T, D], corresponds to x
T_cache_val = 0
# Use x's data pointer and zero strides as placeholders if cache is None.
# These won't be used by the kernel if T_CACHE_VAL is 0 due to masking.
cache_ptr_val = x
cache_s_b, cache_s_t, cache_s_d = 0, 0, 0
if cache is not None:
cache = ensure_contiguous(cache)
B_c, T_c, D_c = cache.shape
assert B_c == B, f"Batch size mismatch: x ({B}) vs cache ({B_c})"
assert D_c == D, f"Dimension mismatch: x ({D}) vs cache ({D_c})"
T_cache_val = T_c
cache_ptr_val = cache
cache_s_b, cache_s_t, cache_s_d = cache.stride(0), cache.stride(1), cache.stride(2)
grid = lambda meta: (B * T, triton.cdiv(D, meta['BLOCK_SIZE_D']))
BLOCK_SIZE_D = 128 # Consider tuning
_dynamic_conv_fwd_kernel[grid](
x, kernels, out, # X, K, Out pointers
cache_ptr_val, # Cache pointer
B, T, D, T_cache_val, # Shapes: B, T_x, D, T_cache
x.stride(0), x.stride(1), x.stride(2), # X strides
kernels.stride(0), kernels.stride(1), kernels.stride(2), kernels.stride(3), # K strides
out.stride(0), out.stride(1), out.stride(2), # Out strides
cache_s_b, cache_s_t, cache_s_d, # Cache strides
W=W,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
# Save tensors needed for backward (cache is not needed for current backward)
ctx.save_for_backward(x, kernels)
ctx.W = W
ctx.BLOCK_SIZE_D = BLOCK_SIZE_D
# ctx.T_cache = T_cache_val # Not needed for current backward
return out
@staticmethod
def backward(ctx, grad_out):
raise NotImplementedError("Backward of cached fwdbwd is not implemented")
# --- User-facing function ---
def dynamic_conv_triton_cache(x: torch.Tensor, kernels: torch.Tensor, cache: torch.Tensor = None) -> torch.Tensor:
"""
Fused dynamic convolution with autograd support using Triton kernels.
Assumes W <= 4.
Args:
x: Input tensor of shape [B, T, D].
kernels: Dynamic kernels of shape [B, T, D, W].
cache: Optional past context tensor of shape [B, T_cache, D].
If provided, treated as concatenated before x for convolution input.
Returns:
Output tensor of shape [B, T, D].
"""
return DynamicConvTritonFunc.apply(x, kernels, cache)