Spaces:
Runtime error
Runtime error
| +++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py | |
| # | |
| # Contact: ps-license@tuebingen.mpg.de | |
| +from pathlib import Path | |
| + | |
| import torch | |
| import trimesh | |
| import torch.nn as nn | |
| from .utils.mesh import winding_numbers | |
| + | |
| +def load_pkl(path): | |
| + with open(path, "rb") as fin: | |
| + return pickle.load(fin) | |
| + | |
| + | |
| +def save_pkl(obj, path): | |
| + with open(path, "wb") as fout: | |
| + pickle.dump(obj, fout) | |
| + | |
| + | |
| class BodySegment(nn.Module): | |
| def __init__(self, | |
| name, | |
| self.register_buffer('segment_faces', segment_faces) | |
| # create vector to select vertices form faces | |
| - tri_vidx = [] | |
| - for ii in range(faces.max().item()+1): | |
| - tri_vidx += [torch.nonzero(faces==ii)[0].tolist()] | |
| + segments_folder = Path(segments_folder) | |
| + tri_vidx_path = segments_folder / "tri_vidx.pkl" | |
| + if not tri_vidx_path.is_file(): | |
| + tri_vidx = [] | |
| + for ii in range(faces.max().item()+1): | |
| + tri_vidx += [torch.nonzero(faces==ii)[0].tolist()] | |
| + | |
| + save_pkl(tri_vidx, tri_vidx_path) | |
| + else: | |
| + tri_vidx = load_pkl(tri_vidx_path) | |
| + | |
| self.register_buffer('tri_vidx', torch.tensor(tri_vidx)) | |
| def create_band_faces(self): | |
| self.segmentation = {} | |
| for idx, name in enumerate(names): | |
| self.segmentation[name] = BodySegment(name, faces, segments_folder, | |
| - model_type).to('cuda') | |
| + model_type).to(device) | |
| def batch_has_self_isec_verts(self, vertices): | |
| """ | |
| +++ venv/lib/python3.10/site-packages/selfcontact/selfcontact.py | |
| test_segments=True, | |
| compute_hd=False, | |
| buffer_geodists=False, | |
| + device="cuda", | |
| ): | |
| super().__init__() | |
| if self.test_segments: | |
| sxseg = pickle.load(open(segments_bounds_path, 'rb')) | |
| self.segments = BatchBodySegment( | |
| - [x for x in sxseg.keys()], faces, segments_folder, self.model_type | |
| + [x for x in sxseg.keys()], faces, segments_folder, self.model_type, device=device, | |
| ) | |
| # load regressor to get high density mesh | |
| torch.tensor(hd_operator['values']), | |
| torch.Size(hd_operator['size'])) | |
| self.register_buffer('hd_operator', | |
| - torch.tensor(hd_operator).float()) | |
| + hd_operator.clone().detach().float()) | |
| with open(point_vert_corres_path, 'rb') as f: | |
| hd_geovec = pickle.load(f)['faces_vert_is_sampled_from'] | |
| # split because of memory into two chunks | |
| exterior = torch.zeros((bs, nv), device=vertices.device, | |
| dtype=torch.bool) | |
| - exterior[:, :5000] = winding_numbers(vertices[:,:5000,:], | |
| + exterior[:, :3000] = winding_numbers(vertices[:,:3000,:], | |
| triangles).le(0.99) | |
| - exterior[:, 5000:] = winding_numbers(vertices[:,5000:,:], | |
| + exterior[:, 3000:6000] = winding_numbers(vertices[:,3000:6000,:], | |
| + triangles).le(0.99) | |
| + exterior[:, 6000:9000] = winding_numbers(vertices[:,6000:9000,:], | |
| + triangles).le(0.99) | |
| + exterior[:, 9000:] = winding_numbers(vertices[:,9000:,:], | |
| triangles).le(0.99) | |
| # check if intersections happen within segments | |
| # split because of memory into two chunks | |
| exterior = torch.zeros((bs, np), device=points.device, | |
| dtype=torch.bool) | |
| - exterior[:, :6000] = winding_numbers(points[:,:6000,:], | |
| + exterior[:, :3000] = winding_numbers(points[:,:3000,:], | |
| + triangles).le(0.99) | |
| + exterior[:, 3000:6000] = winding_numbers(points[:,3000:6000,:], | |
| triangles).le(0.99) | |
| - exterior[:, 6000:] = winding_numbers(points[:,6000:,:], | |
| + exterior[:, 6000:9000] = winding_numbers(points[:,6000:9000,:], | |
| + triangles).le(0.99) | |
| + exterior[:, 9000:] = winding_numbers(points[:,9000:,:], | |
| triangles).le(0.99) | |
| return exterior | |
| return hd_v2v_mins, hd_exteriors, hd_points, hd_faces_in_contacts | |
| + def verts_in_contact(self, vertices, return_idx=False): | |
| + | |
| + # get pairwise distances of vertices | |
| + v2v = self.get_pairwise_dists(vertices, vertices, squared=True) | |
| + | |
| + # mask v2v with eucledean and geodesic dsitance | |
| + euclmask = v2v < self.euclthres**2 | |
| + mask = euclmask * self.geomask | |
| + | |
| + # find closes vertex in contact | |
| + in_contact = mask.sum(1) > 0 | |
| + | |
| + if return_idx: | |
| + in_contact = torch.where(in_contact) | |
| + | |
| + return in_contact | |
| + | |
| class SelfContactSmall(nn.Module): | |
| +++ venv/lib/python3.10/site-packages/selfcontact/utils/mesh.py | |
| if valid_vals > 0: | |
| loss = (mask * dists).sum() / valid_vals | |
| else: | |
| - loss = torch.Tensor([0]).cuda() | |
| + loss = mask.new_tensor([0]) | |
| return loss | |
| def batch_index_select(inp, dim, index): | |
| xx = torch.bmm(x, x.transpose(2, 1)) | |
| yy = torch.bmm(y, y.transpose(2, 1)) | |
| zz = torch.bmm(x, y.transpose(2, 1)) | |
| + use_cuda = x.device.type == "cuda" | |
| if use_cuda: | |
| dtype = torch.cuda.LongTensor | |
| else: | |