Spaces:
Running
on
Zero
Running
on
Zero
10x demo speedup (#1)
Browse files- AOTI load (4d319cb3623ae213865e0c9729298ca0c85ce255)
Co-authored-by: Charles Bensimon <cbensimon@users.noreply.huggingface.co>
aoti.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
|
| 7 |
+
from spaces.zero.torch.aoti import ZeroGPUWeights
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def aoti_load(module: torch.nn.Module, repo_id: str):
|
| 11 |
+
repeated_blocks = module._repeated_blocks
|
| 12 |
+
aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks}
|
| 13 |
+
for block_name, aoti_file in aoti_files.items():
|
| 14 |
+
for block in module.modules():
|
| 15 |
+
if block.__class__.__name__ == block_name:
|
| 16 |
+
weights = ZeroGPUWeights(block.state_dict())
|
| 17 |
+
block.forward = ZeroGPUCompiledModel(aoti_file, weights)
|
app.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
sys.path.append('./')
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
import spaces
|
| 6 |
-
import os
|
| 7 |
import sys
|
| 8 |
import subprocess
|
| 9 |
import numpy as np
|
|
@@ -61,10 +63,11 @@ canny = CannyDetector()
|
|
| 61 |
anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline")
|
| 62 |
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
pipe.
|
|
|
|
| 68 |
|
| 69 |
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
| 70 |
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.system("pip install --upgrade spaces")
|
| 3 |
+
|
| 4 |
import sys
|
| 5 |
sys.path.append('./')
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import spaces
|
|
|
|
| 9 |
import sys
|
| 10 |
import subprocess
|
| 11 |
import numpy as np
|
|
|
|
| 63 |
anyline = AnylineDetector.from_pretrained("TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline")
|
| 64 |
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
| 65 |
|
| 66 |
+
import fa3
|
| 67 |
+
from aoti import aoti_load
|
| 68 |
+
|
| 69 |
+
pipe.transformer.fuse_qkv_projections()
|
| 70 |
+
aoti_load(pipe.transformer, 'zerogpu-aoti/FLUX.1')
|
| 71 |
|
| 72 |
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
| 73 |
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
fa3.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from kernels import get_kernel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
| 12 |
+
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
outputs, lse = _flash_attn_func(q, k, v)
|
| 14 |
+
return outputs
|
| 15 |
+
|
| 16 |
+
@flash_attn_func.register_fake
|
| 17 |
+
def _(q, k, v, **kwargs):
|
| 18 |
+
return torch.empty_like(q).contiguous()
|
requirements.txt
CHANGED
|
@@ -14,4 +14,5 @@ xformers
|
|
| 14 |
sentencepiece
|
| 15 |
peft
|
| 16 |
scipy
|
| 17 |
-
scikit-image
|
|
|
|
|
|
| 14 |
sentencepiece
|
| 15 |
peft
|
| 16 |
scipy
|
| 17 |
+
scikit-image
|
| 18 |
+
kernels
|