Spaces:
Sleeping
Sleeping
feat: use bfloat16
Browse files
app.py
CHANGED
|
@@ -16,6 +16,7 @@ from src.plot_utils import export_mask
|
|
| 16 |
@spaces.GPU()
|
| 17 |
def predict(model_choice, annotations: Dict[str, Any]):
|
| 18 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 19 |
sam2_model = load_model(
|
| 20 |
variant=model_choice,
|
| 21 |
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
|
|
|
|
| 16 |
@spaces.GPU()
|
| 17 |
def predict(model_choice, annotations: Dict[str, Any]):
|
| 18 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 20 |
sam2_model = load_model(
|
| 21 |
variant=model_choice,
|
| 22 |
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
|