Spaces:
Runtime error
Runtime error
| from huggingface_hub import hf_hub_url, cached_download | |
| from mmcv import Config | |
| import torch | |
| from risk_biased.utils.load_model import get_predictor | |
| from risk_biased.utils.torch_utils import load_weights | |
| from risk_biased.utils.waymo_dataloader import WaymoDataloaders | |
| config_file = cached_download(hf_hub_url("jmercat/risk_biased_model", filename="learning_config.py"), force_filename="learing_config.py") | |
| ckpt = torch.load(cached_download(hf_hub_url("jmercat/risk_biased_model", filename="last.ckpt"), force_filename="last.ckpt"), map_location="cpu") | |
| cfg = Config.fromfile(config_file) | |
| predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) | |
| predictor = load_weights(predictor, ckpt) | |
| print("Model loaded") |