File size: 7,222 Bytes
babafa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import gc
from typing import Tuple
import copy
import torch
import tqdm
def cleanup_memory():
gc.collect()
torch.cuda.empty_cache()
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
if tensor.numel() == 0:
# Deal with empty tensors (triggered by empty MoE experts)
min_val, max_val = (
torch.tensor(-16.0, dtype=tensor.dtype),
torch.tensor(16.0, dtype=tensor.dtype),
)
else:
min_val, max_val = tensor.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
"""Quantizes a floating-point tensor to FP8 (E4M3 format) using static scaling.
Performs uniform quantization of the input tensor by:
1. Scaling the tensor values using the provided inverse scale factor
2. Clamping values to the representable range of FP8 E4M3 format
3. Converting to FP8 data type
Args:
tensor (torch.Tensor): Input tensor to be quantized (any floating-point dtype)
inv_scale (float): Inverse of the quantization scale factor (1/scale)
(Must be pre-calculated based on tensor statistics)
Returns:
torch.Tensor: Quantized tensor in torch.float8_e4m3fn format
Note:
- Uses the E4M3 format (4 exponent bits, 3 mantissa bits, no infinity/nan)
- This is a static quantization (scale factor must be pre-determined)
- For dynamic quantization, see per_tensor_quantize()
"""
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
"""Performs FP8 GEMM (General Matrix Multiplication) operation with optional native hardware support.
Args:
A (torch.Tensor): Input tensor A (FP8 or other dtype)
A_scale (torch.Tensor/float): Scale factor for tensor A
B (torch.Tensor): Input tensor B (FP8 or other dtype)
B_scale (torch.Tensor/float): Scale factor for tensor B
bias (torch.Tensor/None): Optional bias tensor
out_dtype (torch.dtype): Output data type
native_fp8_support (bool): Whether to use hardware-accelerated FP8 operations
Returns:
torch.Tensor: Result of GEMM operation
"""
if A.numel() == 0:
# Deal with empty tensors (triggeted by empty MoE experts)
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
if native_fp8_support:
need_reshape = A.dim() == 3
if need_reshape:
batch_size = A.shape[0]
A_input = A.reshape(-1, A.shape[-1])
else:
batch_size = None
A_input = A
output = torch._scaled_mm(
A_input,
B.t(),
out_dtype=out_dtype,
scale_a=A_scale,
scale_b=B_scale,
bias=bias,
)
if need_reshape:
output = output.reshape(
batch_size, output.shape[0] // batch_size, output.shape[1]
)
else:
output = torch.nn.functional.linear(
A.to(out_dtype) * A_scale,
B.to(out_dtype) * B_scale.to(out_dtype),
bias=bias,
)
return output
def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, new_module)
# Class responsible for quantizing weights
class FP8DynamicLinear(torch.nn.Module):
def __init__(
self,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
native_fp8_support: bool = False,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.bias = bias
self.native_fp8_support = native_fp8_support
self.dtype = dtype
# @torch.compile
def forward(self, x):
if x.dtype !=self.dtype:
x = x.to(self.dtype)
qinput, x_scale = per_tensor_quantize(x)
output = fp8_gemm(
A=qinput,
A_scale=x_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
native_fp8_support=self.native_fp8_support,
)
return output
def FluxFp8GeMMProcessor(model: torch.nn.Module):
"""Processes a PyTorch model to convert eligible Linear layers to FP8 precision.
This function performs the following operations:
1. Checks for native FP8 support on the current GPU
2. Identifies target Linear layers in transformer blocks
3. Quantizes weights to FP8 format
4. Replaces original Linear layers with FP8DynamicLinear versions
5. Performs memory cleanup
Args:
model (torch.nn.Module): The neural network model to be processed.
Should contain transformer blocks with Linear layers.
"""
native_fp8_support = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
)
named_modules = list(model.named_modules())
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights to fp8"):
if isinstance(linear, torch.nn.Linear) and "blocks" in name:
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
quant_linear = FP8DynamicLinear(
weight=quant_weight,
weight_scale=weight_scale,
bias=bias,
native_fp8_support=native_fp8_support,
dtype=linear.weight.dtype
)
replace_module(model, name, quant_linear)
del linear.weight
del linear.bias
del linear
cleanup_memory() |