Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import cupy as cp | |
| import math | |
| from os import path | |
| class Quantizer(nn.Module): | |
| def __init__(self, config, codebook): | |
| super().__init__() | |
| self.nsq, nc, self.d = codebook.shape | |
| self.b = int(math.log2(nc)) | |
| head_dim = config.hidden_size // config.num_attention_heads | |
| self.head_dim = head_dim | |
| qpk = config.num_attention_heads // config.num_key_value_heads | |
| self.window_length = getattr(config, 'window_length', 32) | |
| self.register_buffer('codebook', codebook) | |
| with open(path.join(path.dirname(__file__), "quantize.cu"), "r") as f: | |
| kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)) | |
| self._quantize = cp.RawKernel( | |
| kernel_code, | |
| 'quantize', | |
| backend="nvrtc" | |
| ) | |
| with open(path.join(path.dirname(__file__), "dequantize.cu"), "r") as f: | |
| kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)) | |
| self._dequantize = cp.RawKernel( | |
| kernel_code, | |
| 'dequantize', | |
| backend="nvrtc" | |
| ) | |
| with open(path.join(path.dirname(__file__), "dequantize_rope.cu"), "r") as f: | |
| kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)) | |
| self._dequantize_rope = cp.RawKernel( | |
| kernel_code, | |
| 'dequantize_rope', | |
| backend="nvrtc" | |
| ) | |
| with open(path.join(path.dirname(__file__), "fused_rope_mult.cu"), "r") as f: | |
| kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)) | |
| self._fused_rope_mult = cp.RawKernel( | |
| kernel_code, | |
| 'fused_rope_mult', | |
| backend="nvrtc" | |
| ) | |
| with open(path.join(path.dirname(__file__), "fused_rope_pos_mult_mqa.cu"), "r") as f: | |
| kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)).replace('__ROPE_THETA__', str(config.rope_theta)) | |
| self._fused_rope_pos_mult = cp.RawKernel( | |
| kernel_code, | |
| 'fused_rope_pos_mult', | |
| backend="nvrtc" | |
| ) | |
| with open(path.join(path.dirname(__file__), "fused_mult_len.cu"), "r") as f: | |
| kernel_code = f.read().replace('__NSQ__', str(self.nsq)).replace('__B__', str(self.b)).replace('__D__', str(self.d)).replace('__HEAD_DIM__', str(head_dim)).replace('__QPK__', str(qpk)) | |
| self._fused_mult = cp.RawKernel( | |
| kernel_code, | |
| 'fused_mult', | |
| backend="nvrtc" | |
| ) | |
| def quantize(self, x): | |
| n = x.numel() // x.shape[-1] | |
| codes = torch.empty(self.nsq, n, dtype=torch.uint8, device=x.device) | |
| blocks_per_grid = (self.nsq, ) | |
| threads_per_block = (1024, ) | |
| self._quantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
| self.codebook.data_ptr(), | |
| x.data_ptr(), | |
| codes.data_ptr(), | |
| n | |
| ]) | |
| return codes | |
| def dequantize(self, codes): | |
| n = codes.numel() // codes.shape[0] | |
| x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device) | |
| blocks_per_grid = (self.nsq, ) | |
| threads_per_block = (1024, ) | |
| self._dequantize(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
| self.codebook.data_ptr(), | |
| codes.data_ptr(), | |
| x.data_ptr(), | |
| n | |
| ]) | |
| return x | |
| def dequantize_rope(self, codes): | |
| _, batch_size, seq_len = codes.shape | |
| n = batch_size * seq_len | |
| x = torch.zeros(n, self.nsq * self.d, dtype=torch.float16, device=codes.device) | |
| blocks_per_grid = (self.nsq, ) | |
| threads_per_block = (1024, ) | |
| self._dequantize_rope(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
| self.codebook.data_ptr(), | |
| codes.data_ptr(), | |
| x.data_ptr(), | |
| batch_size, seq_len | |
| ]) | |
| return x | |
| def fused_rope_mult(self, codes, queries): | |
| _, batch_size, k_len = codes.shape | |
| _, n_heads, q_len, _ = queries.shape | |
| out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float16, device=codes.device) | |
| blocks_per_grid = (self.nsq, ) | |
| threads_per_block = (1024, ) | |
| self._fused_rope_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
| self.codebook.data_ptr(), | |
| codes.data_ptr(), | |
| queries.data_ptr(), | |
| out.data_ptr(), | |
| batch_size, q_len, k_len | |
| ]) | |
| return out | |
| def fused_rope_pos_mult(self, codes, queries, position_ids): | |
| _, batch_size, k_len = codes.shape | |
| _, n_heads, q_len, _ = queries.shape | |
| position_offsets = position_ids[:, -1] - k_len + 1 | |
| out = torch.zeros(batch_size, n_heads, q_len, k_len, dtype=torch.float32, device=codes.device) | |
| blocks_per_grid = (self.nsq, ) | |
| threads_per_block = (1024, ) | |
| self._fused_rope_pos_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
| self.codebook.data_ptr(), | |
| codes.data_ptr(), | |
| position_offsets.data_ptr(), | |
| queries.data_ptr(), | |
| out.data_ptr(), | |
| batch_size, q_len, k_len | |
| ]) | |
| return out | |
| def fused_mult(self, codes, weights, skip_last=0): | |
| batch_size, n_heads, q_len, k_len = weights.shape | |
| out = torch.zeros(batch_size, n_heads, q_len, self.head_dim, dtype=torch.float16, device=codes.device) | |
| blocks_per_grid = (self.nsq, ) | |
| threads_per_block = (min(1024, batch_size), ) | |
| self._fused_mult(grid=blocks_per_grid, block=threads_per_block, shared_mem=(2 ** self.b) * self.d * 2, args=[ | |
| self.codebook.data_ptr(), | |
| codes.data_ptr(), | |
| weights.data_ptr(), | |
| out.data_ptr(), | |
| batch_size, q_len, k_len, k_len - skip_last | |
| ]) | |
| return out | |