| from gluonts.dataset.common import Dataset | |
| from .models import ( | |
| AbstractPredictor, | |
| AutoGluonPredictor, | |
| AutoPyTorchPredictor, | |
| DeepARPredictor, | |
| TFTPredictor, | |
| AutoARIMAPredictor, | |
| AutoETSPredictor, | |
| AutoThetaPredictor, | |
| StatsEnsemblePredictor, | |
| ) | |
| MODEL_NAME_TO_CLASS = { | |
| "autogluon": AutoGluonPredictor, | |
| "autopytorch": AutoPyTorchPredictor, | |
| "deepar": DeepARPredictor, | |
| "tft": TFTPredictor, | |
| "autoarima": AutoARIMAPredictor, | |
| "autoets": AutoETSPredictor, | |
| "autotheta": AutoThetaPredictor, | |
| "statsensemble": StatsEnsemblePredictor, | |
| } | |
| def fit_predict_with_model( | |
| model_name: str, | |
| dataset: Dataset, | |
| prediction_length: int, | |
| freq: str, | |
| seasonality: int, | |
| **model_kwargs, | |
| ): | |
| model_class = MODEL_NAME_TO_CLASS[model_name.lower()] | |
| model: AbstractPredictor = model_class( | |
| prediction_length=prediction_length, | |
| freq=freq, | |
| seasonality=seasonality, | |
| **model_kwargs, | |
| ) | |
| predictions = model.fit_predict(dataset) | |
| info = {"run_time": model.get_runtime()} | |
| return predictions, info | |