Jet-Nemotron-2B / dconv_fwdbwd.py
t1101675's picture
Upload folder using huggingface_hub
5e94a88 verified
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
import torch.nn.functional as F
from torch.autograd import Function
# Helper function to ensure tensors are contiguous for Triton
def ensure_contiguous(t: torch.Tensor) -> torch.Tensor:
return t if t.is_contiguous() else t.contiguous()
# --- Forward Kernel (from previous step, slight modifications for autograd context) ---
@triton.jit
def _dynamic_conv_fwd_kernel(
X_ptr, K_ptr, Out_ptr,
B, T, D,
X_stride_b, X_stride_t, X_stride_d,
K_stride_b, K_stride_t, K_stride_d, K_stride_w,
Out_stride_b, Out_stride_t, Out_stride_d,
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_batch_time = tl.program_id(0)
pid_d_block = tl.program_id(1)
batch_idx = tl.cast(pid_batch_time // T, tl.int64)
time_idx = pid_batch_time % T
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
offs_w = tl.arange(0, W)
# Load Kernels
k_ptrs = K_ptr + (batch_idx * K_stride_b + time_idx * K_stride_t +
offs_d[:, None] * K_stride_d + offs_w[None, :] * K_stride_w)
k_vals = tl.load(k_ptrs, mask=d_mask[:, None], other=0.0)
# Load Input X with implicit padding
t_in_offs = time_idx + offs_w - W + 1
t_in_mask = (t_in_offs >= 0) & (t_in_offs < T)
x_ptrs = X_ptr + (batch_idx * X_stride_b + t_in_offs[None, :] * X_stride_t +
offs_d[:, None] * X_stride_d)
x_load_mask = d_mask[:, None] & t_in_mask[None, :]
x_vals = tl.load(x_ptrs, mask=x_load_mask, other=0.0)
# Compute and Accumulate
product = k_vals * x_vals
accumulator += tl.sum(product, axis=1)
# Store Result
out_ptrs = Out_ptr + (batch_idx * Out_stride_b + time_idx * Out_stride_t +
offs_d * Out_stride_d)
tl.store(out_ptrs, accumulator, mask=d_mask)
# --- Backward Kernel for Input Gradient (dX) ---
@triton.jit
def _dynamic_conv_bwd_dx_kernel(
GradOut_ptr, K_ptr, GradX_ptr, # Note: GradX is accumulated into
B, T, D,
GradOut_stride_b, GradOut_stride_t, GradOut_stride_d,
K_stride_b, K_stride_t, K_stride_d, K_stride_w,
GradX_stride_b, GradX_stride_t, GradX_stride_d,
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
"""
Computes gradient w.r.t. input X.
Grid: (B * T, cdiv(D, BLOCK_SIZE_D)) - covering GradX output
GradX[b, t_x, d] = sum_{w=0}^{W-1} GradOut[b, t, d] * K[b, t, d, w]
where t = t_x + W - 1 - w
"""
pid_batch_time_x = tl.program_id(0) # Covers B * T for output GradX
pid_d_block = tl.program_id(1)
batch_idx = pid_batch_time_x // T
time_idx_x = pid_batch_time_x % T # This is t_x
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
# Accumulator for GradX elements
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32)
offs_w = tl.arange(0, W) # [W]
# Loop over W to accumulate contributions
# Calculate the 't' index needed for GradOut and K based on t_x and w
# t = t_x + W - 1 - w
t_k_gradout_offs = time_idx_x + W - 1 - offs_w # Shape [W]
# Mask for valid 't' indices [0, T)
t_k_gradout_mask = (t_k_gradout_offs >= 0) & (t_k_gradout_offs < T) # Shape [W]
# --- Load GradOut ---
# Pointers shape: [BLOCK_SIZE_D, W]
gradout_ptrs = GradOut_ptr + (batch_idx * GradOut_stride_b +
t_k_gradout_offs[None, :] * GradOut_stride_t +
offs_d[:, None] * GradOut_stride_d)
# Combined mask for loading GradOut (valid D and valid t)
gradout_load_mask = d_mask[:, None] & t_k_gradout_mask[None, :]
# Shape: [BLOCK_SIZE_D, W]
gradout_vals = tl.load(gradout_ptrs, mask=gradout_load_mask, other=0.0)
# --- Load Kernels ---
# Pointers shape: [BLOCK_SIZE_D, W]
k_ptrs = K_ptr + (batch_idx * K_stride_b +
t_k_gradout_offs[None, :] * K_stride_t +
offs_d[:, None] * K_stride_d +
offs_w[None, :] * K_stride_w) # Index K with 't' and 'w'
# Combined mask for loading K (valid D and valid t)
k_load_mask = d_mask[:, None] & t_k_gradout_mask[None, :]
# Shape: [BLOCK_SIZE_D, W]
k_vals = tl.load(k_ptrs, mask=k_load_mask, other=0.0)
# --- Compute product and accumulate ---
# Shape: [BLOCK_SIZE_D, W]
product = gradout_vals * k_vals
# Sum contributions over the W dimension
accumulator += tl.sum(product, axis=1) # Shape: [BLOCK_SIZE_D]
# --- Store accumulated gradients ---
# Note: This kernel computes the *entire* gradient value for GradX[b, t_x, d_block].
# If this kernel could potentially be called multiple times for the same GradX elements
# (e.g., in complex graphs), atomic adds would be needed. Here, it seems direct store is fine.
gradx_ptrs = GradX_ptr + (batch_idx * GradX_stride_b +
time_idx_x * GradX_stride_t +
offs_d * GradX_stride_d)
tl.store(gradx_ptrs, accumulator, mask=d_mask)
# --- Backward Kernel for Kernel Gradient (dK) ---
@triton.jit
def _dynamic_conv_bwd_dk_kernel(
GradOut_ptr, X_ptr, GradK_ptr, # Note: GradK is written directly
B, T, D,
GradOut_stride_b, GradOut_stride_t, GradOut_stride_d,
X_stride_b, X_stride_t, X_stride_d,
GradK_stride_b, GradK_stride_t, GradK_stride_d, GradK_stride_w,
W: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
"""
Computes gradient w.r.t. kernels K.
Grid: (B * T, cdiv(D, BLOCK_SIZE_D)) - covering GradK output dims B, T, D
GradK[b, t, d, w] = GradOut[b, t, d] * X[b, t + w - W + 1, d]
"""
pid_batch_time = tl.program_id(0) # Covers B * T for output GradK
pid_d_block = tl.program_id(1)
batch_idx = pid_batch_time // T
time_idx = pid_batch_time % T # This is 't' for GradK and GradOut
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D
offs_w = tl.arange(0, W) # [W]
# --- Load GradOut ---
# Pointers shape: [BLOCK_SIZE_D] (only depends on b, t, d)
gradout_ptrs = GradOut_ptr + (batch_idx * GradOut_stride_b +
time_idx * GradOut_stride_t +
offs_d * GradOut_stride_d)
# Shape: [BLOCK_SIZE_D]
gradout_vals = tl.load(gradout_ptrs, mask=d_mask, other=0.0)
# --- Load Input X with implicit padding ---
# Calculate X's time index: t_x = t + w - W + 1
t_in_offs = time_idx + offs_w - W + 1 # Shape [W]
# Mask for valid t_x index [0, T)
t_in_mask = (t_in_offs >= 0) & (t_in_offs < T) # Shape [W]
# Pointers shape: [BLOCK_SIZE_D, W]
x_ptrs = X_ptr + (batch_idx * X_stride_b +
t_in_offs[None, :] * X_stride_t +
offs_d[:, None] * X_stride_d)
# Combined mask for loading X (valid D and valid t_x)
x_load_mask = d_mask[:, None] & t_in_mask[None, :] # Shape [BLOCK_SIZE_D, W]
# Shape: [BLOCK_SIZE_D, W]
x_vals = tl.load(x_ptrs, mask=x_load_mask, other=0.0)
# --- Compute GradK = GradOut * X ---
# Broadcast gradout_vals: [BLOCK_SIZE_D, 1] * [BLOCK_SIZE_D, W] -> [BLOCK_SIZE_D, W]
gradk_vals = gradout_vals[:, None] * x_vals # Shape [BLOCK_SIZE_D, W]
# --- Store gradients for Kernels ---
# Pointers shape: [BLOCK_SIZE_D, W]
gradk_ptrs = GradK_ptr + (batch_idx * GradK_stride_b +
time_idx * GradK_stride_t +
offs_d[:, None] * GradK_stride_d +
offs_w[None, :] * GradK_stride_w)
# Mask only needed for D dimension (W is fully computed)
# Store computed gradient values.
tl.store(gradk_ptrs, gradk_vals, mask=d_mask[:, None])
# --- Autograd Function ---
class DynamicConvTritonFunc(Function):
@staticmethod
def forward(ctx, x, kernels):
"""
Args:
x: Input tensor [B, T, D]
kernels: Kernels tensor [B, T, D, W]
"""
x = ensure_contiguous(x)
kernels = ensure_contiguous(kernels)
B, T, D = x.shape
W = kernels.shape[3]
assert W <= 4, "Kernel W > 4 not expected for this version"
out = torch.empty_like(x) # Output shape [B, T, D]
grid = lambda meta: (B * T, triton.cdiv(D, meta['BLOCK_SIZE_D']))
BLOCK_SIZE_D = 128 # Consider tuning
_dynamic_conv_fwd_kernel[grid](
x, kernels, out,
B, T, D,
x.stride(0), x.stride(1), x.stride(2),
kernels.stride(0), kernels.stride(1), kernels.stride(2), kernels.stride(3),
out.stride(0), out.stride(1), out.stride(2),
W=W,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
# Save tensors needed for backward
# Need x for dK, need kernels for dX
ctx.save_for_backward(x, kernels)
# Store W and BLOCK_SIZE_D needed for backward kernel calls
ctx.W = W
ctx.BLOCK_SIZE_D = BLOCK_SIZE_D
return out
@staticmethod
def backward(ctx, grad_out):
"""
Args:
grad_out: Gradient w.r.t. the output tensor [B, T, D]
Returns:
grad_x: Gradient w.r.t. input x [B, T, D]
grad_kernels: Gradient w.r.t. kernels [B, T, D, W]
"""
grad_out = ensure_contiguous(grad_out)
x, kernels = ctx.saved_tensors
W = ctx.W
BLOCK_SIZE_D = ctx.BLOCK_SIZE_D
B, T, D = x.shape
# Initialize gradients
# grad_x needs accumulation, start with zeros.
grad_x = torch.zeros_like(x)
# grad_kernels is computed directly, can use empty_like if kernel handles all writes.
# Using empty and relying on kernel writing zeros via masking/other=0.0.
grad_kernels = torch.empty_like(kernels)
# Define grid (can often be the same as forward or similar)
grid = lambda meta: (B * T, triton.cdiv(D, meta['BLOCK_SIZE_D']))
# Kernel call for grad_x
_dynamic_conv_bwd_dx_kernel[grid](
grad_out, kernels, grad_x,
B, T, D,
grad_out.stride(0), grad_out.stride(1), grad_out.stride(2),
kernels.stride(0), kernels.stride(1), kernels.stride(2), kernels.stride(3),
grad_x.stride(0), grad_x.stride(1), grad_x.stride(2),
W=W,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
# Kernel call for grad_kernels
_dynamic_conv_bwd_dk_kernel[grid](
grad_out, x, grad_kernels,
B, T, D,
grad_out.stride(0), grad_out.stride(1), grad_out.stride(2),
x.stride(0), x.stride(1), x.stride(2),
grad_kernels.stride(0), grad_kernels.stride(1), grad_kernels.stride(2), grad_kernels.stride(3),
W=W,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
# Return gradients in the order inputs were received by forward
return grad_x, grad_kernels
# --- User-facing function ---
def dynamic_conv_triton_autograd(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor:
"""
Fused dynamic convolution with autograd support using Triton kernels.
Assumes W <= 4.
Args:
x: Input tensor of shape [B, T, D].
kernels: Dynamic kernels of shape [B, T, D, W].
Returns:
Output tensor of shape [B, T, D].
"""
return DynamicConvTritonFunc.apply(x, kernels)