Spaces:
Runtime error
Runtime error
| from einops import rearrange, repeat | |
| import torch | |
| import torch.nn as nn | |
| from risk_biased.models.cvae_params import CVAEParams | |
| from risk_biased.models.nn_blocks import ( | |
| MCG, | |
| MAB, | |
| MHB, | |
| SequenceDecoderLSTM, | |
| SequenceDecoderMLP, | |
| SequenceEncoderLSTM, | |
| SequenceEncoderMLP, | |
| SequenceEncoderMaskedLSTM, | |
| ) | |
| class DecoderNN(nn.Module): | |
| """Decoder neural network that decodes input tensors into a single output tensor. | |
| It contains an interaction layer that (re-)compute the interactions between the agents in the scene. | |
| This implies that a given latent sample for one agent will be affecting the predictions of the othe agents too. | |
| Args: | |
| params: dataclass defining the necessary parameters | |
| """ | |
| def __init__( | |
| self, | |
| params: CVAEParams, | |
| ) -> None: | |
| super().__init__() | |
| self.dt = params.dt | |
| self.state_dim = params.state_dim | |
| self.dynamic_state_dim = params.dynamic_state_dim | |
| self.hidden_dim = params.hidden_dim | |
| self.num_steps_future = params.num_steps_future | |
| self.latent_dim = params.latent_dim | |
| if params.sequence_encoder_type == "MLP": | |
| self._agent_encoder_past = SequenceEncoderMLP( | |
| params.state_dim, | |
| params.hidden_dim, | |
| params.num_hidden_layers, | |
| params.num_steps, | |
| params.is_mlp_residual, | |
| ) | |
| elif params.sequence_encoder_type == "LSTM": | |
| self._agent_encoder_past = SequenceEncoderLSTM( | |
| params.state_dim, params.hidden_dim | |
| ) | |
| elif params.sequence_encoder_type == "maskedLSTM": | |
| self._agent_encoder_past = SequenceEncoderMaskedLSTM( | |
| params.state_dim, params.hidden_dim | |
| ) | |
| else: | |
| raise RuntimeError( | |
| f"Got sequence encoder type {params.sequence_decoder_type} but only knows one of: 'MLP', 'LSTM', 'maskedLSTM' " | |
| ) | |
| self._combine_z_past = nn.Linear( | |
| params.hidden_dim + params.latent_dim, params.hidden_dim | |
| ) | |
| if params.interaction_type == "Attention" or params.interaction_type == "MAB": | |
| self._interaction = MAB( | |
| params.hidden_dim, params.num_attention_heads, params.num_blocks | |
| ) | |
| elif ( | |
| params.interaction_type == "ContextGating" | |
| or params.interaction_type == "MCG" | |
| ): | |
| self._interaction = MCG( | |
| params.hidden_dim, | |
| params.mcg_dim_expansion, | |
| params.mcg_num_layers, | |
| params.num_blocks, | |
| params.is_mlp_residual, | |
| ) | |
| elif params.interaction_type == "Hybrid" or params.interaction_type == "MHB": | |
| self._interaction = MHB( | |
| params.hidden_dim, | |
| params.num_attention_heads, | |
| params.mcg_dim_expansion, | |
| params.mcg_num_layers, | |
| params.num_blocks, | |
| params.is_mlp_residual, | |
| ) | |
| else: | |
| self._interaction = lambda x, *args, **kwargs: x | |
| if params.sequence_decoder_type == "MLP": | |
| self._decoder = SequenceDecoderMLP( | |
| params.hidden_dim, | |
| params.num_hidden_layers, | |
| params.num_steps_future, | |
| params.is_mlp_residual, | |
| ) | |
| elif params.sequence_decoder_type == "LSTM": | |
| self._decoder = SequenceDecoderLSTM(params.hidden_dim) | |
| elif params.sequence_decoder_type == "maskedLSTM": | |
| self._decoder = SequenceDecoderLSTM(params.hidden_dim) | |
| else: | |
| raise RuntimeError( | |
| f"Got sequence decoder type {params.sequence_decoder_type} but only knows one of: 'MLP', 'LSTM', 'maskedLSTM' " | |
| ) | |
| def forward( | |
| self, | |
| z_samples: torch.Tensor, | |
| mask_z: torch.Tensor, | |
| x: torch.Tensor, | |
| mask_x: torch.Tensor, | |
| encoded_absolute: torch.Tensor, | |
| encoded_map: torch.Tensor, | |
| mask_map: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Forward function that decodes input tensors into an output tensor of size | |
| (batch_size, num_agents, (n_samples), num_steps_future, state_dim) | |
| Args: | |
| z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history | |
| mask_z: (batch_size, num_agents) tensor of bool mask | |
| x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents | |
| mask_x: (batch_size, num_agents, num_steps) tensor of bool mask | |
| encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
| encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects | |
| mask_map: (batch_size, num_objects) tensor of bool mask | |
| Returns: | |
| (batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor | |
| """ | |
| encoded_x = self._agent_encoder_past(x, mask_x) | |
| squeeze_output_sample_dim = False | |
| if z_samples.ndim == 3: | |
| batch_size, num_agents, latent_dim = z_samples.shape | |
| num_samples = 1 | |
| z_samples = rearrange(z_samples, "b a l -> b a () l") | |
| squeeze_output_sample_dim = True | |
| else: | |
| batch_size, num_agents, num_samples, latent_dim = z_samples.shape | |
| mask_z = repeat(mask_z, "b a -> (b s) a", s=num_samples) | |
| mask_map = repeat(mask_map, "b o -> (b s) o", s=num_samples) | |
| encoded_x = repeat(encoded_x, "b a l -> (b s) a l", s=num_samples) | |
| encoded_absolute = repeat( | |
| encoded_absolute, "b a l -> (b s) a l", s=num_samples | |
| ) | |
| encoded_map = repeat(encoded_map, "b o l -> (b s) o l", s=num_samples) | |
| z_samples = rearrange(z_samples, "b a s l -> (b s) a l") | |
| h = self._combine_z_past(torch.cat([z_samples, encoded_x], dim=-1)) | |
| h = self._interaction(h, mask_z, encoded_absolute, encoded_map, mask_map) | |
| h = self._decoder(h, self.num_steps_future) | |
| if not squeeze_output_sample_dim: | |
| h = rearrange(h, "(b s) a t l -> b a s t l", b=batch_size, s=num_samples) | |
| return h | |
| class CVAEAccelerationDecoder(nn.Module): | |
| """Decoder architecture for conditional variational autoencoder | |
| Args: | |
| model: decoder neural network that transforms input tensors to an output sequence | |
| """ | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| ) -> None: | |
| super().__init__() | |
| self._model = model | |
| self._output_layer = nn.Linear(model.hidden_dim, 2) | |
| def forward( | |
| self, | |
| z_samples: torch.Tensor, | |
| mask_z: torch.Tensor, | |
| x: torch.Tensor, | |
| mask_x: torch.Tensor, | |
| encoded_absolute: torch.Tensor, | |
| encoded_map: torch.Tensor, | |
| mask_map: torch.Tensor, | |
| offset: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Forward function that decodes input tensors into an output tensor of size | |
| (batch_size, num_agents, (n_samples), num_steps_future, state_dim=5) | |
| It first predicts accelerations that are doubly integrated to produce the output | |
| state sequence with positions angles and velocities (x, y, theta, vx, vy) or (x, y, vx, vy) or (x, y) | |
| Args: | |
| z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history | |
| mask_z: (batch_size, num_agents) tensor of bool mask | |
| x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents | |
| mask_x: (batch_size, num_agents, num_steps) tensor of bool mask | |
| encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
| encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects | |
| mask_map: (batch_size, num_objects) tensor of bool mask | |
| Returns: | |
| (batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor. Sample dimension | |
| does not exist if z_samples is a 2D tensor. | |
| """ | |
| h = self._model( | |
| z_samples, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map | |
| ) | |
| h = self._output_layer(h) | |
| dt = self._model.dt | |
| initial_position = x[..., -1:, :2].clone() | |
| # If shape is 5 it should be (x, y, angle, vx, vy) | |
| if offset.shape[-1] == 5: | |
| initial_velocity = offset[..., 3:5].clone().unsqueeze(-2) | |
| # else if shape is 4 it should be (x, y, vx, vy) | |
| elif offset.shape[-1] == 4: | |
| initial_velocity = offset[..., 2:4].clone().unsqueeze(-2) | |
| elif x.shape[-1] == 5: | |
| initial_velocity = x[..., -1:, 3:5].clone() | |
| elif x.shape[-1] == 4: | |
| initial_velocity = x[..., -1:, 2:4].clone() | |
| else: | |
| initial_velocity = (x[..., -1:, :] - x[..., -2:-1, :]) / dt | |
| output = torch.zeros( | |
| (*h.shape[:-1], self._model.dynamic_state_dim), device=h.device | |
| ) | |
| # There might be a sample dimension in the output tensor, then adapt the shape of initial position and velocity | |
| if output.ndim == 5: | |
| initial_position = initial_position.unsqueeze(-3) | |
| initial_velocity = initial_velocity.unsqueeze(-3) | |
| if self._model.dynamic_state_dim == 5: | |
| output[..., 3:5] = h.cumsum(-2) * dt | |
| output[..., :2] = (output[..., 3:5].clone() + initial_velocity).cumsum( | |
| -2 | |
| ) * dt + initial_position | |
| output[..., 2] = torch.atan2(output[..., 4].clone(), output[..., 3].clone()) | |
| elif self._model.dynamic_state_dim == 4: | |
| output[..., 2:4] = h.cumsum(-2) * dt | |
| output[..., :2] = (output[..., 2:4].clone() + initial_velocity).cumsum( | |
| -2 | |
| ) * dt + initial_position | |
| else: | |
| velocity = h.cumsum(-2) * dt | |
| output = (velocity.clone() + initial_velocity).cumsum( | |
| -2 | |
| ) * dt + initial_position | |
| return output | |
| class CVAEParametrizedDecoder(nn.Module): | |
| """Decoder architecture for conditional variational autoencoder | |
| Args: | |
| model: decoder neural network that transforms input tensors to an output sequence | |
| """ | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| ) -> None: | |
| super().__init__() | |
| self._model = model | |
| self._order = 3 | |
| self._output_layer = nn.Linear( | |
| model.hidden_dim * model.num_steps_future, | |
| 2 * self._order + model.num_steps_future, | |
| ) | |
| def polynomial(self, x: torch.Tensor, params: torch.Tensor): | |
| """Polynomial function that takes a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) and | |
| a parameter tensor of shape (batch_size, num_agents, (n_samples), self._order*2) and returns a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) | |
| """ | |
| h = x.clone() | |
| squeeze = False | |
| if h.ndim == 3: | |
| h = h.unsqueeze(2) | |
| params = params.unsqueeze(2) | |
| squeeze = True | |
| h = repeat( | |
| h, | |
| "batch agents samples sequence -> batch agents samples sequence two order", | |
| order=self._order, | |
| two=2, | |
| ).cumprod(-1) | |
| h = h * params.view(*params.shape[:-1], 1, 2, self._order) | |
| h = h.sum(-1) | |
| if squeeze: | |
| h = h.squeeze(2) | |
| return h | |
| def dpolynomial(self, x: torch.Tensor, params: torch.Tensor): | |
| """Derivative of the polynomial function that takes a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) and | |
| a parameter tensor of shape (batch_size, num_agents, (n_samples), self._order*2) and returns a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) | |
| """ | |
| h = x.clone() | |
| squeeze = False | |
| if h.ndim == 3: | |
| h = h.unsqueeze(2) | |
| params = params.unsqueeze(2) | |
| squeeze = True | |
| h = repeat( | |
| h, | |
| "batch agents samples sequence -> batch agents samples sequence two order", | |
| order=self._order - 1, | |
| two=2, | |
| ) | |
| h = torch.cat((torch.ones_like(h[..., :1]), h.cumprod(-1)), -1) | |
| h = h * params.view(*params.shape[:-1], 1, 2, self._order) | |
| h = h * torch.arange(self._order).view(*([1] * params.ndim), -1).to(x.device) | |
| h = h.sum(-1) | |
| if squeeze: | |
| h = h.squeeze(2) | |
| return h | |
| def forward( | |
| self, | |
| z_samples: torch.Tensor, | |
| mask_z: torch.Tensor, | |
| x: torch.Tensor, | |
| mask_x: torch.Tensor, | |
| encoded_absolute: torch.Tensor, | |
| encoded_map: torch.Tensor, | |
| mask_map: torch.Tensor, | |
| offset: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Forward function that decodes input tensors into an output tensor of size | |
| (batch_size, num_agents, (n_samples), num_steps_future, state_dim=5) | |
| It first predicts accelerations that are doubly integrated to produce the output | |
| state sequence with positions angles and velocities (x, y, theta, vx, vy) or (x, y, vx, vy) or (x, y) | |
| Args: | |
| z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history | |
| mask_z: (batch_size, num_agents) tensor of bool mask | |
| x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents | |
| mask_x: (batch_size, num_agents, num_steps) tensor of bool mask | |
| encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions | |
| encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects | |
| mask_map: (batch_size, num_objects) tensor of bool mask | |
| Returns: | |
| (batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor. Sample dimension | |
| does not exist if z_samples is a 2D tensor. | |
| """ | |
| squeeze_output_sample_dim = z_samples.ndim == 3 | |
| batch_size = z_samples.shape[0] | |
| h = self._model( | |
| z_samples, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map | |
| ) | |
| if squeeze_output_sample_dim: | |
| h = rearrange( | |
| h, "batch agents sequence features -> batch agents (sequence features)" | |
| ) | |
| else: | |
| h = rearrange( | |
| h, | |
| "(batch samples) agents sequence features -> batch agents samples (sequence features)", | |
| batch=batch_size, | |
| ) | |
| h = self._output_layer(h) | |
| output = torch.zeros( | |
| ( | |
| *h.shape[:-1], | |
| self._model.num_steps_future, | |
| self._model.dynamic_state_dim, | |
| ), | |
| device=h.device, | |
| ) | |
| params = h[..., : 2 * self._order] | |
| dldt = torch.relu(h[..., 2 * self._order :]) | |
| distance = dldt.cumsum(-2) | |
| output[..., :2] = self.polynomial(distance, params) | |
| if self._model.dynamic_state_dim == 5: | |
| output[..., 3:5] = dldt * self.dpolynomial(distance, params) | |
| output[..., 2] = torch.atan2(output[..., 4].clone(), output[..., 3].clone()) | |
| elif self._model.dynamic_state_dim == 4: | |
| output[..., 2:4] = dldt * self.dpolynomial(distance, params) | |
| return output | |