Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Callable, Tuple | |
| import torch | |
| from mmcv import Config | |
| from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor | |
| from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator | |
| from risk_biased.mpc_planner.planner_cost import TrackingCostParams | |
| from risk_biased.mpc_planner.solver import CrossEntropySolver, CrossEntropySolverParams | |
| from risk_biased.mpc_planner.planner_cost import TrackingCost | |
| from risk_biased.utils.cost import TTCCostTorch, TTCCostParams | |
| from risk_biased.utils.planner_utils import AbstractState, to_state | |
| from risk_biased.utils.risk import get_risk_estimator | |
| class MPCPlannerParams: | |
| """Dataclass for MPC-Planner Parameters | |
| Args: | |
| dt_s: discrete time interval in seconds that is used for planning | |
| num_steps: number of time steps for which history of ego's and the other actor's | |
| trajectories are stored | |
| num_steps_future: number of time steps into the future for which ego's and the other actor's | |
| trajectories are considered | |
| acceleration_std_x_m_s2: Acceleration noise standard deviation (m/s^2) in x-direction that | |
| is used to initialize the Cross Entropy solver | |
| acceleration_std_y_m_s2: Acceleration noise standard deviation (m/s^2) in y-direction that | |
| is used to initialize the Cross Entropy solver | |
| risk_estimator_params: parameters for the Monte Carlo risk estimator used in the planner for | |
| ego's control optimization | |
| solver_params: parameters for the CrossEntropySolver | |
| tracking_cost_params: parameters for the TrackingCost | |
| ttc_cost_params: parameters for the TTCCost (i.e., collision cost between ego and the other | |
| actor) | |
| """ | |
| dt: float | |
| num_steps: int | |
| num_steps_future: int | |
| acceleration_std_x_m_s2: float | |
| acceleration_std_y_m_s2: float | |
| risk_estimator_params: dict | |
| solver_params: CrossEntropySolverParams | |
| tracking_cost_params: TrackingCostParams | |
| ttc_cost_params: TTCCostParams | |
| def from_config(cfg: Config): | |
| return MPCPlannerParams( | |
| cfg.dt, | |
| cfg.num_steps, | |
| cfg.num_steps_future, | |
| cfg.acceleration_std_x_m_s2, | |
| cfg.acceleration_std_y_m_s2, | |
| cfg.risk_estimator, | |
| CrossEntropySolverParams.from_config(cfg), | |
| TrackingCostParams.from_config(cfg), | |
| TTCCostParams.from_config(cfg), | |
| ) | |
| class MPCPlanner: | |
| """MPC Planner with a Cross Entropy solver | |
| Args: | |
| params: MPCPlannerParams object | |
| predictor: LitTrajectoryPredictor object | |
| normalizer: function that takes in an unnormalized trajectory and that outputs the | |
| normalized trajectory and the offset in this order | |
| """ | |
| def __init__( | |
| self, | |
| params: MPCPlannerParams, | |
| predictor: LitTrajectoryPredictor, | |
| normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], | |
| ) -> None: | |
| self.params = params | |
| self.dynamics_model = PositionVelocityDoubleIntegrator(params.dt) | |
| self.control_input_mean_init = torch.zeros( | |
| 1, params.num_steps_future, self.dynamics_model.control_dim | |
| ) | |
| self.control_input_std_init = torch.Tensor( | |
| [ | |
| params.acceleration_std_x_m_s2, | |
| params.acceleration_std_y_m_s2, | |
| ] | |
| ).expand_as(self.control_input_mean_init) | |
| self.solver = CrossEntropySolver( | |
| params=params.solver_params, | |
| dynamics_model=self.dynamics_model, | |
| control_input_mean=self.control_input_mean_init, | |
| control_input_std=self.control_input_std_init, | |
| tracking_cost_function=TrackingCost(params.tracking_cost_params), | |
| interaction_cost_function=TTCCostTorch(params.ttc_cost_params), | |
| risk_estimator=get_risk_estimator(params.risk_estimator_params), | |
| ) | |
| self.predictor = predictor | |
| self.normalizer = normalizer | |
| self._ego_state_history = [] | |
| self._ego_state_target_trajectory = None | |
| self._ego_state_planned_trajectory = None | |
| self._ado_state_history = [] | |
| self._latest_ado_position_future_samples = None | |
| def replan( | |
| self, | |
| current_ado_state: AbstractState, | |
| current_ego_state: AbstractState, | |
| target_velocity: torch.Tensor, | |
| num_prediction_samples: int = 1, | |
| risk_level: float = 0.0, | |
| resample_prediction: bool = False, | |
| risk_in_predictor: bool = False, | |
| ) -> None: | |
| """Performs re-planning given the current_ado_position, current_ego_state, and | |
| target_velocity. Updates ego_state_planned_trajectory. Note that all the information given | |
| to the solver.solve(...) is expressed in the ego-centric frame, whose origin is the initial | |
| ego position in ego_state_history and the x-direction is parallel to the initial ego | |
| velocity. | |
| Args: | |
| current_ado_position: ado state | |
| current_ego_state: ego state | |
| target_velocity: ((1), 2) tensor | |
| num_prediction_samples (optional): number of prediction samples. Defaults to 1. | |
| risk_level (optional): a risk-level float for the entire prediction-planning pipeline. | |
| If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0. | |
| resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy | |
| iteration. Defaults to False. | |
| risk_in_predictor (optional): If True, risk-biased prediction is used and the solver | |
| becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes | |
| risk-sensitive. Defaults to False. | |
| """ | |
| self._update_ado_state_history(current_ado_state) | |
| self._update_ego_state_history(current_ego_state) | |
| self._update_ego_state_target_trajectory(current_ego_state, target_velocity) | |
| if not self.ado_state_history.shape[-1] < self.params.num_steps: | |
| self.solver.solve( | |
| self.predictor, | |
| self._map_to_ego_centric_frame(self.ego_state_history), | |
| self._map_to_ego_centric_frame(self._ego_state_target_trajectory), | |
| self._map_to_ego_centric_frame(self.ado_state_history), | |
| self.normalizer, | |
| num_prediction_samples=num_prediction_samples, | |
| risk_level=risk_level, | |
| resample_prediction=resample_prediction, | |
| risk_in_predictor=risk_in_predictor, | |
| ) | |
| ego_state_planned_trajectory_in_ego_frame = self.dynamics_model.simulate( | |
| self._map_to_ego_centric_frame(self.ego_state_history[..., -1]), | |
| self.solver.control_sequence, | |
| ) | |
| self._ego_state_planned_trajectory = self._map_to_world_frame( | |
| ego_state_planned_trajectory_in_ego_frame | |
| ) | |
| latest_ado_position_future_samples_in_ego_frame = ( | |
| self.solver.fetch_latest_prediction() | |
| ) | |
| if latest_ado_position_future_samples_in_ego_frame is not None: | |
| self._latest_ado_position_future_samples = self._map_to_world_frame( | |
| latest_ado_position_future_samples_in_ego_frame | |
| ) | |
| else: | |
| self._latest_ado_position_future_samples = None | |
| def get_planned_next_ego_state(self) -> AbstractState: | |
| """Returns the next ego state according to the ego_state_planned_trajectory | |
| Returns: | |
| Planned state | |
| """ | |
| assert ( | |
| self._ego_state_planned_trajectory is not None | |
| ), "call self.replan(...) first" | |
| return self._ego_state_planned_trajectory[..., 0] | |
| def reset(self) -> None: | |
| """Resets the planner's internal state. This will fully reset the solver's internal state, | |
| including solver.control_input_mean_init and solver.control_input_std_init.""" | |
| self.solver.control_input_mean_init = ( | |
| self.control_input_mean_init.detach().clone() | |
| ) | |
| self.solver.control_input_std_init = ( | |
| self.control_input_std_init.detach().clone() | |
| ) | |
| self.solver.reset() | |
| self._ego_state_history = [] | |
| self._ego_state_target_trajectory = None | |
| self._ego_state_planned_trajectory = None | |
| self._ado_state_history = [] | |
| self._latest_ado_position_future_samples = None | |
| def fetch_latest_prediction(self) -> torch.Tensor: | |
| if self._latest_ado_position_future_samples is not None: | |
| return self._latest_ado_position_future_samples | |
| else: | |
| return None | |
| def ego_state_history(self) -> torch.Tensor: | |
| """Returns ego_state_history as a concatenated tensor | |
| Returns: | |
| ego_state_history tensor | |
| """ | |
| assert len(self._ego_state_history) > 0 | |
| return to_state( | |
| torch.stack( | |
| [ego_state.get_states(4) for ego_state in self._ego_state_history], | |
| dim=-2, | |
| ), | |
| self.params.dt, | |
| ) | |
| def ado_state_history(self) -> torch.Tensor: | |
| """Returns ado_position_history as a concatenated tensor | |
| Returns: | |
| ado_position_history tensor | |
| """ | |
| assert len(self._ado_state_history) > 0 | |
| return to_state( | |
| torch.stack( | |
| [ado_state.get_states(4) for ado_state in self._ado_state_history], | |
| dim=-2, | |
| ), | |
| self.params.dt, | |
| ) | |
| def _update_ego_state_history(self, current_ego_state: AbstractState) -> None: | |
| """Updates ego_state_history with the current_ego_state | |
| Args: | |
| current_ego_state: (1, state_dim) tensor | |
| """ | |
| if len(self._ego_state_history) >= self.params.num_steps: | |
| self._ego_state_history = self._ego_state_history[1:] | |
| self._ego_state_history.append(current_ego_state) | |
| assert len(self._ego_state_history) <= self.params.num_steps | |
| def _update_ado_state_history(self, current_ado_state: AbstractState) -> None: | |
| """Updates ego_state_history with the current_ado_position | |
| Args: | |
| current_ado_state states of the current non-ego vehicles | |
| """ | |
| if len(self._ado_state_history) >= self.params.num_steps: | |
| self._ado_state_history = self._ado_state_history[1:] | |
| self._ado_state_history.append(current_ado_state) | |
| assert len(self._ado_state_history) <= self.params.num_steps | |
| def _update_ego_state_target_trajectory( | |
| self, current_ego_state: AbstractState, target_velocity: torch.Tensor | |
| ) -> None: | |
| """Updates ego_state_target_trajectory based on the current_ego_state and the target_velocity | |
| Args: | |
| current_ego_state: state | |
| target_velocity: (1, 2) tensor | |
| """ | |
| target_displacement = self.params.dt * target_velocity | |
| target_position_list = [current_ego_state.position] | |
| for time_idx in range(self.params.num_steps_future): | |
| target_position_list.append(target_position_list[-1] + target_displacement) | |
| target_position_list = target_position_list[1:] | |
| target_position = torch.cat(target_position_list, dim=-2) | |
| target_state = to_state( | |
| torch.cat( | |
| (target_position, target_velocity.expand_as(target_position)), dim=-1 | |
| ), | |
| self.params.dt, | |
| ) | |
| self._ego_state_target_trajectory = target_state | |
| def _map_to_ego_centric_frame( | |
| self, trajectory_in_world_frame: AbstractState | |
| ) -> torch.Tensor: | |
| """Maps trajectory epxressed in the world frame to the ego-centric frame, whose origin is | |
| the initial ego position in ego_state_history and the x-direction is parallel to the initial | |
| ego velocity | |
| Args: | |
| trajectory: sequence of states | |
| Returns: | |
| trajectory mapped to the ego-centric frame | |
| """ | |
| # If trajectory_in_world_frame is of shape (..., state_dim) then use the associated | |
| # dynamics model in translate_position and rotate_angle. Otherwise assume that th | |
| # trajectory is in the 2D position space. | |
| ego_pos_init = self.ego_state_history.position[..., -1, :] | |
| ego_vel_init = self.ego_state_history.velocity[..., -1, :] | |
| ego_rot_init = torch.atan2(ego_vel_init[..., 1], ego_vel_init[..., 0]) | |
| trajectory_in_ego_frame = trajectory_in_world_frame.translate( | |
| -ego_pos_init | |
| ).rotate(-ego_rot_init) | |
| return trajectory_in_ego_frame | |
| def _map_to_world_frame( | |
| self, trajectory_in_ego_frame: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Maps trajectory epxressed in the ego-centric frame to the world frame | |
| Args: | |
| trajectory_in_ego_frame: (..., 2) position trajectory or (..., markov_state_dim) state | |
| trajectory expressed in the ego-centric frame, whose origin is the initial ego | |
| position in ego_state_history and the x-direction is parallel to the initial ego | |
| velocity | |
| Returns: | |
| trajectory mapped to the world frame | |
| """ | |
| # state starts with x, y, angle | |
| ego_pos_init = self.ego_state_history.position[..., -1, :] | |
| ego_rot_init = self.ego_state_history.angle[..., -1, :] | |
| trajectory_in_world_frame = trajectory_in_ego_frame.rotate( | |
| ego_rot_init | |
| ).translate(ego_pos_init) | |
| return trajectory_in_world_frame | |