Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7e16671
1
Parent(s):
2085d6e
flash attention upgraded for cuda 12.8
Browse files- app.py +21 -10
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -24,20 +24,31 @@ importlib.invalidate_caches()
|
|
| 24 |
|
| 25 |
def sh(cmd): subprocess.check_call(cmd, shell=True)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
sh(f"pip install {flash_attention_wheel}")
|
| 34 |
|
| 35 |
-
# tell Python to re-scan site-packages now that the egg-link exists
|
| 36 |
-
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
|
| 37 |
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
import torch.nn as nn
|
|
|
|
| 24 |
|
| 25 |
def sh(cmd): subprocess.check_call(cmd, shell=True)
|
| 26 |
|
| 27 |
+
flash_attention_installed = False
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
print("Attempting to download and install FlashAttention wheel...")
|
| 31 |
+
flash_attention_wheel = hf_hub_download(
|
| 32 |
+
repo_id="alexnasa/flash-attn-3",
|
| 33 |
+
repo_type="model",
|
| 34 |
+
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
|
| 35 |
+
)
|
| 36 |
|
| 37 |
+
sh(f"pip install {flash_attention_wheel}")
|
| 38 |
|
| 39 |
+
# tell Python to re-scan site-packages now that the egg-link exists
|
| 40 |
+
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
|
| 41 |
|
| 42 |
+
flash_attention_installed = True
|
| 43 |
+
print("FlashAttention installed successfully.")
|
| 44 |
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"⚠️ Could not install FlashAttention: {e}")
|
| 47 |
+
print("Continuing without FlashAttention...")
|
| 48 |
+
|
| 49 |
+
import torch
|
| 50 |
+
print(f"Torch version: {torch.__version__}")
|
| 51 |
+
print(f"FlashAttention available: {flash_attention_installed}")
|
| 52 |
|
| 53 |
|
| 54 |
import torch.nn as nn
|
requirements.txt
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
torch==2.7.1
|
| 2 |
tqdm
|
| 3 |
librosa==0.10.2.post1
|
| 4 |
peft==0.15.1
|
|
|
|
|
|
|
| 1 |
tqdm
|
| 2 |
librosa==0.10.2.post1
|
| 3 |
peft==0.15.1
|