Spaces:
Running
on
Zero
Running
on
Zero
fp8e4m3 (disable aoti)
Browse files- fa3.py +2 -1
- optimization.py +1 -1
fa3.py
CHANGED
|
@@ -10,7 +10,8 @@ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_f
|
|
| 10 |
|
| 11 |
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
| 12 |
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 13 |
-
|
|
|
|
| 14 |
return outputs
|
| 15 |
|
| 16 |
@flash_attn_func.register_fake
|
|
|
|
| 10 |
|
| 11 |
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
| 12 |
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
dtype = torch.float8_e4m3fn
|
| 14 |
+
outputs, lse = _flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype))
|
| 15 |
return outputs
|
| 16 |
|
| 17 |
@flash_attn_func.register_fake
|
optimization.py
CHANGED
|
@@ -41,4 +41,4 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 41 |
|
| 42 |
pipeline.transformer.fuse_qkv_projections()
|
| 43 |
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
|
| 44 |
-
spaces.aoti_apply(compile_transformer(), pipeline.transformer)
|
|
|
|
| 41 |
|
| 42 |
pipeline.transformer.fuse_qkv_projections()
|
| 43 |
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
|
| 44 |
+
# spaces.aoti_apply(compile_transformer(), pipeline.transformer)
|