File size: 7,222 Bytes
babafa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gc
from typing import Tuple
import copy
import torch
import tqdm


def cleanup_memory():
    gc.collect()
    torch.cuda.empty_cache()


def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
    """Quantize a tensor using per-tensor static scaling factor.
    Args:
        tensor: The input tensor.
    """
    finfo = torch.finfo(torch.float8_e4m3fn)
    # Calculate the scale as dtype max divided by absmax.
    # Since .abs() creates a new tensor, we use aminmax to get
    # the min and max first and then calculate the absmax.
    if tensor.numel() == 0:
        # Deal with empty tensors (triggered by empty MoE experts)
        min_val, max_val = (
            torch.tensor(-16.0, dtype=tensor.dtype),
            torch.tensor(16.0, dtype=tensor.dtype),
        )
    else:
        min_val, max_val = tensor.aminmax()
    amax = torch.maximum(min_val.abs(), max_val.abs())
    scale = finfo.max / amax.clamp(min=1e-12)
    # scale and clamp the tensor to bring it to
    # the representative range of float8 data type
    # (as default cast is unsaturated)
    qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
    # Return both float8 data and the inverse scale (as float),
    # as both required as inputs to torch._scaled_mm
    qweight = qweight.to(torch.float8_e4m3fn)
    scale = scale.float().reciprocal()
    return qweight, scale


def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
    """Quantizes a floating-point tensor to FP8 (E4M3 format) using static scaling.
    
    Performs uniform quantization of the input tensor by:
    1. Scaling the tensor values using the provided inverse scale factor
    2. Clamping values to the representable range of FP8 E4M3 format
    3. Converting to FP8 data type
    
    Args:
        tensor (torch.Tensor): Input tensor to be quantized (any floating-point dtype)
        inv_scale (float): Inverse of the quantization scale factor (1/scale)
                         (Must be pre-calculated based on tensor statistics)
    
    Returns:
        torch.Tensor: Quantized tensor in torch.float8_e4m3fn format
    
    Note:
        - Uses the E4M3 format (4 exponent bits, 3 mantissa bits, no infinity/nan)
        - This is a static quantization (scale factor must be pre-determined)
        - For dynamic quantization, see per_tensor_quantize()
    """
    finfo = torch.finfo(torch.float8_e4m3fn)
    qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
    return qweight.to(torch.float8_e4m3fn)


def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
    """Performs FP8 GEMM (General Matrix Multiplication) operation with optional native hardware support.
    Args:
        A (torch.Tensor): Input tensor A (FP8 or other dtype)
        A_scale (torch.Tensor/float): Scale factor for tensor A
        B (torch.Tensor): Input tensor B (FP8 or other dtype)
        B_scale (torch.Tensor/float): Scale factor for tensor B
        bias (torch.Tensor/None): Optional bias tensor
        out_dtype (torch.dtype): Output data type
        native_fp8_support (bool): Whether to use hardware-accelerated FP8 operations
        
    Returns:
        torch.Tensor: Result of GEMM operation 
    """
    if A.numel() == 0:
        # Deal with empty tensors (triggeted by empty MoE experts)
        return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)

    if native_fp8_support:
        need_reshape = A.dim() == 3
        if need_reshape:
            batch_size = A.shape[0]
            A_input = A.reshape(-1, A.shape[-1])
        else:
            batch_size = None
            A_input = A
        output = torch._scaled_mm(
            A_input,
            B.t(),
            out_dtype=out_dtype,
            scale_a=A_scale,
            scale_b=B_scale,
            bias=bias,
        )
        if need_reshape:
            output = output.reshape(
                batch_size, output.shape[0] // batch_size, output.shape[1]
            )
    else:
        output = torch.nn.functional.linear(
            A.to(out_dtype) * A_scale,
            B.to(out_dtype) * B_scale.to(out_dtype),
            bias=bias,
        )

    return output

def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
    if "." in name:
        parent_name = name.rsplit(".", 1)[0]
        child_name = name[len(parent_name) + 1:]
        parent = model.get_submodule(parent_name)
    else:
        parent_name = ""
        parent = model
        child_name = name
    setattr(parent, child_name, new_module)


# Class responsible for quantizing weights
class FP8DynamicLinear(torch.nn.Module):
    def __init__(
            self,
            weight: torch.Tensor,
            weight_scale: torch.Tensor,
            bias: torch.nn.Parameter,
            native_fp8_support: bool = False,
            dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.weight = torch.nn.Parameter(weight, requires_grad=False)
        self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
        self.bias = bias
        self.native_fp8_support = native_fp8_support
        self.dtype = dtype

    # @torch.compile
    def forward(self, x):
        if x.dtype !=self.dtype:
            x = x.to(self.dtype)
        qinput, x_scale = per_tensor_quantize(x)
        output = fp8_gemm(
            A=qinput,
            A_scale=x_scale,
            B=self.weight,
            B_scale=self.weight_scale,
            bias=self.bias,
            out_dtype=x.dtype,
            native_fp8_support=self.native_fp8_support,
        )
        return output


def FluxFp8GeMMProcessor(model: torch.nn.Module):
    """Processes a PyTorch model to convert eligible Linear layers to FP8 precision.
    
    This function performs the following operations:
    1. Checks for native FP8 support on the current GPU
    2. Identifies target Linear layers in transformer blocks
    3. Quantizes weights to FP8 format
    4. Replaces original Linear layers with FP8DynamicLinear versions
    5. Performs memory cleanup
    
    Args:
        model (torch.nn.Module): The neural network model to be processed.
                                Should contain transformer blocks with Linear layers.
    """
    native_fp8_support = (
            torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
    )
    named_modules = list(model.named_modules())
    for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights to fp8"):
        if isinstance(linear, torch.nn.Linear) and "blocks" in name:
            quant_weight, weight_scale = per_tensor_quantize(linear.weight)
            bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
            quant_linear = FP8DynamicLinear(
                weight=quant_weight,
                weight_scale=weight_scale,
                bias=bias,
                native_fp8_support=native_fp8_support,
                dtype=linear.weight.dtype
            )
            replace_module(model, name, quant_linear)
            del linear.weight
            del linear.bias
            del linear
    cleanup_memory()