Update handler.py
Browse files- handler.py +47 -67
 
    	
        handler.py
    CHANGED
    
    | 
         @@ -1,20 +1,15 @@ 
     | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
            -
            from typing import Any, Dict 
     | 
| 3 | 
         
             
            from PIL import Image
         
     | 
| 4 | 
         
             
            import torch
         
     | 
| 5 | 
         
             
            from diffusers import FluxPipeline
         
     | 
| 6 | 
         
             
            from huggingface_inference_toolkit.logging import logger
         
     | 
| 7 | 
         
             
            from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
         
     | 
| 8 | 
         
            -
            from torchao.quantization import autoquant
         
     | 
| 9 | 
         
             
            import time
         
     | 
| 10 | 
         
            -
            import  
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
            torch.set_float32_matmul_precision("high")
         
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
            import torch._dynamo
         
     | 
| 17 | 
         
            -
            torch._dynamo.config.suppress_errors = False # for debugging
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            class EndpointHandler:
         
     | 
| 20 | 
         
             
                def __init__(self, path=""):
         
     | 
| 
         @@ -22,75 +17,60 @@ class EndpointHandler: 
     | 
|
| 22 | 
         
             
                        "NoMoreCopyrightOrg/flux-dev",
         
     | 
| 23 | 
         
             
                        torch_dtype=torch.bfloat16,
         
     | 
| 24 | 
         
             
                    ).to("cuda")
         
     | 
| 25 | 
         
            -
                     
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
                     
     | 
| 29 | 
         
            -
                     
     | 
| 30 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 31 | 
         
             
                    apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
         
     | 
| 
         | 
|
| 32 | 
         
             
                    self.pipe.transformer = torch.compile(
         
     | 
| 33 | 
         
             
                        self.pipe.transformer, mode="max-autotune-no-cudagraphs",
         
     | 
| 34 | 
         
             
                    )
         
     | 
| 35 | 
         
             
                    self.pipe.vae = torch.compile(
         
     | 
| 36 | 
         
             
                        self.pipe.vae, mode="max-autotune-no-cudagraphs",
         
     | 
| 37 | 
         
             
                    )
         
     | 
| 38 | 
         
            -
                    self.pipe.transformer = autoquant(self.pipe.transformer, error_on_unseen=False)
         
     | 
| 39 | 
         
            -
                    self.pipe.vae = autoquant(self.pipe.vae, error_on_unseen=False)
         
     | 
| 40 | 
         
            -
                    
         
     | 
| 41 | 
         
            -
                    gc.collect()
         
     | 
| 42 | 
         
            -
                    torch.cuda.empty_cache()
         
     | 
| 43 | 
         | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
                     
     | 
| 46 | 
         
            -
                    self.pipe("Hello world!") # Warm-up for compiling
         
     | 
| 47 | 
         
            -
                    end_time = time.time()
         
     | 
| 48 | 
         
            -
                    time_taken = end_time - start_time
         
     | 
| 49 | 
         
            -
                    print(f"Time taken: {time_taken:.2f} seconds")
         
     | 
| 50 | 
         
            -
                    self.record=0
         
     | 
| 51 | 
         | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 55 | 
         | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
                            prompt = data.pop("inputs")
         
     | 
| 58 | 
         
            -
                        elif "prompt" in data and isinstance(data["prompt"], str):
         
     | 
| 59 | 
         
            -
                            prompt = data.pop("prompt")
         
     | 
| 60 | 
         
            -
                        else:
         
     | 
| 61 | 
         
            -
                            raise ValueError(
         
     | 
| 62 | 
         
            -
                                "Provided input body must contain either the key `inputs` or `prompt` with the"
         
     | 
| 63 | 
         
            -
                                " prompt to use for the image generation, and it needs to be a non-empty string."
         
     | 
| 64 | 
         
            -
                            )
         
     | 
| 65 | 
         
            -
                        if prompt=="get_queue":
         
     | 
| 66 | 
         
            -
                            return self.record
         
     | 
| 67 | 
         
            -
                        parameters = data.pop("parameters", {})
         
     | 
| 68 | 
         | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
                        guidance_scale = parameters.get("guidance", 3.5)
         
     | 
| 74 | 
         | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
             
     | 
| 80 | 
         
            -
                         
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
             
     | 
| 83 | 
         
            -
             
     | 
| 84 | 
         
            -
             
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
             
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
             
     | 
| 
         | 
|
| 89 | 
         
             
                        time_taken = end_time - start_time
         
     | 
| 90 | 
         
             
                        print(f"Time taken: {time_taken:.2f} seconds")
         
     | 
| 91 | 
         
            -
                        self.record-=1
         
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
             
                        return result
         
     | 
| 94 | 
         
            -
                     
     | 
| 95 | 
         
            -
                        print(e)
         
     | 
| 96 | 
         
            -
                        return None
         
     | 
| 
         | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
            +
            from typing import Any, Dict
         
     | 
| 3 | 
         
             
            from PIL import Image
         
     | 
| 4 | 
         
             
            import torch
         
     | 
| 5 | 
         
             
            from diffusers import FluxPipeline
         
     | 
| 6 | 
         
             
            from huggingface_inference_toolkit.logging import logger
         
     | 
| 7 | 
         
             
            from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
         
     | 
| 
         | 
