murcherful commited on
Commit
bc42d0a
·
1 Parent(s): 0bf21ce
Files changed (3) hide show
  1. P3-SAM/demo/auto_mask.py +4 -3
  2. app.py +33 -28
  3. requirements.txt +3 -1
P3-SAM/demo/auto_mask.py CHANGED
@@ -1253,10 +1253,11 @@ class AutoMask:
1253
  self.model = P3SAM()
1254
  self.model.load_state_dict(ckpt_path)
1255
  self.model.eval()
1256
- # self.model_parallel = torch.nn.DataParallel(self.model)
1257
  self.model_parallel = self.model
1258
- self.model.cuda()
1259
- self.model_parallel.cuda()
 
 
1260
  self.point_num = point_num
1261
  self.prompt_num = prompt_num
1262
  self.threshold = threshold
 
1253
  self.model = P3SAM()
1254
  self.model.load_state_dict(ckpt_path)
1255
  self.model.eval()
 
1256
  self.model_parallel = self.model
1257
+ # self.model.cuda()
1258
+ self.model.to('cuda')
1259
+ self.model_parallel.to('cuda')
1260
+ print('p3sam to cuda')
1261
  self.point_num = point_num
1262
  self.prompt_num = prompt_num
1263
  self.threshold = threshold
app.py CHANGED
@@ -8,39 +8,44 @@ from pathlib import Path
8
  import torch
9
  import pytorch_lightning as pl
10
  import spaces
11
- import torch.multiprocessing as mp
12
- mp.set_start_method('spawn')
13
- print('using torch spawm')
14
-
 
 
15
 
16
  sys.path.append('P3-SAM')
17
  from demo.auto_mask import AutoMask
18
- from demo.auto_mask_no_postprocess import AutoMask as AutoMaskNoPostProcess
19
- sys.path.append('XPart')
20
- from partgen.partformer_pipeline import PartFormerPipeline
21
- from partgen.utils.misc import get_config_from_file
 
22
 
23
  automask = AutoMask()
24
- automask_no_postprocess = AutoMaskNoPostProcess(automask_instance=automask)
25
-
26
- def _load_pipeline():
27
- pl.seed_everything(2026, workers=True)
28
- cfg_path = str(Path(__file__).parent / "XPart/partgen/config" / "infer.yaml")
29
- config = get_config_from_file(cfg_path)
30
- assert hasattr(config, "ckpt") or hasattr(
31
- config, "ckpt_path"
32
- ), "ckpt or ckpt_path must be specified in config"
33
- pipeline = PartFormerPipeline.from_pretrained(
34
- config=config,
35
- verbose=True,
36
- ignore_keys=config.get("ignore_keys", []),
37
- )
38
-
39
- device = "cuda"
40
- pipeline.to(device=device, dtype=torch.float32)
41
- return pipeline
42
-
43
- _PIPELINE = _load_pipeline()
 
 
44
 
45
  output_path = 'P3-SAM/results/gradio'
46
  os.makedirs(output_path, exist_ok=True)
 
8
  import torch
9
  import pytorch_lightning as pl
10
  import spaces
11
+ # from spaces import zero
12
+ # zero.startup()
13
+ # import torch.multiprocessing as mp
14
+ # mp.set_start_method('spawn')
15
+ # print('using torch spawm')
16
+ # print('zero gpu startup')
17
 
18
  sys.path.append('P3-SAM')
19
  from demo.auto_mask import AutoMask
20
+ # from demo.auto_mask_no_postprocess import AutoMask as AutoMaskNoPostProcess
21
+ # sys.path.append('XPart')
22
+ # from partgen.partformer_pipeline import PartFormerPipeline
23
+ # from partgen.utils.misc import get_config_from_file
24
+ print('no automask no postprocess')
25
 
26
  automask = AutoMask()
27
+ # automask_no_postprocess = AutoMaskNoPostProcess(automask_instance=automask)
28
+
29
+ # def _load_pipeline():
30
+ # pl.seed_everything(2026, workers=True)
31
+ # cfg_path = str(Path(__file__).parent / "XPart/partgen/config" / "infer.yaml")
32
+ # config = get_config_from_file(cfg_path)
33
+ # assert hasattr(config, "ckpt") or hasattr(
34
+ # config, "ckpt_path"
35
+ # ), "ckpt or ckpt_path must be specified in config"
36
+ # pipeline = PartFormerPipeline.from_pretrained(
37
+ # config=config,
38
+ # verbose=True,
39
+ # ignore_keys=config.get("ignore_keys", []),
40
+ # )
41
+
42
+ # device = "cuda"
43
+ # pipeline.to(device=device, dtype=torch.float32)
44
+ # return pipeline
45
+
46
+ # _PIPELINE = _load_pipeline()
47
+
48
+ print('no xpart pipeline')
49
 
50
  output_path = 'P3-SAM/results/gradio'
51
  os.makedirs(output_path, exist_ok=True)
requirements.txt CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # Build Tools
2
  ninja==1.11.1.1
3
 
@@ -43,5 +45,5 @@ scikit-image
43
 
44
  # sonata
45
  spconv-cu126
46
- torch-scatter -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
47
  git+https://github.com/Dao-AILab/flash-attention.git
 
1
+ torch==2.7.0 --index-url https://download.pytorch.org/whl/cu126
2
+
3
  # Build Tools
4
  ninja==1.11.1.1
5
 
 
45
 
46
  # sonata
47
  spconv-cu126
48
+ torch-scatter -f https://data.pyg.org/whl/torch-2.7.0+cu126.html
49
  git+https://github.com/Dao-AILab/flash-attention.git