Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: ps-license@tuebingen.mpg.de | |
| from typing import Optional | |
| from torch import Tensor | |
| import smplx | |
| from .base import Datastruct, dataclass, Transform | |
| from .rots2rfeats import Rots2Rfeats | |
| from .rots2joints import Rots2Joints | |
| from .joints2jfeats import Joints2Jfeats | |
| class SMPLTransform(Transform): | |
| def __init__(self, rots2rfeats: Rots2Rfeats, | |
| rots2joints: Rots2Joints, | |
| joints2jfeats: Joints2Jfeats, | |
| **kwargs): | |
| self.rots2rfeats = rots2rfeats | |
| self.rots2joints = rots2joints | |
| self.joints2jfeats = joints2jfeats | |
| def Datastruct(self, **kwargs): | |
| return SMPLDatastruct(_rots2rfeats=self.rots2rfeats, | |
| _rots2joints=self.rots2joints, | |
| _joints2jfeats=self.joints2jfeats, | |
| transforms=self, | |
| **kwargs) | |
| def __repr__(self): | |
| return "SMPLTransform()" | |
| class RotIdentityTransform(Transform): | |
| def __init__(self, **kwargs): | |
| return | |
| def Datastruct(self, **kwargs): | |
| return RotTransDatastruct(**kwargs) | |
| def __repr__(self): | |
| return "RotIdentityTransform()" | |
| class RotTransDatastruct(Datastruct): | |
| rots: Tensor | |
| trans: Tensor | |
| transforms: RotIdentityTransform = RotIdentityTransform() | |
| def __post_init__(self): | |
| self.datakeys = ["rots", "trans"] | |
| def __len__(self): | |
| return len(self.rots) | |
| class SMPLDatastruct(Datastruct): | |
| transforms: SMPLTransform | |
| _rots2rfeats: Rots2Rfeats | |
| _rots2joints: Rots2Joints | |
| _joints2jfeats: Joints2Jfeats | |
| features: Optional[Tensor] = None | |
| rots_: Optional[RotTransDatastruct] = None | |
| rfeats_: Optional[Tensor] = None | |
| joints_: Optional[Tensor] = None | |
| jfeats_: Optional[Tensor] = None | |
| vertices_: Optional[Tensor] = None | |
| def __post_init__(self): | |
| self.datakeys = ['features', 'rots_', 'rfeats_', | |
| 'joints_', 'jfeats_', 'vertices_'] | |
| # starting point | |
| if self.features is not None and self.rfeats_ is None: | |
| self.rfeats_ = self.features | |
| def rots(self): | |
| # Cached value | |
| if self.rots_ is not None: | |
| return self.rots_ | |
| # self.rfeats_ should be defined | |
| assert self.rfeats_ is not None | |
| self._rots2rfeats.to(self.rfeats.device) | |
| self.rots_ = self._rots2rfeats.inverse(self.rfeats) | |
| return self.rots_ | |
| def rfeats(self): | |
| # Cached value | |
| if self.rfeats_ is not None: | |
| return self.rfeats_ | |
| # self.rots_ should be defined | |
| assert self.rots_ is not None | |
| self._rots2rfeats.to(self.rots.device) | |
| self.rfeats_ = self._rots2rfeats(self.rots) | |
| return self.rfeats_ | |
| def joints(self): | |
| # Cached value | |
| if self.joints_ is not None: | |
| return self.joints_ | |
| self._rots2joints.to(self.rots.device) | |
| self.joints_ = self._rots2joints(self.rots) | |
| return self.joints_ | |
| def jfeats(self): | |
| # Cached value | |
| if self.jfeats_ is not None: | |
| return self.jfeats_ | |
| self._joints2jfeats.to(self.joints.device) | |
| self.jfeats_ = self._joints2jfeats(self.joints) | |
| return self.jfeats_ | |
| def vertices(self): | |
| # Cached value | |
| if self.vertices_ is not None: | |
| return self.vertices_ | |
| self._rots2joints.to(self.rots.device) | |
| self.vertices_ = self._rots2joints(self.rots, jointstype="vertices") | |
| return self.vertices_ | |
| def __len__(self): | |
| return len(self.rfeats) | |
| def get_body_model(model_type, gender, batch_size, device='cpu', ext='pkl'): | |
| ''' | |
| type: smpl, smplx smplh and others. Refer to smplx tutorial | |
| gender: male, female, neutral | |
| batch_size: an positive integar | |
| ''' | |
| mtype = model_type.upper() | |
| if gender != 'neutral': | |
| if not isinstance(gender, str): | |
| gender = str(gender.astype(str)).upper() | |
| else: | |
| gender = gender.upper() | |
| else: | |
| gender = gender.upper() | |
| ext = 'npz' | |
| body_model_path = f'data/smpl_models/{model_type}/{mtype}_{gender}.{ext}' | |
| body_model = smplx.create(body_model_path, model_type=type, | |
| gender=gender, ext=ext, | |
| use_pca=False, | |
| num_pca_comps=12, | |
| create_global_orient=True, | |
| create_body_pose=True, | |
| create_betas=True, | |
| create_left_hand_pose=True, | |
| create_right_hand_pose=True, | |
| create_expression=True, | |
| create_jaw_pose=True, | |
| create_leye_pose=True, | |
| create_reye_pose=True, | |
| create_transl=True, | |
| batch_size=batch_size) | |
| if device == 'cuda': | |
| return body_model.cuda() | |
| else: | |
| return body_model | |