silveroxides's picture
Upload 5 files
487bef0 verified
# file: ComfyUI/custom_nodes/HybridFP8Loader/hybrid_fp8_ops.py
import torch
import comfy.ops
import comfy.model_management
TARGET_FP8_DTYPE = torch.float8_e4m3fn
_high_precision_keynames = []
def set_high_precision_keynames(keynames):
"""Sets the list of substrings used to identify high-precision layers."""
global _high_precision_keynames
_high_precision_keynames = keynames
print(f"[Hybrid FP8 Ops] High precision keynames set: {keynames}")
def get_hybrid_fp8_ops(scale_input_enabled=False):
"""
Dynamically creates and returns a hybrid operations class.
The 'scale_input_enabled' flag is now passed in from the loader.
"""
print(f"[Hybrid FP8 Ops] Configuring with scale_input_enabled: {scale_input_enabled}")
fp8_mat_mult_supported = comfy.model_management.supports_fp8_compute()
base_ops_class = comfy.ops.scaled_fp8_ops(
fp8_matrix_mult=fp8_mat_mult_supported,
scale_input=scale_input_enabled,
override_dtype=TARGET_FP8_DTYPE
)
class HybridScaledFP8Linear(base_ops_class.Linear):
"""
A Linear layer that intelligently handles both scaled FP8 and high-precision weights.
"""
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
is_excluded = any(name in prefix for name in _high_precision_keynames)
if is_excluded and _high_precision_keynames:
print(f"[Hybrid FP8 Ops] Intercepting high-precision layer: {prefix}")
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_tensor = state_dict.pop(weight_key, None)
bias_tensor = state_dict.pop(bias_key, None)
if weight_tensor is None:
missing_keys.append(weight_key)
else:
self.weight = torch.nn.Parameter(weight_tensor, requires_grad=False)
if bias_tensor is not None:
self.bias = torch.nn.Parameter(bias_tensor, requires_grad=False)
else:
self.bias = None
state_dict.pop(prefix + 'scale_weight', None)
# --- THIS IS THE FIX ---
# Corrected the syntax. dict.pop(key, [default]) takes at most 2 arguments.
state_dict.pop(prefix + 'scale_input', None)
self.scale_weight = None
self.scale_input = None
setattr(self, 'is_high_precision_layer', True)
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def forward(self, input):
if getattr(self, 'is_high_precision_layer', False):
weight_hp = self.weight.to(input.device, input.dtype)
bias_hp = self.bias.to(input.device, input.dtype) if self.bias is not None else None
return torch.nn.functional.linear(input, weight_hp, bias_hp)
else:
return super().forward(input)
class HybridOps(base_ops_class):
class Linear(HybridScaledFP8Linear):
pass
return HybridOps