File size: 3,326 Bytes
487bef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# file: ComfyUI/custom_nodes/HybridFP8Loader/hybrid_fp8_ops.py

import torch
import comfy.ops
import comfy.model_management

TARGET_FP8_DTYPE = torch.float8_e4m3fn
_high_precision_keynames = []

def set_high_precision_keynames(keynames):
    """Sets the list of substrings used to identify high-precision layers."""
    global _high_precision_keynames
    _high_precision_keynames = keynames
    print(f"[Hybrid FP8 Ops] High precision keynames set: {keynames}")

def get_hybrid_fp8_ops(scale_input_enabled=False):
    """

    Dynamically creates and returns a hybrid operations class.

    The 'scale_input_enabled' flag is now passed in from the loader.

    """
    print(f"[Hybrid FP8 Ops] Configuring with scale_input_enabled: {scale_input_enabled}")

    fp8_mat_mult_supported = comfy.model_management.supports_fp8_compute()

    base_ops_class = comfy.ops.scaled_fp8_ops(
        fp8_matrix_mult=fp8_mat_mult_supported,
        scale_input=scale_input_enabled,
        override_dtype=TARGET_FP8_DTYPE
    )

    class HybridScaledFP8Linear(base_ops_class.Linear):
        """

        A Linear layer that intelligently handles both scaled FP8 and high-precision weights.

        """
        def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
            is_excluded = any(name in prefix for name in _high_precision_keynames)

            if is_excluded and _high_precision_keynames:
                print(f"[Hybrid FP8 Ops] Intercepting high-precision layer: {prefix}")

                weight_key = prefix + 'weight'
                bias_key = prefix + 'bias'

                weight_tensor = state_dict.pop(weight_key, None)
                bias_tensor = state_dict.pop(bias_key, None)

                if weight_tensor is None:
                    missing_keys.append(weight_key)
                else:
                    self.weight = torch.nn.Parameter(weight_tensor, requires_grad=False)

                if bias_tensor is not None:
                    self.bias = torch.nn.Parameter(bias_tensor, requires_grad=False)
                else:
                    self.bias = None

                state_dict.pop(prefix + 'scale_weight', None)
                # --- THIS IS THE FIX ---
                # Corrected the syntax. dict.pop(key, [default]) takes at most 2 arguments.
                state_dict.pop(prefix + 'scale_input', None)

                self.scale_weight = None
                self.scale_input = None

                setattr(self, 'is_high_precision_layer', True)
            else:
                super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

        def forward(self, input):
            if getattr(self, 'is_high_precision_layer', False):
                weight_hp = self.weight.to(input.device, input.dtype)
                bias_hp = self.bias.to(input.device, input.dtype) if self.bias is not None else None
                return torch.nn.functional.linear(input, weight_hp, bias_hp)
            else:
                return super().forward(input)

    class HybridOps(base_ops_class):
        class Linear(HybridScaledFP8Linear):
            pass

    return HybridOps