|
| 8 | 
         
             
            import time
         
     | 
| 9 | 
         
            +
            import torch.distributed as dist
         
     | 
| 10 | 
         
            +
            from para_attn.context_parallel import init_context_parallel_mesh
         
     | 
| 11 | 
         
            +
            from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
         
     | 
| 12 | 
         
            +
            from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
             
            class EndpointHandler:
         
     | 
| 15 | 
         
             
                def __init__(self, path=""):
         
     | 
| 
         | 
|
| 17 | 
         
             
                        "NoMoreCopyrightOrg/flux-dev",
         
     | 
| 18 | 
         
             
                        torch_dtype=torch.bfloat16,
         
     | 
| 19 | 
         
             
                    ).to("cuda")
         
     | 
| 20 | 
         
            +
                    mesh = init_context_parallel_mesh(
         
     | 
| 21 | 
         
            +
                        self.pipe.device.type,
         
     | 
| 22 | 
         
            +
                        max_ring_dim_size=2,
         
     | 
| 23 | 
         
            +
                    )
         
     | 
| 24 | 
         
            +
                    parallelize_pipe(
         
     | 
| 25 | 
         
            +
                        self.pipe,
         
     | 
| 26 | 
         
            +
                        mesh=mesh,
         
     | 
| 27 | 
         
            +
                    )
         
     | 
| 28 | 
         
            +
                    parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
         
     | 
| 29 | 
         
             
                    apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
         
     | 
| 30 | 
         
            +
                    torch._inductor.config.reorder_for_compute_comm_overlap = True
         
     | 
| 31 | 
         
             
                    self.pipe.transformer = torch.compile(
         
     | 
| 32 | 
         
             
                        self.pipe.transformer, mode="max-autotune-no-cudagraphs",
         
     | 
| 33 | 
         
             
                    )
         
     | 
| 34 | 
         
             
                    self.pipe.vae = torch.compile(
         
     | 
| 35 | 
         
             
                        self.pipe.vae, mode="max-autotune-no-cudagraphs",
         
     | 
| 36 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 37 | 
         | 
| 38 | 
         
            +
                def __call__(self, data: Dict[str, Any]) -> str:
         
     | 
| 39 | 
         
            +
                    logger.info(f"Received incoming request with {data=}")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 40 | 
         | 
| 41 | 
         
            +
                    if "inputs" in data and isinstance(data["inputs"], str):
         
     | 
| 42 | 
         
            +
                        prompt = data.pop("inputs")
         
     | 
| 43 | 
         
            +
                    elif "prompt" in data and isinstance(data["prompt"], str):
         
     | 
| 44 | 
         
            +
                        prompt = data.pop("prompt")
         
     | 
| 45 | 
         
            +
                    else:
         
     | 
| 46 | 
         
            +
                        raise ValueError(
         
     | 
| 47 | 
         
            +
                            "Provided input body must contain either the key `inputs` or `prompt` with the"
         
     | 
| 48 | 
         
            +
                            " prompt to use for the image generation, and it needs to be a non-empty string."
         
     | 
| 49 | 
         
            +
                        )
         
     | 
| 50 | 
         | 
| 51 | 
         
            +
                    parameters = data.pop("parameters", {})
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 52 | 
         | 
| 53 | 
         
            +
                    num_inference_steps = parameters.get("num_inference_steps", 28)
         
     | 
| 54 | 
         
            +
                    width = parameters.get("width", 1024)
         
     | 
| 55 | 
         
            +
                    height = parameters.get("height", 1024)
         
     | 
| 56 | 
         
            +
                    guidance_scale = parameters.get("guidance_scale", 3.5)
         
     | 
| 
         | 
|
| 57 | 
         | 
| 58 | 
         
            +
                    # seed generator (seed cannot be provided as is but via a generator)
         
     | 
| 59 | 
         
            +
                    seed = parameters.get("seed", 0)
         
     | 
| 60 | 
         
            +
                    generator = torch.manual_seed(seed)
         
     | 
| 61 | 
         
            +
                    start_time = time.time()
         
     | 
| 62 | 
         
            +
                    result = self.pipe(  # type: ignore
         
     | 
| 63 | 
         
            +
                        prompt,
         
     | 
| 64 | 
         
            +
                        height=height,
         
     | 
| 65 | 
         
            +
                        width=width,
         
     | 
| 66 | 
         
            +
                        guidance_scale=guidance_scale,
         
     | 
| 67 | 
         
            +
                        num_inference_steps=num_inference_steps,
         
     | 
| 68 | 
         
            +
                        generator=generator,
         
     | 
| 69 | 
         
            +
                        output_type="pil" if dist.get_rank() == 0 else "pt",
         
     | 
| 70 | 
         
            +
                    ).images[0]
         
     | 
| 71 | 
         
            +
                    end_time = time.time()
         
     | 
| 72 | 
         
            +
                    if dist.get_rank() == 0:
         
     | 
| 73 | 
         
             
                        time_taken = end_time - start_time
         
     | 
| 74 | 
         
             
                        print(f"Time taken: {time_taken:.2f} seconds")
         
     | 
| 
         | 
|
| 
         | 
|
| 75 | 
         
             
                        return result
         
     | 
| 76 | 
         
            +
                    return "123"
         
     | 
| 
         | 
|
| 
         |