Spaces:
Running
on
Zero
Running
on
Zero
| # file: ComfyUI/custom_nodes/HybridFP8Loader/hybrid_fp8_ops.py | |
| import torch | |
| import comfy.ops | |
| import comfy.model_management | |
| TARGET_FP8_DTYPE = torch.float8_e4m3fn | |
| _high_precision_keynames = [] | |
| def set_high_precision_keynames(keynames): | |
| """Sets the list of substrings used to identify high-precision layers.""" | |
| global _high_precision_keynames | |
| _high_precision_keynames = keynames | |
| print(f"[Hybrid FP8 Ops] High precision keynames set: {keynames}") | |
| def get_hybrid_fp8_ops(scale_input_enabled=False): | |
| """ | |
| Dynamically creates and returns a hybrid operations class. | |
| The 'scale_input_enabled' flag is now passed in from the loader. | |
| """ | |
| print(f"[Hybrid FP8 Ops] Configuring with scale_input_enabled: {scale_input_enabled}") | |
| fp8_mat_mult_supported = comfy.model_management.supports_fp8_compute() | |
| base_ops_class = comfy.ops.scaled_fp8_ops( | |
| fp8_matrix_mult=fp8_mat_mult_supported, | |
| scale_input=scale_input_enabled, | |
| override_dtype=TARGET_FP8_DTYPE | |
| ) | |
| class HybridScaledFP8Linear(base_ops_class.Linear): | |
| """ | |
| A Linear layer that intelligently handles both scaled FP8 and high-precision weights. | |
| """ | |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
| is_excluded = any(name in prefix for name in _high_precision_keynames) | |
| if is_excluded and _high_precision_keynames: | |
| print(f"[Hybrid FP8 Ops] Intercepting high-precision layer: {prefix}") | |
| weight_key = prefix + 'weight' | |
| bias_key = prefix + 'bias' | |
| weight_tensor = state_dict.pop(weight_key, None) | |
| bias_tensor = state_dict.pop(bias_key, None) | |
| if weight_tensor is None: | |
| missing_keys.append(weight_key) | |
| else: | |
| self.weight = torch.nn.Parameter(weight_tensor, requires_grad=False) | |
| if bias_tensor is not None: | |
| self.bias = torch.nn.Parameter(bias_tensor, requires_grad=False) | |
| else: | |
| self.bias = None | |
| state_dict.pop(prefix + 'scale_weight', None) | |
| # --- THIS IS THE FIX --- | |
| # Corrected the syntax. dict.pop(key, [default]) takes at most 2 arguments. | |
| state_dict.pop(prefix + 'scale_input', None) | |
| self.scale_weight = None | |
| self.scale_input = None | |
| setattr(self, 'is_high_precision_layer', True) | |
| else: | |
| super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
| def forward(self, input): | |
| if getattr(self, 'is_high_precision_layer', False): | |
| weight_hp = self.weight.to(input.device, input.dtype) | |
| bias_hp = self.bias.to(input.device, input.dtype) if self.bias is not None else None | |
| return torch.nn.functional.linear(input, weight_hp, bias_hp) | |
| else: | |
| return super().forward(input) | |
| class HybridOps(base_ops_class): | |
| class Linear(HybridScaledFP8Linear): | |
| pass | |
| return HybridOps | |