cbensimon HF Staff commited on
Commit
c72054d
·
1 Parent(s): 48a26b4

fp8e4m3 (disable aoti)

Browse files
Files changed (2) hide show
  1. fa3.py +2 -1
  2. optimization.py +1 -1
fa3.py CHANGED
@@ -10,7 +10,8 @@ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_f
10
 
11
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
12
  def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
13
- outputs, lse = _flash_attn_func(q, k, v)
 
14
  return outputs
15
 
16
  @flash_attn_func.register_fake
 
10
 
11
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
12
  def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
13
+ dtype = torch.float8_e4m3fn
14
+ outputs, lse = _flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype))
15
  return outputs
16
 
17
  @flash_attn_func.register_fake
optimization.py CHANGED
@@ -41,4 +41,4 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
41
 
42
  pipeline.transformer.fuse_qkv_projections()
43
  pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
44
- spaces.aoti_apply(compile_transformer(), pipeline.transformer)
 
41
 
42
  pipeline.transformer.fuse_qkv_projections()
43
  pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
44
+ # spaces.aoti_apply(compile_transformer(), pipeline.transformer)