murcherful commited on
Commit
0bf21ce
·
1 Parent(s): 5b81799
P3-SAM/demo/auto_mask.py CHANGED
@@ -1253,7 +1253,8 @@ class AutoMask:
1253
  self.model = P3SAM()
1254
  self.model.load_state_dict(ckpt_path)
1255
  self.model.eval()
1256
- self.model_parallel = torch.nn.DataParallel(self.model)
 
1257
  self.model.cuda()
1258
  self.model_parallel.cuda()
1259
  self.point_num = point_num
 
1253
  self.model = P3SAM()
1254
  self.model.load_state_dict(ckpt_path)
1255
  self.model.eval()
1256
+ # self.model_parallel = torch.nn.DataParallel(self.model)
1257
+ self.model_parallel = self.model
1258
  self.model.cuda()
1259
  self.model_parallel.cuda()
1260
  self.point_num = point_num
P3-SAM/demo/auto_mask_no_postprocess.py CHANGED
@@ -760,7 +760,8 @@ class AutoMask:
760
  self.model = P3SAM()
761
  self.model.load_state_dict(ckpt_path)
762
  self.model.eval()
763
- self.model_parallel = torch.nn.DataParallel(self.model)
 
764
  self.model.cuda()
765
  self.model_parallel.cuda()
766
  self.point_num = point_num
 
760
  self.model = P3SAM()
761
  self.model.load_state_dict(ckpt_path)
762
  self.model.eval()
763
+ # self.model_parallel = torch.nn.DataParallel(self.model)
764
+ self.model_parallel = self.model
765
  self.model.cuda()
766
  self.model_parallel.cuda()
767
  self.point_num = point_num
XPart/partgen/bbox_estimator/auto_mask_api.py CHANGED
@@ -1366,7 +1366,8 @@ class AutoMask:
1366
  state_dict=torch.load(ckpt_path, weights_only=False, map_location="cpu")["state_dict"]
1367
  )
1368
  self.model.eval()
1369
- self.model_parallel = torch.nn.DataParallel(self.model)
 
1370
  self.model.cuda()
1371
  self.model_parallel.cuda()
1372
  self.point_num = point_num
 
1366
  state_dict=torch.load(ckpt_path, weights_only=False, map_location="cpu")["state_dict"]
1367
  )
1368
  self.model.eval()
1369
+ # self.model_parallel = torch.nn.DataParallel(self.model)
1370
+ self.model_parallel = self.model
1371
  self.model.cuda()
1372
  self.model_parallel.cuda()
1373
  self.point_num = point_num
app.py CHANGED
@@ -8,9 +8,9 @@ from pathlib import Path
8
  import torch
9
  import pytorch_lightning as pl
10
  import spaces
11
- import multiprocessing
12
- multiprocessing.set_start_method('spawn')
13
- print('using spawm')
14
 
15
 
16
  sys.path.append('P3-SAM')
 
8
  import torch
9
  import pytorch_lightning as pl
10
  import spaces
11
+ import torch.multiprocessing as mp
12
+ mp.set_start_method('spawn')
13
+ print('using torch spawm')
14
 
15
 
16
  sys.path.append('P3-SAM')