Jet-Nemotron-2B / dconv_step.py
t1101675's picture
Upload folder using huggingface_hub
5e94a88 verified
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
import torch.nn.functional as F
from typing import Tuple, Optional
# Helper function to ensure tensors are contiguous for Triton
def ensure_contiguous(t: torch.Tensor) -> torch.Tensor:
# Ensure tensor is contiguous in memory.
return t if t.is_contiguous() else t.contiguous()
# Removed _apply_activation helper function
@triton.jit
def _causal_conv_step_kernel(
# --- Input/Output Pointers ---
X_ptr, # Pointer to current input x [B, D] (after squeeze)
Cache_ptr, # Pointer to cache [B, D, W], updated IN-PLACE
Kernels_ptr, # Pointer to generated kernels [B, D, W]
Out_ptr, # Pointer to output tensor [B, D]
# --- Tensor Dimensions ---
B, D, # Batch size, Feature dimension
# --- Tensor Strides ---
X_stride_b, X_stride_d,
Cache_stride_b, Cache_stride_d, Cache_stride_w,
Kernels_stride_b, Kernels_stride_d, Kernels_stride_w,
Out_stride_b, Out_stride_d,
# --- Kernel Meta-Parameters ---
W: tl.constexpr, # Kernel width (Cache size), passed as compile-time constant (1 < W <= 4)
BLOCK_SIZE_D: tl.constexpr, # Block size for D dimension (tuning parameter)
# Removed ACTIVATION: tl.constexpr
):
"""
Triton kernel for a single step (T=1) of causal dynamic convolution.
Updates the cache in-place and computes the output (without activation).
Optimized for small W (1 < W <= 4) by manually unrolling the W dimension.
Does NOT handle separate static bias.
Grid: (B, cdiv(D, BLOCK_SIZE_D))
Updates Cache[b, d, :] and computes Out[b, d].
"""
# 1. --- Get Program IDs and Calculate Indices ---
pid_b = tl.program_id(0) # Program ID for batch dimension
pid_d_block = tl.program_id(1) # Program ID for dimension block
offs_d = pid_d_block * BLOCK_SIZE_D + tl.arange(0, BLOCK_SIZE_D)
d_mask = offs_d < D # Shape: [BLOCK_SIZE_D]
# 2. --- Load Current Input X ---
x_ptrs = X_ptr + pid_b * X_stride_b + offs_d * X_stride_d
x_curr = tl.load(x_ptrs, mask=d_mask, other=0.0) # Shape: [BLOCK_SIZE_D]
# --- Initialize Accumulator ---
accumulator = tl.zeros((BLOCK_SIZE_D,), dtype=x_curr.dtype) # Use input dtype
# --- Manually Unroll Operations for W ---
# We will load kernel values and cache values step-by-step
# and perform the calculation and cache update.
# --- Step w = 0 ---
# Compute: cache_val_1 * k_val_0 (part 1)
# Cache Update: store cache_val_1 at index 0
if tl.constexpr(W > 1):
# Load k_val_0
k_ptr_0 = Kernels_ptr + pid_b * Kernels_stride_b + offs_d * Kernels_stride_d + 0 * Kernels_stride_w
k_val_0 = tl.load(k_ptr_0, mask=d_mask, other=0.0)
# Load cache_val_1 (needed for computation and storing at index 0)
cache_ptr_1 = Cache_ptr + pid_b * Cache_stride_b + offs_d * Cache_stride_d + 1 * Cache_stride_w
cache_val_1 = tl.load(cache_ptr_1, mask=d_mask, other=0.0)
# Accumulate Part 1
accumulator += cache_val_1 * k_val_0
# Cache Update: Store cache_val_1 -> cache_ptr_0
cache_ptr_0 = Cache_ptr + pid_b * Cache_stride_b + offs_d * Cache_stride_d + 0 * Cache_stride_w
tl.store(cache_ptr_0, cache_val_1, mask=d_mask)
# --- Step w = 1 ---
# Compute: cache_val_2 * k_val_1 (part 1)
# Cache Update: store cache_val_2 at index 1
if tl.constexpr(W > 2):
# Load k_val_1
k_ptr_1 = Kernels_ptr + pid_b * Kernels_stride_b + offs_d * Kernels_stride_d + 1 * Kernels_stride_w
k_val_1 = tl.load(k_ptr_1, mask=d_mask, other=0.0)
# Load cache_val_2
cache_ptr_2 = Cache_ptr + pid_b * Cache_stride_b + offs_d * Cache_stride_d + 2 * Cache_stride_w
cache_val_2 = tl.load(cache_ptr_2, mask=d_mask, other=0.0)
# Accumulate Part 1
accumulator += cache_val_2 * k_val_1
# Cache Update: Store cache_val_2 -> cache_ptr_1
cache_ptr_1 = Cache_ptr + pid_b * Cache_stride_b + offs_d * Cache_stride_d + 1 * Cache_stride_w
tl.store(cache_ptr_1, cache_val_2, mask=d_mask)
# --- Step w = 2 ---
# Compute: cache_val_3 * k_val_2 (part 1)
# Cache Update: store cache_val_3 at index 2
if tl.constexpr(W > 3):
# Load k_val_2
k_ptr_2 = Kernels_ptr + pid_b * Kernels_stride_b + offs_d * Kernels_stride_d + 2 * Kernels_stride_w
k_val_2 = tl.load(k_ptr_2, mask=d_mask, other=0.0)
# Load cache_val_3
cache_ptr_3 = Cache_ptr + pid_b * Cache_stride_b + offs_d * Cache_stride_d + 3 * Cache_stride_w
cache_val_3 = tl.load(cache_ptr_3, mask=d_mask, other=0.0)
# Accumulate Part 1
accumulator += cache_val_3 * k_val_2
# Cache Update: Store cache_val_3 -> cache_ptr_2
cache_ptr_2 = Cache_ptr + pid_b * Cache_stride_b + offs_d * Cache_stride_d + 2 * Cache_stride_w
tl.store(cache_ptr_2, cache_val_3, mask=d_mask)
# --- Final Step (Part 2 and Final Cache Update) ---
# Compute: x_curr * k_val_{W-1} (part 2)
# Cache Update: store x_curr at index W-1
# Load k_val_{W-1}
k_ptr_last = Kernels_ptr + pid_b * Kernels_stride_b + offs_d * Kernels_stride_d + (W - 1) * Kernels_stride_w
k_val_last = tl.load(k_ptr_last, mask=d_mask, other=0.0)
# Accumulate Part 2
accumulator += x_curr * k_val_last
# Final Cache Update: Store x_curr -> cache_ptr_{W-1}
cache_ptr_last = Cache_ptr + pid_b * Cache_stride_b + offs_d * Cache_stride_d + (W - 1) * Cache_stride_w
tl.store(cache_ptr_last, x_curr, mask=d_mask)
# Removed activation application: accumulator = _apply_activation(accumulator, ACTIVATION)
# 6. --- Store Output ---
out_ptrs = Out_ptr + pid_b * Out_stride_b + offs_d * Out_stride_d
tl.store(out_ptrs, accumulator, mask=d_mask) # Store result without activation
# Cache update is now fully handled within the unrolled steps.
# --- Python Wrapper Function ---
def causal_conv_step_triton(
x: torch.Tensor, # Input tensor [B, 1, D]
cache: torch.Tensor, # Cache tensor [B, D, W], modified in-place
kernels: torch.Tensor, # Kernels tensor [B, D, W]
# Removed activation parameter
) -> torch.Tensor: # Returns output tensor [B, D] (before activation)
"""
Performs one step of causal dynamic convolution using Triton.
Updates the cache in-place. Does NOT fuse activation. Assumes 1 < W <= 4.
Uses manually unrolled kernel for W dimension.
Args:
x: Current input token tensor of shape [B, 1, D].
cache: Cache tensor of shape [B, D, W]. Will be updated in-place.
kernels: Dynamically generated kernels tensor of shape [B, D, W].
Returns:
Output tensor of shape [B, D] for the current step (before activation).
"""
# --- Input Validation and Preparation ---
assert x.dim() == 3 and x.shape[1] == 1, "Input x must have shape [B, 1, D]"
assert cache.dim() == 3, "Cache must have shape [B, D, W]"
assert kernels.dim() == 3, "Kernels must have shape [B, D, W]"
B, _, D = x.shape
W = cache.shape[2]
# Updated assertion: W must be > 1 and <= 4
assert 1 < W <= 4, f"Kernel W={W}, this optimized version assumes 1 < W <= 4"
assert cache.shape[0] == B and cache.shape[1] == D, f"Cache shape mismatch: {cache.shape}"
assert kernels.shape == cache.shape, f"Kernels shape mismatch: {kernels.shape}"
assert x.is_cuda and cache.is_cuda and kernels.is_cuda, "Inputs must be CUDA tensors"
# Allow different input dtypes, but ensure they are compatible or handled
# assert x.dtype == cache.dtype == kernels.dtype, "Input dtypes must match"
# Squeeze the time dimension from input x
x_squeezed = x.squeeze(1) # Shape [B, D]
# Ensure tensors are contiguous for correct stride calculations in Triton
x_squeezed = ensure_contiguous(x_squeezed)
# Cache MUST be contiguous for in-place updates and loads/stores to work reliably
cache = ensure_contiguous(cache)
kernels = ensure_contiguous(kernels)
# Create output tensor with the same dtype as input x
out = torch.empty_like(x_squeezed) # Shape [B, D]
# --- Triton Kernel Launch ---
grid = lambda meta: (B, triton.cdiv(D, meta['BLOCK_SIZE_D']))
BLOCK_SIZE_D = 64 # Example, tune this value
# Launch the kernel
_causal_conv_step_kernel[grid](
x_squeezed, cache, kernels, out, # Tensor pointers
B, D, # Dimensions
x_squeezed.stride(0), x_squeezed.stride(1), # x strides
cache.stride(0), cache.stride(1), cache.stride(2), # cache strides
kernels.stride(0), kernels.stride(1), kernels.stride(2), # kernels strides
out.stride(0), out.stride(1), # out strides
# --- Meta-parameters ---
W=W, # Pass W as constexpr
BLOCK_SIZE_D=BLOCK_SIZE_D, # Pass BLOCK_SIZE_D as constexpr
# Removed ACTIVATION=activation
)
return out # Return the computed output [B, D] (before activation)