eswardivi commited on
Commit
1ce2a4a
·
verified ·
1 Parent(s): 2e26df2
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -11,19 +11,32 @@ import spaces
11
  import time
12
  import subprocess
13
 
 
 
14
  if torch.cuda.is_available():
 
15
  try:
16
  subprocess.run(
17
  "pip install flash-attn --no-build-isolation",
18
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
19
  shell=True,
20
  check=True,
21
  )
22
- print("✅ flash-attn installed (GPU detected)")
23
  except subprocess.CalledProcessError as e:
24
  print("⚠️ flash-attn installation failed:", e)
25
  else:
26
  print("⚙️ CPU detected — skipping flash-attn installation")
 
 
 
 
 
 
 
 
 
 
 
27
  token = os.environ["HF_TOKEN"]
28
 
29
 
 
11
  import time
12
  import subprocess
13
 
14
+ print("\n=== Environment Setup ===")
15
+
16
  if torch.cuda.is_available():
17
+ print(f"GPU detected: {torch.cuda.get_device_name(0)}")
18
  try:
19
  subprocess.run(
20
  "pip install flash-attn --no-build-isolation",
 
21
  shell=True,
22
  check=True,
23
  )
24
+ print("✅ flash-attn installed successfully")
25
  except subprocess.CalledProcessError as e:
26
  print("⚠️ flash-attn installation failed:", e)
27
  else:
28
  print("⚙️ CPU detected — skipping flash-attn installation")
29
+ # Disable flash-attn references safely
30
+ os.environ["DISABLE_FLASH_ATTN"] = "1"
31
+ os.environ["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
32
+ try:
33
+ from transformers.utils import import_utils
34
+ if "flash_attn" not in import_utils.PACKAGE_DISTRIBUTION_MAPPING:
35
+ import_utils.PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] = "flash-attn"
36
+ except Exception as e:
37
+ print("⚠️ Patch skipped:", e)
38
+
39
+
40
  token = os.environ["HF_TOKEN"]
41
 
42