File size: 1,051 Bytes
e8e4be6
 
 
 
 
 
 
 
 
 
 
 
9b13459
e8e4be6
 
 
 
9b13459
 
 
 
 
 
e8e4be6
9b13459
 
 
 
 
 
 
 
e8e4be6
 
 
 
9b13459
 
 
e8e4be6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import torch.nn.functional as F
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def torch_causal_conv1d(input_tensor, weight, bias):
    # Convert to weight dtype for computation
    x = input_tensor.to(weight.dtype)
    dim = weight.shape[0]
    width = weight.shape[1]
    seqlen = input_tensor.shape[-1]

    # Depthwise causal conv1d using PyTorch
    out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)

    # Truncate to original sequence length
    out = out[..., :seqlen]

    # Convert back to original dtype
    return out.to(input_tensor.dtype)


run_benchmark(
    kernel_type=KernelTypeEnum.CAUSAL_CONV1D,
    impl_name="torch_eager",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=torch_causal_conv1d,
)