Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
bc42d0a
1
Parent(s):
0bf21ce
fix req
Browse files- P3-SAM/demo/auto_mask.py +4 -3
- app.py +33 -28
- requirements.txt +3 -1
P3-SAM/demo/auto_mask.py
CHANGED
|
@@ -1253,10 +1253,11 @@ class AutoMask:
|
|
| 1253 |
self.model = P3SAM()
|
| 1254 |
self.model.load_state_dict(ckpt_path)
|
| 1255 |
self.model.eval()
|
| 1256 |
-
# self.model_parallel = torch.nn.DataParallel(self.model)
|
| 1257 |
self.model_parallel = self.model
|
| 1258 |
-
self.model.cuda()
|
| 1259 |
-
self.
|
|
|
|
|
|
|
| 1260 |
self.point_num = point_num
|
| 1261 |
self.prompt_num = prompt_num
|
| 1262 |
self.threshold = threshold
|
|
|
|
| 1253 |
self.model = P3SAM()
|
| 1254 |
self.model.load_state_dict(ckpt_path)
|
| 1255 |
self.model.eval()
|
|
|
|
| 1256 |
self.model_parallel = self.model
|
| 1257 |
+
# self.model.cuda()
|
| 1258 |
+
self.model.to('cuda')
|
| 1259 |
+
self.model_parallel.to('cuda')
|
| 1260 |
+
print('p3sam to cuda')
|
| 1261 |
self.point_num = point_num
|
| 1262 |
self.prompt_num = prompt_num
|
| 1263 |
self.threshold = threshold
|
app.py
CHANGED
|
@@ -8,39 +8,44 @@ from pathlib import Path
|
|
| 8 |
import torch
|
| 9 |
import pytorch_lightning as pl
|
| 10 |
import spaces
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
|
| 16 |
sys.path.append('P3-SAM')
|
| 17 |
from demo.auto_mask import AutoMask
|
| 18 |
-
from demo.auto_mask_no_postprocess import AutoMask as AutoMaskNoPostProcess
|
| 19 |
-
sys.path.append('XPart')
|
| 20 |
-
from partgen.partformer_pipeline import PartFormerPipeline
|
| 21 |
-
from partgen.utils.misc import get_config_from_file
|
|
|
|
| 22 |
|
| 23 |
automask = AutoMask()
|
| 24 |
-
automask_no_postprocess = AutoMaskNoPostProcess(automask_instance=automask)
|
| 25 |
-
|
| 26 |
-
def _load_pipeline():
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
_PIPELINE = _load_pipeline()
|
|
|
|
|
|
|
| 44 |
|
| 45 |
output_path = 'P3-SAM/results/gradio'
|
| 46 |
os.makedirs(output_path, exist_ok=True)
|
|
|
|
| 8 |
import torch
|
| 9 |
import pytorch_lightning as pl
|
| 10 |
import spaces
|
| 11 |
+
# from spaces import zero
|
| 12 |
+
# zero.startup()
|
| 13 |
+
# import torch.multiprocessing as mp
|
| 14 |
+
# mp.set_start_method('spawn')
|
| 15 |
+
# print('using torch spawm')
|
| 16 |
+
# print('zero gpu startup')
|
| 17 |
|
| 18 |
sys.path.append('P3-SAM')
|
| 19 |
from demo.auto_mask import AutoMask
|
| 20 |
+
# from demo.auto_mask_no_postprocess import AutoMask as AutoMaskNoPostProcess
|
| 21 |
+
# sys.path.append('XPart')
|
| 22 |
+
# from partgen.partformer_pipeline import PartFormerPipeline
|
| 23 |
+
# from partgen.utils.misc import get_config_from_file
|
| 24 |
+
print('no automask no postprocess')
|
| 25 |
|
| 26 |
automask = AutoMask()
|
| 27 |
+
# automask_no_postprocess = AutoMaskNoPostProcess(automask_instance=automask)
|
| 28 |
+
|
| 29 |
+
# def _load_pipeline():
|
| 30 |
+
# pl.seed_everything(2026, workers=True)
|
| 31 |
+
# cfg_path = str(Path(__file__).parent / "XPart/partgen/config" / "infer.yaml")
|
| 32 |
+
# config = get_config_from_file(cfg_path)
|
| 33 |
+
# assert hasattr(config, "ckpt") or hasattr(
|
| 34 |
+
# config, "ckpt_path"
|
| 35 |
+
# ), "ckpt or ckpt_path must be specified in config"
|
| 36 |
+
# pipeline = PartFormerPipeline.from_pretrained(
|
| 37 |
+
# config=config,
|
| 38 |
+
# verbose=True,
|
| 39 |
+
# ignore_keys=config.get("ignore_keys", []),
|
| 40 |
+
# )
|
| 41 |
+
|
| 42 |
+
# device = "cuda"
|
| 43 |
+
# pipeline.to(device=device, dtype=torch.float32)
|
| 44 |
+
# return pipeline
|
| 45 |
+
|
| 46 |
+
# _PIPELINE = _load_pipeline()
|
| 47 |
+
|
| 48 |
+
print('no xpart pipeline')
|
| 49 |
|
| 50 |
output_path = 'P3-SAM/results/gradio'
|
| 51 |
os.makedirs(output_path, exist_ok=True)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
# Build Tools
|
| 2 |
ninja==1.11.1.1
|
| 3 |
|
|
@@ -43,5 +45,5 @@ scikit-image
|
|
| 43 |
|
| 44 |
# sonata
|
| 45 |
spconv-cu126
|
| 46 |
-
torch-scatter -f https://data.pyg.org/whl/torch-2.
|
| 47 |
git+https://github.com/Dao-AILab/flash-attention.git
|
|
|
|
| 1 |
+
torch==2.7.0 --index-url https://download.pytorch.org/whl/cu126
|
| 2 |
+
|
| 3 |
# Build Tools
|
| 4 |
ninja==1.11.1.1
|
| 5 |
|
|
|
|
| 45 |
|
| 46 |
# sonata
|
| 47 |
spconv-cu126
|
| 48 |
+
torch-scatter -f https://data.pyg.org/whl/torch-2.7.0+cu126.html
|
| 49 |
git+https://github.com/Dao-AILab/flash-attention.git
|