Spaces:
Runtime error
Runtime error
| #include <cuda_fp16.h> | |
| typedef unsigned char uint8_t; | |
| // extern "C" | |
| // __global__ void quantize( | |
| // const half* __restrict__ codebook, // nsq x 2^b x d | |
| // const half* __restrict__ vectors, // n x (nsq * d) | |
| // uint8_t* __restrict__ codes, // nsq x n | |
| // int n | |
| // ) { | |
| // extern __shared__ volatile half centroids[]; // 2^b x d | |
| // const int sq_id = blockIdx.x; | |
| // const int thread_id = threadIdx.x; | |
| // const int n_threads = blockDim.x; | |
| // const int n_floats_per_sq = (1 << __B__) * __D__; | |
| // #pragma unroll | |
| // for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { | |
| // centroids[i] = codebook[sq_id * n_floats_per_sq + i]; | |
| // } | |
| // __syncthreads(); | |
| // half subvector[__D__]; | |
| // for (int i = thread_id; i < n; i += n_threads) { | |
| // #pragma unroll | |
| // for (int j = 0; j < __D__; ++j) { | |
| // subvector[j] = vectors[(i * __NSQ__ + sq_id) * __D__ + j]; | |
| // } | |
| // float min_dist = 1 << 16; | |
| // uint8_t min_idx; | |
| // #pragma unroll | |
| // for (int j = 0; j < (1 << __B__); ++j) { | |
| // float dist = 0; | |
| // #pragma unroll | |
| // for (int k = 0; k < __D__; ++k) { | |
| // float tmp = __half2float(subvector[k]) - __half2float(centroids[j * __D__ + k]); | |
| // dist += tmp * tmp; | |
| // } | |
| // min_dist = (dist <= min_dist) ? dist : min_dist; | |
| // min_idx = (dist == min_dist) ? j : min_idx; | |
| // } | |
| // // printf("%d %d %d %d\n", sq_id, n, i, min_idx); | |
| // codes[sq_id * n + i] = min_idx; | |
| // } | |
| // } | |
| extern "C" | |
| __global__ void dequantize( | |
| const half* __restrict__ codebook, // nsq x 2^b x d | |
| const uint8_t* __restrict__ codes, // nsq x n | |
| half* __restrict__ vectors, // n x (nsq x d) | |
| int n | |
| ) { | |
| extern __shared__ volatile half centroids[]; // 2^b x d | |
| const int sq_id = blockIdx.x; | |
| const int thread_id = threadIdx.x; | |
| const int n_threads = blockDim.x; | |
| const int n_floats_per_sq = (1 << __B__) * __D__; | |
| #pragma unroll | |
| for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { | |
| centroids[i] = codebook[sq_id * n_floats_per_sq + i]; | |
| } | |
| __syncthreads(); | |
| for (int i = thread_id; i < n; i += n_threads) { | |
| uint8_t code = codes[sq_id * n + i]; | |
| #pragma unroll | |
| for (int dim = 0; dim < __D__; ++dim) { | |
| vectors[(i * __NSQ__ + sq_id) * __D__ + dim] = centroids[__D__ * code + dim]; | |
| // atomicAdd(vectors + (i * __NSQ__ + sq_id) * __D__ + dim, centroids[__D__ * code + dim]); | |
| } | |
| } | |
| } |