Spaces:
Runtime error
Runtime error
| import os | |
| from unittest.mock import patch | |
| from transformers.dynamic_module_utils import get_imports | |
| def fixed_get_imports(filename: str | os.PathLike) -> list[str]: | |
| """Workaround for flash_attn import issue.""" | |
| if not str(filename).endswith(("/modeling_florence2.py", "configuration_florence2.py")): | |
| return get_imports(filename) | |
| imports = get_imports(filename) | |
| if "flash_attn" in imports: | |
| imports.remove("flash_attn") | |
| return imports | |
| def load_model_without_flash_attn(model_loader): | |
| """Load a model using the flash_attn workaround.""" | |
| with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): | |
| return model_loader() | |