drbh
		
	commited on
		
		
					Commit 
							
							·
						
						846d481
	
0
								Parent(s):
							
							
feat: small grayscale kernel
Browse files- .gitattributes +2 -0
- .gitignore +10 -0
- build.toml +16 -0
- csrc/img2gray.cu +38 -0
- flake.lock +168 -0
- flake.nix +17 -0
- scripts/sanity.py +23 -0
- torch-ext/img2gray/__init__.py +18 -0
- torch-ext/torch_binding.cpp +12 -0
- torch-ext/torch_binding.h +3 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.so filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.png filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            cmake
         | 
| 2 | 
            +
            .venv
         | 
| 3 | 
            +
            __pycache__
         | 
| 4 | 
            +
            *.pyc
         | 
| 5 | 
            +
            torch-ext/img2gray/*.abi3.so
         | 
| 6 | 
            +
            torch-ext/img2gray/_ops.py
         | 
| 7 | 
            +
            torch-ext/registration.h
         | 
| 8 | 
            +
            CMakeLists.txt
         | 
| 9 | 
            +
            pyproject.toml
         | 
| 10 | 
            +
            setup.py
         | 
    	
        build.toml
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [general]
         | 
| 2 | 
            +
            name = "img2gray"
         | 
| 3 | 
            +
            universal = false
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            [torch]
         | 
| 6 | 
            +
            src = [
         | 
| 7 | 
            +
              "torch-ext/torch_binding.cpp",
         | 
| 8 | 
            +
              "torch-ext/torch_binding.h"
         | 
| 9 | 
            +
            ]
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            [kernel.img2gray]
         | 
| 12 | 
            +
            backend = "cuda"
         | 
| 13 | 
            +
            depends = ["torch"]
         | 
| 14 | 
            +
            src = [
         | 
| 15 | 
            +
                "csrc/img2gray.cu",
         | 
| 16 | 
            +
            ]
         | 
    	
        csrc/img2gray.cu
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <cstdint>
         | 
| 2 | 
            +
            #include <torch/torch.h>
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            // Define a kernel to convert RGB to Grayscale
         | 
| 5 | 
            +
            __global__ void img2gray_kernel(const uint8_t* input, uint8_t* output, int width, int height) {
         | 
| 6 | 
            +
                int x = blockIdx.x * blockDim.x + threadIdx.x;
         | 
| 7 | 
            +
                int y = blockIdx.y * blockDim.y + threadIdx.y;
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                if (x < width && y < height) {
         | 
| 10 | 
            +
                    int idx = (y * width + x) * 3; // RGB has 3 channels
         | 
| 11 | 
            +
                    uint8_t r = input[idx];
         | 
| 12 | 
            +
                    uint8_t g = input[idx + 1];
         | 
| 13 | 
            +
                    uint8_t b = input[idx + 2];
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    // Convert to grayscale using luminosity method
         | 
| 16 | 
            +
                    uint8_t gray = static_cast<uint8_t>(0.21f * r + 0.72f * g + 0.07f * b);
         | 
| 17 | 
            +
                    output[y * width + x] = gray;
         | 
| 18 | 
            +
                }
         | 
| 19 | 
            +
            }
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            // Define a wrapper for this kernel to align with the PyTorch extension interface
         | 
| 23 | 
            +
            void img2gray_cuda(torch::Tensor input, torch::Tensor output) {
         | 
| 24 | 
            +
                const int width = input.size(1);
         | 
| 25 | 
            +
                const int height = input.size(0);
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                const dim3 blockSize(16, 16);
         | 
| 28 | 
            +
                const dim3 gridSize((width + blockSize.x - 1) / blockSize.x, (height + blockSize.y - 1) / blockSize.y);
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                img2gray_kernel<<<gridSize, blockSize>>>(
         | 
| 31 | 
            +
                    input.data_ptr<uint8_t>(),
         | 
| 32 | 
            +
                    output.data_ptr<uint8_t>(),
         | 
| 33 | 
            +
                    width,
         | 
| 34 | 
            +
                    height
         | 
| 35 | 
            +
                );
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                cudaDeviceSynchronize();
         | 
| 38 | 
            +
            }
         | 
    	
        flake.lock
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "nodes": {
         | 
| 3 | 
            +
                "flake-compat": {
         | 
| 4 | 
            +
                  "locked": {
         | 
| 5 | 
            +
                    "lastModified": 1747046372,
         | 
| 6 | 
            +
                    "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
         | 
| 7 | 
            +
                    "owner": "edolstra",
         | 
| 8 | 
            +
                    "repo": "flake-compat",
         | 
| 9 | 
            +
                    "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
         | 
| 10 | 
            +
                    "type": "github"
         | 
| 11 | 
            +
                  },
         | 
| 12 | 
            +
                  "original": {
         | 
| 13 | 
            +
                    "owner": "edolstra",
         | 
| 14 | 
            +
                    "repo": "flake-compat",
         | 
| 15 | 
            +
                    "type": "github"
         | 
| 16 | 
            +
                  }
         | 
| 17 | 
            +
                },
         | 
| 18 | 
            +
                "flake-compat_2": {
         | 
| 19 | 
            +
                  "locked": {
         | 
| 20 | 
            +
                    "lastModified": 1733328505,
         | 
| 21 | 
            +
                    "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
         | 
| 22 | 
            +
                    "owner": "edolstra",
         | 
| 23 | 
            +
                    "repo": "flake-compat",
         | 
| 24 | 
            +
                    "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
         | 
| 25 | 
            +
                    "type": "github"
         | 
| 26 | 
            +
                  },
         | 
| 27 | 
            +
                  "original": {
         | 
| 28 | 
            +
                    "owner": "edolstra",
         | 
| 29 | 
            +
                    "repo": "flake-compat",
         | 
| 30 | 
            +
                    "type": "github"
         | 
| 31 | 
            +
                  }
         | 
| 32 | 
            +
                },
         | 
| 33 | 
            +
                "flake-utils": {
         | 
| 34 | 
            +
                  "inputs": {
         | 
| 35 | 
            +
                    "systems": "systems"
         | 
| 36 | 
            +
                  },
         | 
| 37 | 
            +
                  "locked": {
         | 
| 38 | 
            +
                    "lastModified": 1731533236,
         | 
| 39 | 
            +
                    "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
         | 
| 40 | 
            +
                    "owner": "numtide",
         | 
| 41 | 
            +
                    "repo": "flake-utils",
         | 
| 42 | 
            +
                    "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
         | 
| 43 | 
            +
                    "type": "github"
         | 
| 44 | 
            +
                  },
         | 
| 45 | 
            +
                  "original": {
         | 
| 46 | 
            +
                    "owner": "numtide",
         | 
| 47 | 
            +
                    "repo": "flake-utils",
         | 
| 48 | 
            +
                    "type": "github"
         | 
| 49 | 
            +
                  }
         | 
| 50 | 
            +
                },
         | 
| 51 | 
            +
                "flake-utils_2": {
         | 
| 52 | 
            +
                  "inputs": {
         | 
| 53 | 
            +
                    "systems": "systems_2"
         | 
| 54 | 
            +
                  },
         | 
| 55 | 
            +
                  "locked": {
         | 
| 56 | 
            +
                    "lastModified": 1731533236,
         | 
| 57 | 
            +
                    "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
         | 
| 58 | 
            +
                    "owner": "numtide",
         | 
| 59 | 
            +
                    "repo": "flake-utils",
         | 
| 60 | 
            +
                    "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
         | 
| 61 | 
            +
                    "type": "github"
         | 
| 62 | 
            +
                  },
         | 
| 63 | 
            +
                  "original": {
         | 
| 64 | 
            +
                    "owner": "numtide",
         | 
| 65 | 
            +
                    "repo": "flake-utils",
         | 
| 66 | 
            +
                    "type": "github"
         | 
| 67 | 
            +
                  }
         | 
| 68 | 
            +
                },
         | 
| 69 | 
            +
                "hf-nix": {
         | 
| 70 | 
            +
                  "inputs": {
         | 
| 71 | 
            +
                    "flake-compat": "flake-compat_2",
         | 
| 72 | 
            +
                    "flake-utils": "flake-utils_2",
         | 
| 73 | 
            +
                    "nixpkgs": "nixpkgs"
         | 
| 74 | 
            +
                  },
         | 
| 75 | 
            +
                  "locked": {
         | 
| 76 | 
            +
                    "lastModified": 1750234878,
         | 
| 77 | 
            +
                    "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
         | 
| 78 | 
            +
                    "owner": "huggingface",
         | 
| 79 | 
            +
                    "repo": "hf-nix",
         | 
| 80 | 
            +
                    "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
         | 
| 81 | 
            +
                    "type": "github"
         | 
| 82 | 
            +
                  },
         | 
| 83 | 
            +
                  "original": {
         | 
| 84 | 
            +
                    "owner": "huggingface",
         | 
| 85 | 
            +
                    "repo": "hf-nix",
         | 
| 86 | 
            +
                    "type": "github"
         | 
| 87 | 
            +
                  }
         | 
| 88 | 
            +
                },
         | 
| 89 | 
            +
                "kernel-builder": {
         | 
| 90 | 
            +
                  "inputs": {
         | 
| 91 | 
            +
                    "flake-compat": "flake-compat",
         | 
| 92 | 
            +
                    "flake-utils": "flake-utils",
         | 
| 93 | 
            +
                    "hf-nix": "hf-nix",
         | 
| 94 | 
            +
                    "nixpkgs": [
         | 
| 95 | 
            +
                      "kernel-builder",
         | 
| 96 | 
            +
                      "hf-nix",
         | 
| 97 | 
            +
                      "nixpkgs"
         | 
| 98 | 
            +
                    ]
         | 
| 99 | 
            +
                  },
         | 
| 100 | 
            +
                  "locked": {
         | 
| 101 | 
            +
                    "lastModified": 1750790603,
         | 
| 102 | 
            +
                    "narHash": "sha256-m7FoTYWDV811Y7FiuJPa/uCOV63rf6LHxWportuI9h0=",
         | 
| 103 | 
            +
                    "owner": "huggingface",
         | 
| 104 | 
            +
                    "repo": "kernel-builder",
         | 
| 105 | 
            +
                    "rev": "37cad313efea84e213b2fc13b2ec808d273a126d",
         | 
| 106 | 
            +
                    "type": "github"
         | 
| 107 | 
            +
                  },
         | 
| 108 | 
            +
                  "original": {
         | 
| 109 | 
            +
                    "owner": "huggingface",
         | 
| 110 | 
            +
                    "repo": "kernel-builder",
         | 
| 111 | 
            +
                    "type": "github"
         | 
| 112 | 
            +
                  }
         | 
| 113 | 
            +
                },
         | 
| 114 | 
            +
                "nixpkgs": {
         | 
| 115 | 
            +
                  "locked": {
         | 
| 116 | 
            +
                    "lastModified": 1747820358,
         | 
| 117 | 
            +
                    "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
         | 
| 118 | 
            +
                    "owner": "danieldk",
         | 
| 119 | 
            +
                    "repo": "nixpkgs",
         | 
| 120 | 
            +
                    "rev": "d3c1681180717528068082103bf323147de6ab0b",
         | 
| 121 | 
            +
                    "type": "github"
         | 
| 122 | 
            +
                  },
         | 
| 123 | 
            +
                  "original": {
         | 
| 124 | 
            +
                    "owner": "danieldk",
         | 
| 125 | 
            +
                    "ref": "cudatoolkit-12.9-kernel-builder",
         | 
| 126 | 
            +
                    "repo": "nixpkgs",
         | 
| 127 | 
            +
                    "type": "github"
         | 
| 128 | 
            +
                  }
         | 
| 129 | 
            +
                },
         | 
| 130 | 
            +
                "root": {
         | 
| 131 | 
            +
                  "inputs": {
         | 
| 132 | 
            +
                    "kernel-builder": "kernel-builder"
         | 
| 133 | 
            +
                  }
         | 
| 134 | 
            +
                },
         | 
| 135 | 
            +
                "systems": {
         | 
| 136 | 
            +
                  "locked": {
         | 
| 137 | 
            +
                    "lastModified": 1681028828,
         | 
| 138 | 
            +
                    "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
         | 
| 139 | 
            +
                    "owner": "nix-systems",
         | 
| 140 | 
            +
                    "repo": "default",
         | 
| 141 | 
            +
                    "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
         | 
| 142 | 
            +
                    "type": "github"
         | 
| 143 | 
            +
                  },
         | 
| 144 | 
            +
                  "original": {
         | 
| 145 | 
            +
                    "owner": "nix-systems",
         | 
| 146 | 
            +
                    "repo": "default",
         | 
| 147 | 
            +
                    "type": "github"
         | 
| 148 | 
            +
                  }
         | 
| 149 | 
            +
                },
         | 
| 150 | 
            +
                "systems_2": {
         | 
| 151 | 
            +
                  "locked": {
         | 
| 152 | 
            +
                    "lastModified": 1681028828,
         | 
| 153 | 
            +
                    "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
         | 
| 154 | 
            +
                    "owner": "nix-systems",
         | 
| 155 | 
            +
                    "repo": "default",
         | 
| 156 | 
            +
                    "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
         | 
| 157 | 
            +
                    "type": "github"
         | 
| 158 | 
            +
                  },
         | 
| 159 | 
            +
                  "original": {
         | 
| 160 | 
            +
                    "owner": "nix-systems",
         | 
| 161 | 
            +
                    "repo": "default",
         | 
| 162 | 
            +
                    "type": "github"
         | 
| 163 | 
            +
                  }
         | 
| 164 | 
            +
                }
         | 
| 165 | 
            +
              },
         | 
| 166 | 
            +
              "root": "root",
         | 
| 167 | 
            +
              "version": 7
         | 
| 168 | 
            +
            }
         | 
    	
        flake.nix
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              description = "Flake for img2gray kernel";
         | 
| 3 | 
            +
             | 
| 4 | 
            +
              inputs = {
         | 
| 5 | 
            +
                kernel-builder.url = "github:huggingface/kernel-builder";
         | 
| 6 | 
            +
              };
         | 
| 7 | 
            +
             | 
| 8 | 
            +
              outputs =
         | 
| 9 | 
            +
                {
         | 
| 10 | 
            +
                  self,
         | 
| 11 | 
            +
                  kernel-builder,
         | 
| 12 | 
            +
                }:
         | 
| 13 | 
            +
                kernel-builder.lib.genFlakeOutputs {
         | 
| 14 | 
            +
                  path = ./.;
         | 
| 15 | 
            +
                  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
         | 
| 16 | 
            +
                };
         | 
| 17 | 
            +
            }
         | 
    	
        scripts/sanity.py
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import img2gray
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            print(dir(img2gray))
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            img = Image.open("/home/ubuntu/Projects/img2gray/kernel-builder-logo-color.png").convert("RGB")
         | 
| 10 | 
            +
            img = np.array(img)
         | 
| 11 | 
            +
            img_tensor = torch.from_numpy(img)
         | 
| 12 | 
            +
            print(img_tensor.shape)  # HWC
         | 
| 13 | 
            +
            img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).contiguous().cuda()  # BCHW
         | 
| 14 | 
            +
            print(img_tensor.shape)  # BCHW
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            gray_tensor = img2gray.img2gray(img_tensor).squeeze()
         | 
| 17 | 
            +
            print(gray_tensor.shape)  # B1HW
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # save the output image
         | 
| 20 | 
            +
            gray_img = gray_tensor.cpu().numpy()  # 1HW -> HW
         | 
| 21 | 
            +
            gray_img = Image.fromarray(gray_img.astype(np.uint8), mode="L")
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            gray_img.save("/home/ubuntu/Projects/img2gray/kernel-builder-logo-gray.png")
         | 
    	
        torch-ext/img2gray/__init__.py
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from ._ops import ops
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def img2gray(input: torch.Tensor) -> torch.Tensor:
         | 
| 6 | 
            +
                # we expect input to be in BCHW format
         | 
| 7 | 
            +
                batch, channels, height, width = input.shape
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                assert channels == 3, "Input image must have 3 channels (RGB)"
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                output = torch.empty((batch, 1, height, width), device=input.device, dtype=input.dtype)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                for b in range(batch):
         | 
| 14 | 
            +
                    single_image = input[b].permute(1, 2, 0).contiguous()  # HWC
         | 
| 15 | 
            +
                    single_output = output[b].reshape(height, width)  # HW
         | 
| 16 | 
            +
                    ops.img2gray(single_image, single_output)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                return output
         | 
    	
        torch-ext/torch_binding.cpp
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <torch/library.h>
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            #include "registration.h"
         | 
| 4 | 
            +
            #include "torch_binding.h"
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
         | 
| 8 | 
            +
                ops.def("img2gray(Tensor input, Tensor output) -> ()");
         | 
| 9 | 
            +
                ops.impl("img2gray", torch::kCUDA, &img2gray_cuda);
         | 
| 10 | 
            +
            }
         | 
| 11 | 
            +
              
         | 
| 12 | 
            +
            REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
         | 
    	
        torch-ext/torch_binding.h
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <torch/torch.h>
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            void img2gray_cuda(torch::Tensor input, torch::Tensor output);
         | 
