Spaces:
Runtime error
Runtime error
flash attn fix
Browse files- app.py +13 -1
- requirements.txt +1 -2
- utils.py +17 -0
app.py
CHANGED
|
@@ -10,6 +10,8 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
| 10 |
import cv2
|
| 11 |
import traceback
|
| 12 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# CUDA optimizations
|
| 15 |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
|
@@ -26,9 +28,19 @@ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
|
| 26 |
image_predictor = SAM2ImagePredictor(sam2_model)
|
| 27 |
|
| 28 |
model_id = 'microsoft/Florence-2-large'
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
| 31 |
|
|
|
|
| 32 |
def apply_color_mask(frame, mask, obj_id):
|
| 33 |
cmap = plt.get_cmap("tab10")
|
| 34 |
color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
|
|
|
|
| 10 |
import cv2
|
| 11 |
import traceback
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
+
from utils import load_model_without_flash_attn
|
| 14 |
+
|
| 15 |
|
| 16 |
# CUDA optimizations
|
| 17 |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
|
|
|
| 28 |
image_predictor = SAM2ImagePredictor(sam2_model)
|
| 29 |
|
| 30 |
model_id = 'microsoft/Florence-2-large'
|
| 31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
|
| 33 |
+
def load_florence_model():
|
| 34 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 35 |
+
model_id,
|
| 36 |
+
trust_remote_code=True,
|
| 37 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
| 38 |
+
).eval().to(device)
|
| 39 |
+
|
| 40 |
+
florence_model = load_model_without_flash_attn(load_florence_model)
|
| 41 |
florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
| 42 |
|
| 43 |
+
|
| 44 |
def apply_color_mask(frame, mask, obj_id):
|
| 45 |
cmap = plt.get_cmap("tab10")
|
| 46 |
color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
|
requirements.txt
CHANGED
|
@@ -8,5 +8,4 @@ opencv-python
|
|
| 8 |
matplotlib
|
| 9 |
einops
|
| 10 |
timm
|
| 11 |
-
pytest
|
| 12 |
-
flash_attn
|
|
|
|
| 8 |
matplotlib
|
| 9 |
einops
|
| 10 |
timm
|
| 11 |
+
pytest
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from unittest.mock import patch
|
| 3 |
+
from transformers.dynamic_module_utils import get_imports
|
| 4 |
+
|
| 5 |
+
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
| 6 |
+
"""Workaround for flash_attn import issue."""
|
| 7 |
+
if not str(filename).endswith(("modeling_phi.py", "configuration_florence2.py")):
|
| 8 |
+
return get_imports(filename)
|
| 9 |
+
imports = get_imports(filename)
|
| 10 |
+
if "flash_attn" in imports:
|
| 11 |
+
imports.remove("flash_attn")
|
| 12 |
+
return imports
|
| 13 |
+
|
| 14 |
+
def load_model_without_flash_attn(model_loader):
|
| 15 |
+
"""Load a model using the flash_attn workaround."""
|
| 16 |
+
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
| 17 |
+
return model_loader()
|