| import datetime | |
| import pytorch_lightning as pl | |
| from pytorch_lightning import loggers | |
| from src import config | |
| def _get_wandb_logger(trainer_config: config.TrainerConfig): | |
| name = f"{config.MODEL_NAME}-{datetime.datetime.now()}" | |
| if trainer_config.debug: | |
| name = "debug-" + name | |
| return loggers.WandbLogger( | |
| entity=config.WANDB_ENTITY, | |
| save_dir=config.WANDB_LOG_PATH, | |
| project=config.MODEL_NAME, | |
| name=name, | |
| config=trainer_config._model_config.to_dict(), | |
| ) | |
| def get_trainer(trainer_config: config.TrainerConfig): | |
| return pl.Trainer( | |
| max_epochs=trainer_config.epochs if not trainer_config.debug else 1, | |
| logger=_get_wandb_logger(trainer_config), | |
| log_every_n_steps=trainer_config.log_every_n_steps, | |
| gradient_clip_val=1.0, | |
| limit_train_batches=5 if trainer_config.debug else 1.0, | |
| limit_val_batches=5 if trainer_config.debug else 1.0, | |
| accelerator="auto", | |
| ) | |