File size: 6,804 Bytes
487bef0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# file: ComfyUI/custom_nodes/HybridFP8Loader/__init__.py

import folder_paths
import comfy.sd
from safetensors import safe_open
from . import hybrid_fp8_ops

# --- Exclusion Lists ---
DISTILL_LAYER_KEYNAMES_LARGE = ["distilled_guidance_layer", "final_layer", "img_in", "txt_in"]
NERF_LAYER_KEYNAMES_LARGE = ["distilled_guidance_layer", "img_in_patch", "nerf_blocks", "nerf_final_layer_conv", "nerf_image_embedder", "txt_in"]
DISTILL_LAYER_KEYNAMES_SMALL = ["distilled_guidance_layer"]
NERF_LAYER_KEYNAMES_SMALL = ["distilled_guidance_layer", "img_in_patch", "nerf_blocks", "nerf_final_layer_conv", "nerf_image_embedder"]
WAN_LAYER_KEYNAMES = ["patch_embedding", "text_embedding", "time_embedding", "time_projection", "head.head"]
QWEN_LAYER_KEYNAMES = ["time_text_embed", "img_in", "norm_out", "proj_out", "txt_in"]

def detect_fp8_optimizations(model_path):
    """

    Peeks into the safetensors file to check the shape of 'scaled_fp8' tensor.

    Returns True if input scaling should be enabled, False otherwise.

    """
    try:
        with safe_open(model_path, framework="pt", device="cpu") as f:
            if "scaled_fp8" in f.keys():
                scaled_fp8_tensor = f.get_tensor("scaled_fp8")
                if scaled_fp8_tensor.shape[0] == 0:
                    print("[Hybrid FP8 Loader] Scale Input model detected (scale_input enabled).")
                    return True
    except Exception as e:
        print(f"[Hybrid FP8 Loader] Warning: Could not inspect model file to determine FP8 type: {e}")

    print("[Hybrid FP8 Loader] Standard UNet-style model detected (scale_input disabled).")
    return False

def setup_hybrid_ops(model_path, chroma_hybrid_large, radiance_hybrid_large, chroma_hybrid_small, radiance_hybrid_small, wan, qwen):
    """A helper function to configure the hybrid ops based on user settings and model type."""
    excluded_layers = []
    if chroma_hybrid_large: excluded_layers.extend(DISTILL_LAYER_KEYNAMES_LARGE)
    if radiance_hybrid_large: excluded_layers.extend(NERF_LAYER_KEYNAMES_LARGE)
    if chroma_hybrid_small: excluded_layers.extend(DISTILL_LAYER_KEYNAMES_SMALL)
    if radiance_hybrid_small: excluded_layers.extend(NERF_LAYER_KEYNAMES_SMALL)
    if wan: excluded_layers.extend(WAN_LAYER_KEYNAMES)
    if qwen: excluded_layers.extend(QWEN_LAYER_KEYNAMES)

    hybrid_fp8_ops.set_high_precision_keynames(list(set(excluded_layers)))

    # --- THIS IS THE KEY LOGIC ---
    # Detect model type from the file and pass the correct flag to get_hybrid_fp8_ops
    scale_input_enabled = detect_fp8_optimizations(model_path)
    return hybrid_fp8_ops.get_hybrid_fp8_ops(scale_input_enabled=scale_input_enabled)

class ScaledFP8HybridUNetLoader:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model_name": (folder_paths.get_filename_list("unet"), ),
                "chroma_hybrid_large": ("BOOLEAN", {"default": False, "tooltip": "Only the larger Chroma hybrid models which might have some LoRA issues and can only load pruned flash-heun LoRA"}),
                "radiance_hybrid_large": ("BOOLEAN", {"default": False, "tooltip": "Only the larger Radiance hybrid models which might have some LoRA issues and can only load pruned flash-heun LoRA"}),
                "chroma_hybrid_small": ("BOOLEAN", {"default": False, "tooltip": "Use only with the smaller Chroma hybrid models. LoRA should no longer be an issue"}),
                "radiance_hybrid_small": ("BOOLEAN", {"default": False, "tooltip": "Use only with the smaller Radiance hybrid models. LoRA should no longer be an issue"}),
                "wan": ("BOOLEAN", {"default": False, "tooltip": "for selective fp8_scaled quantization of WAN 2.2 models"}),
                "qwen": ("BOOLEAN", {"default": False, "tooltip": "for selective fp8_scaled quantization of Qwen Image models"}),
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "load_unet"
    CATEGORY = "loaders/FP8"

    def load_unet(self, model_name, chroma_hybrid_large, radiance_hybrid_large, chroma_hybrid_small, radiance_hybrid_small, wan, qwen):
        unet_path = folder_paths.get_full_path("unet", model_name)
        ops = setup_hybrid_ops(unet_path, chroma_hybrid_large, radiance_hybrid_large, chroma_hybrid_small, radiance_hybrid_small, wan, qwen)
        model = comfy.sd.load_diffusion_model(unet_path, model_options={"custom_operations": ops})
        return (model,)

class ScaledFP8HybridCheckpointLoader:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
                "chroma_hybrid_large": ("BOOLEAN", {"default": False, "tooltip": "Only the larger Chroma hybrid models which might have some LoRA issues and can only load pruned flash-heun LoRA"}),
                "radiance_hybrid_large": ("BOOLEAN", {"default": False, "tooltip": "Only the larger Radiance hybrid models which might have some LoRA issues and can only load pruned flash-heun LoRA"}),
                "chroma_hybrid_small": ("BOOLEAN", {"default": False, "tooltip": "Use only with the smaller Chroma hybrid models. LoRA should no longer be an issue"}),
                "radiance_hybrid_small": ("BOOLEAN", {"default": False, "tooltip": "Use only with the smaller Radiance hybrid models. LoRA should no longer be an issue"}),
                "wan": ("BOOLEAN", {"default": False, "tooltip": "for selective fp8_scaled quantization of WAN 2.2 models"}),
                "qwen": ("BOOLEAN", {"default": False, "tooltip": "for selective fp8_scaled quantization of Qwen Image models"}),
            }
        }

    RETURN_TYPES = ("MODEL", "CLIP", "VAE")
    FUNCTION = "load_checkpoint"
    CATEGORY = "loaders/FP8"

    def load_checkpoint(self, ckpt_name, chroma_hybrid_large, radiance_hybrid_large, chroma_hybrid_small, radiance_hybrid_small, wan, qwen):
        ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
        ops = setup_hybrid_ops(ckpt_path, chroma_hybrid_large, radiance_hybrid_large, chroma_hybrid_small, radiance_hybrid_small, wan, qwen)
        out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": ops})
        return out[:3]

NODE_CLASS_MAPPINGS = {
    "ScaledFP8HybridUNetLoader": ScaledFP8HybridUNetLoader,
    "ScaledFP8HybridCheckpointLoader": ScaledFP8HybridCheckpointLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
    "ScaledFP8HybridUNetLoader": "Load FP8 Scaled Diffusion Model (Choose One)",
    "ScaledFP8HybridCheckpointLoader": "Load FP8 Scaled Ckpt (Choose One)",
}