Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import Callable, Dict, List, Optional, Tuple | |
| from mmcv import Config | |
| import numpy as np | |
| import torch | |
| from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator | |
| from risk_biased.mpc_planner.planner_cost import TrackingCost | |
| from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor | |
| from risk_biased.utils.cost import BaseCostTorch | |
| from risk_biased.utils.planner_utils import ( | |
| AbstractState, | |
| to_state, | |
| evaluate_risk, | |
| get_interaction_cost, | |
| ) | |
| from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator | |
| class CrossEntropySolverParams: | |
| """Dataclass for Cross Entropy Solver Parameters | |
| Args: | |
| num_control_samples: number of Monte Carlo samples for control input | |
| num_elite: number of elite samples | |
| iter_max: maximum iteration number | |
| smoothing_factor: smoothing factor in (0, 1) used to update the mean and the std of the | |
| control input distribution for the next iteration. If 0, the updated distribution is | |
| independent of the previous iteration. If 1, the updated distribution is the same as the | |
| previous iteration. | |
| mean_warm_start: internally saves control_input_mean of the last iteration of the current | |
| solve, so that control_input_mean will be warm-started in the next solve | |
| """ | |
| num_control_samples: int | |
| num_elite: int | |
| iter_max: int | |
| smoothing_factor: float | |
| mean_warm_start: bool | |
| dt: float | |
| def from_config(cfg: Config): | |
| return CrossEntropySolverParams( | |
| cfg.num_control_samples, | |
| cfg.num_elite, | |
| cfg.iter_max, | |
| cfg.smoothing_factor, | |
| cfg.mean_warm_start, | |
| cfg.dt, | |
| ) | |
| class CrossEntropySolver: | |
| """Cross Entropy Solver for MPC Planner | |
| Args: | |
| params: CrossEntropySolverParams object | |
| dynamics_model: dynamics model for control | |
| control_input_mean: (num_agents, num_steps_future, control_dim) tensor of control input mean | |
| control_input_std: (num_agents, num_steps_future, control_dim) tensor of control input std | |
| tracking_cost_function: deterministic tracking cost that does not involve ado | |
| intraction_cost_function: interaction cost function between ego and (stochastic) ado | |
| risk_estimator (optional): Monte Carlo risk estimator for risk computation. If None, | |
| risk-neutral expecation is used for selectoin of elites. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| params: CrossEntropySolverParams, | |
| dynamics_model: PositionVelocityDoubleIntegrator, | |
| control_input_mean: torch.Tensor, | |
| control_input_std: torch.Tensor, | |
| tracking_cost_function: TrackingCost, | |
| interaction_cost_function: BaseCostTorch, | |
| risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None, | |
| ) -> None: | |
| self.params = params | |
| self.control_input_mean_init = control_input_mean.detach().clone() | |
| self.control_input_std_init = control_input_std.detach().clone() | |
| assert ( | |
| self.control_input_mean_init.shape == self.control_input_std_init.shape | |
| ), "control input mean and std must have the same size" | |
| assert ( | |
| self.control_input_mean_init.shape[-1] == dynamics_model.control_dim | |
| ), f"control dimension must be {dynamics_model.control_dim}" | |
| self.dynamics_model = dynamics_model | |
| self.tracking_cost = tracking_cost_function | |
| self.interaction_cost = interaction_cost_function | |
| self.risk_estimator = risk_estimator | |
| self._iter_current = None | |
| self._control_input_mean = None | |
| self._control_input_std = None | |
| self._latest_ado_position_future_samples = None | |
| self.reset() | |
| def reset(self) -> None: | |
| """Resets the solver's internal state""" | |
| self._iter_current = 0 | |
| self._control_input_mean = self.control_input_mean_init.clone() | |
| self._control_input_std = self.control_input_std_init.clone() | |
| self._latest_ado_position_future_samples = None | |
| def step( | |
| self, | |
| ego_state_history: AbstractState, | |
| ego_state_target_trajectory: AbstractState, | |
| ado_state_future_samples: AbstractState, | |
| weights: torch.Tensor, | |
| verbose: bool = False, | |
| risk_level: float = 0.0, | |
| ) -> Dict: | |
| """Performs one iteration step of the Cross Entropy Method | |
| Args: | |
| ego_state_history: (num_agents, num_steps) ego state history | |
| ego_state_target_trajectory: (num_agents, num_steps_future) ego target | |
| state trajectory | |
| ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future) | |
| predicted ado trajectory samples | |
| weights: (num_prediction_samples, num_agents) prediction sample weight | |
| verbose (optional): Print progress. Defaults to False. | |
| risk_level (optional): a risk-level float for the solver. If 0.0, risk-neutral | |
| expectation is used for selection of elites. Defaults to 0.0. | |
| Return: | |
| Dictionary containing information about this solver step. | |
| """ | |
| self._iter_current += 1 | |
| ego_control_input = torch.normal( | |
| self._control_input_mean.expand( | |
| self.params.num_control_samples, -1, -1, -1 | |
| ), | |
| self._control_input_std.expand(self.params.num_control_samples, -1, -1, -1), | |
| ) | |
| if verbose: | |
| print(f"**Cross Entropy Iteration {self._iter_current}") | |
| print( | |
| f"****Drawring ego's control input samples of {ego_control_input.shape}" | |
| ) | |
| ego_state_current = ego_state_history[..., -1] | |
| ego_state_future = self.dynamics_model.simulate( | |
| ego_state_current, ego_control_input | |
| ) | |
| if verbose: | |
| print(f"****Simulating ego's future state trajectory") | |
| # state starts with x, y, angle, vx, vy | |
| tracking_cost = self.tracking_cost( | |
| ego_state_future.position, | |
| ego_state_target_trajectory.position, | |
| ego_state_target_trajectory.velocity, | |
| ) | |
| if verbose: | |
| print( | |
| f"****Computing tracking cost of {tracking_cost.shape} for the control input samples" | |
| ) | |
| # state starts with x, y | |
| interaction_cost = get_interaction_cost( | |
| ego_state_future, | |
| ado_state_future_samples, | |
| self.interaction_cost, | |
| ) | |
| if verbose: | |
| print( | |
| f"****Computing interaction cost of {interaction_cost.shape} for the control input samples" | |
| ) | |
| interaction_risk = evaluate_risk( | |
| risk_level, | |
| interaction_cost, | |
| weights.permute(1, 0).unsqueeze(0).expand_as(interaction_cost), | |
| self.risk_estimator, | |
| ) | |
| total_risk = interaction_risk + tracking_cost | |
| elite_ego_control_input, elite_total_risk = self._get_elites( | |
| ego_control_input, total_risk | |
| ) | |
| if verbose: | |
| print(f"****Selecting {self.params.num_elite} elite samples") | |
| print(f"****Elite Total_Risk Information: {elite_total_risk}") | |
| info = dict( | |
| iteration=self._iter_current, | |
| control_input_mean=self._control_input_mean.detach().cpu().numpy().copy(), | |
| control_input_std=self._control_input_std.detach().cpu().numpy().copy(), | |
| ego_state_future=ego_state_future.get_states(5) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| .copy(), | |
| ado_state_future_samples=ado_state_future_samples.get_states(5) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| .copy(), | |
| sample_weights=weights.detach().cpu().numpy().copy(), | |
| tracking_cost=tracking_cost.detach().cpu().numpy().copy(), | |
| interaction_cost=interaction_cost.detach().cpu().numpy().copy(), | |
| total_risk=total_risk.detach().cpu().numpy().copy(), | |
| ) | |
| self._update_control_distribution(elite_ego_control_input) | |
| if verbose: | |
| print("****Updating ego's control distribution") | |
| return info | |
| def solve( | |
| self, | |
| predictor: LitTrajectoryPredictor, | |
| ego_state_history: AbstractState, | |
| ego_state_target_trajectory: AbstractState, | |
| ado_state_history: AbstractState, | |
| normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], | |
| num_prediction_samples: int = 1, | |
| verbose: bool = False, | |
| risk_level: float = 0.0, | |
| resample_prediction: bool = False, | |
| risk_in_predictor: bool = False, | |
| ) -> List[Dict]: | |
| """Performs Cross Entropy optimization of ego's control input | |
| Args: | |
| predictor: LitTrajectoryPredictor object | |
| ego_state_history: (num_agents, num_steps, state_dim) ego state history | |
| ego_state_target_trajectory: (num_agents, num_steps_future, state_dim) ego target | |
| state trajectory | |
| ado_state_history: (num_agents, num_steps, state_dim) ado state history | |
| normalizer: function that takes in an unnormalized trajectory and that outputs the | |
| normalized trajectory and the offset in this order | |
| num_prediction_samples: number of prediction samples. Defaults to 1. | |
| verbose (optional): Print progress. Defaults to False. | |
| risk_level (optional): a risk-level float for the entire prediction-planning pipeline. | |
| If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0. | |
| resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy | |
| iteration. Defaults to False. | |
| risk_in_predictor (optional): If True, risk-biased prediction is used and the solver | |
| becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes | |
| risk-sensitive. Defaults to False. | |
| Return: | |
| List of dictionaries each containing information about the corresponding solver step. | |
| """ | |
| if risk_level == 0.0: | |
| risk_level_planner, risk_level_predictor = 0.0, 0.0 | |
| else: | |
| if risk_in_predictor: | |
| risk_level_planner, risk_level_predictor = 0.0, risk_level | |
| else: | |
| risk_level_planner, risk_level_predictor = risk_level, 0.0 | |
| self.reset() | |
| infos = [] | |
| ego_state_future = self.dynamics_model.simulate( | |
| ego_state_history[..., -1], | |
| self.control_sequence, | |
| ) | |
| for iter in range(self.params.iter_max): | |
| assert iter == self._iter_current | |
| if resample_prediction or self._iter_current == 0: | |
| ado_state_future_samples, weights = self.sample_prediction( | |
| predictor, | |
| ado_state_history, | |
| normalizer, | |
| ego_state_history, | |
| ego_state_future, | |
| num_prediction_samples, | |
| risk_level_predictor, | |
| ) | |
| self._latest_ado_position_future_samples = ado_state_future_samples | |
| info = self.step( | |
| ego_state_history, | |
| ego_state_target_trajectory, | |
| ado_state_future_samples, | |
| weights, | |
| verbose=verbose, | |
| risk_level=risk_level_planner, | |
| ) | |
| infos.append(info) | |
| if self.params.mean_warm_start: | |
| self.control_input_mean_init[:, :-1] = ( | |
| self._control_input_mean[:, 1:].detach().clone() | |
| ) | |
| return infos | |
| def control_sequence(self) -> torch.Tensor: | |
| """Returns the planned control sequence, which is a detached copy of the control input mean | |
| tensor | |
| Returns: | |
| (num_steps_future, control_dim) control sequence tensor | |
| """ | |
| return self._control_input_mean.detach().clone() | |
| def sample_prediction( | |
| predictor: LitTrajectoryPredictor, | |
| ado_state_history: AbstractState, | |
| normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], | |
| ego_state_history: AbstractState, | |
| ego_state_future: AbstractState, | |
| num_prediction_samples: int = 1, | |
| risk_level: float = 0.0, | |
| ) -> Tuple[AbstractState, torch.Tensor]: | |
| """Sample prediction from the predictor given the history, normalizer, and the desired | |
| risk-level | |
| Args: | |
| predictor: LitTrajectoryPredictor object | |
| ado_state_history: (num_agents, num_steps, state_dim) tensor of ado position history | |
| normalizer: function that takes in an unnormalized trajectory and that outputs the | |
| normalized trajectory and the offset in this order | |
| ego_state_history: (num_agents, num_steps , state_dim) tensor of ego position history or future | |
| ego_state_future: (num_agents, num_steps_future, state_dim) tensor of ego position history or future | |
| num_prediction_samples (optional): number of prediction samples. Defaults to 1. | |
| risk_level (optional): a risk-level float for the predictor. If 0.0, risk-neutral | |
| prediction is sampled. Defaults to 0.0. | |
| Returns: | |
| state samples of shape (num_agents, num_prediction_samples, num_steps_future) | |
| probability weights of the samples of shape (num_agents, num_prediction_samples) | |
| """ | |
| ado_position_history_normalized, offset = normalizer( | |
| ado_state_history.get_states(predictor.dynamic_state_dim) | |
| .unsqueeze(0) | |
| .expand(num_prediction_samples, -1, -1, -1) | |
| ) | |
| x = ado_position_history_normalized.clone() | |
| mask_x = torch.ones_like(x[..., 0]) | |
| map = torch.empty(num_prediction_samples, 0, 0, 2, device=x.device) | |
| mask_map = torch.empty(num_prediction_samples, 0, 0, device=x.device) | |
| batch = ( | |
| x, | |
| mask_x, | |
| map, | |
| mask_map, | |
| offset, | |
| ego_state_history.get_states(predictor.dynamic_state_dim) | |
| .unsqueeze(0) | |
| .expand(num_prediction_samples, -1, -1, -1), | |
| ego_state_future.get_states(predictor.dynamic_state_dim) | |
| .unsqueeze(0) | |
| .expand(num_prediction_samples, -1, -1, -1), | |
| ) | |
| ado_position_future_samples, weights = predictor.predict_step( | |
| batch, | |
| 0, | |
| risk_level=risk_level, | |
| return_weights=True, | |
| ) | |
| ado_position_future_samples = ado_position_future_samples.detach().cpu() | |
| weights = weights.detach().cpu() | |
| return to_state(ado_position_future_samples, predictor.dt), weights | |
| def fetch_latest_prediction(self): | |
| if self._latest_ado_position_future_samples is not None: | |
| return self._latest_ado_position_future_samples | |
| else: | |
| return None | |
| def _get_elites( | |
| self, control_input: torch.Tensor, risk: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Selects elite control input based on corresponding risk (lower the better) | |
| Args: | |
| control_input: (num_control_samples, num_agents, num_steps_future, control_dim) control samples | |
| risk: (num_control_samples, num_agents) risk tensor | |
| Returns: | |
| elite_control_input: (num_elite, num_agents, num_steps_future, control_dim) elite control | |
| elite_risk: (num_elite, num_agents) elite risk | |
| """ | |
| num_control_samples = self.params.num_control_samples | |
| assert ( | |
| control_input.shape[0] == num_control_samples | |
| ), f"size of control_input tensor must be {num_control_samples} at dimension 0" | |
| assert ( | |
| risk.shape[0] == num_control_samples | |
| ), f"size of risk tensor must be {num_control_samples} at dimension 0" | |
| _, sorted_risk_indices = torch.sort(risk, dim=0) | |
| elite_control_input = control_input[ | |
| sorted_risk_indices[: self.params.num_elite], np.arange(risk.shape[1]) | |
| ] | |
| elite_risk = risk[ | |
| sorted_risk_indices[: self.params.num_elite], np.arange(risk.shape[1]) | |
| ] | |
| return elite_control_input, elite_risk | |
| def _update_control_distribution(self, elite_control_input: torch.Tensor) -> None: | |
| """Updates control input distribution using elites | |
| Args: | |
| elite_control_input: (num_elite, num_steps_future, control_dim) elite control | |
| """ | |
| num_elite, smoothing_factor = ( | |
| self.params.num_elite, | |
| self.params.smoothing_factor, | |
| ) | |
| assert ( | |
| elite_control_input.shape[0] == num_elite | |
| ), f"size of elite_control_input tensor must be {num_elite} at dimension 0" | |
| elite_control_input_mean = elite_control_input.mean(dim=0, keepdim=False) | |
| if num_elite < 2: | |
| elite_control_input_std = torch.zeros_like(elite_control_input_mean) | |
| else: | |
| elite_control_input_std = elite_control_input.std(dim=0, keepdim=False) | |
| self._control_input_mean = ( | |
| 1.0 - smoothing_factor | |
| ) * elite_control_input_mean + smoothing_factor * self._control_input_mean | |
| self._control_input_std = ( | |
| 1.0 - smoothing_factor | |
| ) * elite_control_input_std + smoothing_factor * self._control_input_std | |