drbh commited on
Commit
846d481
·
0 Parent(s):

feat: small grayscale kernel

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.so filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake
2
+ .venv
3
+ __pycache__
4
+ *.pyc
5
+ torch-ext/img2gray/*.abi3.so
6
+ torch-ext/img2gray/_ops.py
7
+ torch-ext/registration.h
8
+ CMakeLists.txt
9
+ pyproject.toml
10
+ setup.py
build.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "img2gray"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h"
9
+ ]
10
+
11
+ [kernel.img2gray]
12
+ backend = "cuda"
13
+ depends = ["torch"]
14
+ src = [
15
+ "csrc/img2gray.cu",
16
+ ]
csrc/img2gray.cu ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cstdint>
2
+ #include <torch/torch.h>
3
+
4
+ // Define a kernel to convert RGB to Grayscale
5
+ __global__ void img2gray_kernel(const uint8_t* input, uint8_t* output, int width, int height) {
6
+ int x = blockIdx.x * blockDim.x + threadIdx.x;
7
+ int y = blockIdx.y * blockDim.y + threadIdx.y;
8
+
9
+ if (x < width && y < height) {
10
+ int idx = (y * width + x) * 3; // RGB has 3 channels
11
+ uint8_t r = input[idx];
12
+ uint8_t g = input[idx + 1];
13
+ uint8_t b = input[idx + 2];
14
+
15
+ // Convert to grayscale using luminosity method
16
+ uint8_t gray = static_cast<uint8_t>(0.21f * r + 0.72f * g + 0.07f * b);
17
+ output[y * width + x] = gray;
18
+ }
19
+ }
20
+
21
+
22
+ // Define a wrapper for this kernel to align with the PyTorch extension interface
23
+ void img2gray_cuda(torch::Tensor input, torch::Tensor output) {
24
+ const int width = input.size(1);
25
+ const int height = input.size(0);
26
+
27
+ const dim3 blockSize(16, 16);
28
+ const dim3 gridSize((width + blockSize.x - 1) / blockSize.x, (height + blockSize.y - 1) / blockSize.y);
29
+
30
+ img2gray_kernel<<<gridSize, blockSize>>>(
31
+ input.data_ptr<uint8_t>(),
32
+ output.data_ptr<uint8_t>(),
33
+ width,
34
+ height
35
+ );
36
+
37
+ cudaDeviceSynchronize();
38
+ }
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1750234878,
77
+ "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1750790603,
102
+ "narHash": "sha256-m7FoTYWDV811Y7FiuJPa/uCOV63rf6LHxWportuI9h0=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "37cad313efea84e213b2fc13b2ec808d273a126d",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for img2gray kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
scripts/sanity.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import img2gray
3
+
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ print(dir(img2gray))
8
+
9
+ img = Image.open("/home/ubuntu/Projects/img2gray/kernel-builder-logo-color.png").convert("RGB")
10
+ img = np.array(img)
11
+ img_tensor = torch.from_numpy(img)
12
+ print(img_tensor.shape) # HWC
13
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).contiguous().cuda() # BCHW
14
+ print(img_tensor.shape) # BCHW
15
+
16
+ gray_tensor = img2gray.img2gray(img_tensor).squeeze()
17
+ print(gray_tensor.shape) # B1HW
18
+
19
+ # save the output image
20
+ gray_img = gray_tensor.cpu().numpy() # 1HW -> HW
21
+ gray_img = Image.fromarray(gray_img.astype(np.uint8), mode="L")
22
+
23
+ gray_img.save("/home/ubuntu/Projects/img2gray/kernel-builder-logo-gray.png")
torch-ext/img2gray/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+ def img2gray(input: torch.Tensor) -> torch.Tensor:
6
+ # we expect input to be in BCHW format
7
+ batch, channels, height, width = input.shape
8
+
9
+ assert channels == 3, "Input image must have 3 channels (RGB)"
10
+
11
+ output = torch.empty((batch, 1, height, width), device=input.device, dtype=input.dtype)
12
+
13
+ for b in range(batch):
14
+ single_image = input[b].permute(1, 2, 0).contiguous() # HWC
15
+ single_output = output[b].reshape(height, width) # HW
16
+ ops.img2gray(single_image, single_output)
17
+
18
+ return output
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+
7
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
+ ops.def("img2gray(Tensor input, Tensor output) -> ()");
9
+ ops.impl("img2gray", torch::kCUDA, &img2gray_cuda);
10
+ }
11
+
12
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include <torch/torch.h>
2
+
3
+ void img2gray_cuda(torch::Tensor input, torch::Tensor output);