Jet-Nemotron-2B / dynamic_conv.py
t1101675's picture
Upload folder using huggingface_hub
5e94a88 verified
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from collections import OrderedDict
from typing import Optional, Tuple, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
from .dconv_fwdbwd import dynamic_conv_triton_autograd
from .dconv_fwd_cache import dynamic_conv_triton_cache
from .dconv_step import causal_conv_step_triton
class DynamicShortConvolution(nn.Module):
"""
Simple wrapper around `nn.Conv1d` that accepts dimension last.
"""
def __init__(
self,
hidden_size: int,
kernel_size: int,
generator_input_size: Optional[int] = None,
generator_reduction: Optional[int] = None,
generator_activation: str = 'silu',
activation: Optional[str] = 'silu',
static_conv_init: Callable = None,
use_fast_conv1d: bool = True,
implementation: str = "naive",
) -> DynamicShortConvolution:
super().__init__()
self.hidden_size = hidden_size
self.generator_input_size = hidden_size if generator_input_size is None else generator_input_size
self.generator_hidden_size = hidden_size if generator_reduction is None else (hidden_size // generator_reduction)
self.kernel_size = kernel_size
self.activation = None
self.use_fast_conv1d = use_fast_conv1d
self.implementation = implementation
if activation is not None:
assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
self.activation = activation
self.static_conv_init = static_conv_init
self.kernel_generator = nn.Sequential(
OrderedDict([
("w1", nn.Linear(self.generator_input_size, self.generator_hidden_size, bias=False)),
("act", ACT2FN[generator_activation]),
("w2", nn.Linear(self.generator_hidden_size, self.hidden_size * self.kernel_size, bias=True)),
])
)
self._init_kernel_generator()
def _init_kernel_generator(self):
"""
Initialize the kernel generator.
"""
for layer in self.kernel_generator:
if isinstance(layer, nn.Linear):
layer.weight.data.zero_()
if layer.bias is not None:
layer.bias.data.zero_()
if self.static_conv_init is not None:
# init for static_bias
self.static_conv_init(self.kernel_generator.w2.bias)
def get_kernel(self, x: torch.Tensor) -> torch.Tensor:
flat_kernels = self.kernel_generator(x)
if flat_kernels.dim() == 3:
kernels = rearrange(flat_kernels, 'b t (d w) -> b t d w', w=self.kernel_size)
elif flat_kernels.dim() == 2:
kernels = rearrange(flat_kernels, 'b (d w) -> b d w', w=self.kernel_size)
else:
raise ValueError(f"Invalid kernel shape: {flat_kernels.shape}")
return kernels
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
cache: Optional[torch.Tensor] = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
generator_input: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x (`torch.Tensor`):
Tensor of shape `[B, T, D]`.
If `seq_idx` is provided, `B` must be 1.
mask (`Optional[torch.Tensor]`):
Attention mask dealing with padded positions.
cache (`Optional[torch.Tensor]`):
Previous cache tensor of shape `[N, D, W]`, where `W` is the kernel size.
If provided, the cache is updated **inplace**.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, D, W]`. Default: `False`.
cu_seqlens (Optional[torch.LongTensor]):
Cumulative sequence lengths for each batch. Used for varlen. Default: `None`.
Shape: [B+1]
Returns:
Tensor of shape `[B, T, D]`.
"""
"""
x: [B, T, D]
return: [B, T, D]
"""
assert cu_seqlens is None, "cu_seqlens not supported yet."
B, T, D, W = *x.shape, self.kernel_size
N = B
input_dtype = x.dtype
if mask is not None:
x = x.mul_(mask.unsqueeze(-1))
implementation = self.implementation
if implementation == "triton" and not self.training:
implementation = "triton_cache"
# during the decoding phase, we assume the batch is composed of sequences of length 1
if cache is not None and B * T == N:
assert T == 1
if implementation in ["naive", "triton_training"]:
x, cache = self._step_naive(x, cache, cu_seqlens, generator_input=generator_input)
elif implementation in ["triton", "triton_cache", "triton_decoding"]:
x, cache = self._step_triton(x, cache, cu_seqlens, generator_input=generator_input)
else:
raise ValueError(f"Unknown implementation: {implementation}")
return x, cache
if output_final_state:
new_cache = rearrange(x[..., -min(W, T):, :], 'n w d -> n d w')
else:
new_cache = None
if implementation in ["naive", "triton_decoding"]:
x = self._forward_naive(x, generator_input=generator_input) # [B, T, D]
elif implementation in ["triton", "triton_training"]:
assert cache is None, "Cache not supported in pure triton mode. Please set model.eval() or use triton_cache mode."
x = self._forward_triton(x, generator_input=generator_input)
elif implementation == "triton_cache":
x = self._forward_triton_cache(x, generator_input=generator_input, cache=cache)
else:
raise ValueError(f"Unknown implementation: {implementation}")
if self.activation is not None:
x = ACT2FN[self.activation](x)
x = x.to(input_dtype)
if output_final_state:
if cache is None:
cache = x.new_zeros(N, D, W)
cache[:, :, -min(W, T):].copy_(new_cache)
return x, cache
def _forward_naive(self, x: torch.Tensor, generator_input: Optional[torch.Tensor] = None) -> torch.Tensor:
W = self.kernel_size
generator_input = x if generator_input is None else generator_input
kernels = self.get_kernel(generator_input)
x = F.pad(x.transpose(1, 2), (W - 1, 0)) # [B, D, T+W-1]
x = x.unfold(dimension=2, size=W, step=1) # [B, D, T, W]
x = x.permute(0, 2, 1, 3) # [B, T, D, W]
x = (x * kernels).sum(dim=-1) # [B, T, D]
return x
def _forward_triton(self, x: torch.Tensor, generator_input: Optional[torch.Tensor] = None) -> torch.Tensor:
generator_input = x if generator_input is None else generator_input
kernels = self.get_kernel(generator_input)
output_triton = dynamic_conv_triton_autograd(x, kernels)
return output_triton
@torch.no_grad
def _forward_triton_cache(self, x: torch.Tensor, generator_input: Optional[torch.Tensor] = None, cache: Optional[torch.Tensor] = None) -> torch.Tensor:
generator_input = x if generator_input is None else generator_input
assert not self.training, "Triton implementation is only available in eval mode."
# cache: [B, D, T(W)]
CHUNK_SIZE = 2048
n_chunk = (x.shape[1] + CHUNK_SIZE - 1) // CHUNK_SIZE
output_triton = torch.zeros_like(x)
if cache is not None:
cache = rearrange(cache, "b d t -> b t d") # [B, T(W), D]
for i in range(n_chunk):
start = i * CHUNK_SIZE
end = min((i + 1) * CHUNK_SIZE, x.shape[1])
kernels = self.get_kernel(generator_input[:, start:end])
out = dynamic_conv_triton_cache(x[:, start:end], kernels, cache=cache)
output_triton[:, i*CHUNK_SIZE:end, :] = out
cache = x[:, end-self.kernel_size:end, :]
return output_triton
def _step_naive(
self,
x: torch.Tensor,
cache: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
generator_input: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.shape[1] == 1, "x must be of shape [B, 1, D]"
shape = x.shape
generator_input = x if generator_input is None else generator_input
x = x.squeeze(1)
generator_input = generator_input.squeeze(1) # Shape [B, D]
B, D, W = *x.shape, self.kernel_size
# we follow the fast mode that updates the cache in-place
cache.copy_(cache.roll(shifts=-1, dims=-1))
cache[:, :, -1] = x # [B, D, T(W)]
kernels = self.get_kernel(generator_input) # [B, D, W]
x = torch.sum(cache * kernels, dim=-1)
if self.activation is not None:
x = ACT2FN[self.activation](x)
return x.view(shape), cache
def _step_triton(
self,
x: torch.Tensor,
cache: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
generator_input: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# --- Triton Implementation ---
assert x.shape[1] == 1, "x must be of shape [B, 1, D]"
shape = x.shape # Keep original shape [B, 1, D] for return
generator_input = x if generator_input is None else generator_input
# 1. Generate kernels
kernels_triton = self.get_kernel(generator_input.squeeze(1)) # [B, D, W]
# 2. Call Triton kernel without activation
x_out_triton = causal_conv_step_triton(
x,
cache,
kernels_triton,
)
# Apply activation (if any) after kernel execution
if self.activation is not None:
x_out_triton = ACT2FN[self.activation](x_out_triton)
# 3. Return reshaped output and the *same cache tensor* (it was updated in-place)
return x_out_triton.view(shape), cache