update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import pyrootutils | |
| root = pyrootutils.setup_root( | |
| search_from=__file__, | |
| indicator=[".project-root"], | |
| pythonpath=True, | |
| dotenv=True, | |
| ) | |
| # ------------------------------------------------------------------------------------ # | |
| # `pyrootutils.setup_root(...)` is an optional line at the top of each entry file | |
| # that helps to make the environment more robust and convenient | |
| # | |
| # the main advantages are: | |
| # - allows you to keep all entry files in "src/" without installing project as a package | |
| # - makes paths and scripts always work no matter where is your current work dir | |
| # - automatically loads environment variables from ".env" file if exists | |
| # | |
| # how it works: | |
| # - the line above recursively searches for either ".git" or "pyproject.toml" in present | |
| # and parent dirs, to determine the project root dir | |
| # - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from | |
| # any place without installing project as a package | |
| # - sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" | |
| # to make all paths always relative to the project root | |
| # - loads environment variables from ".env" file in root dir (if `dotenv=True`) | |
| # | |
| # you can remove `pyrootutils.setup_root(...)` if you: | |
| # 1. either install project as a package or move each entry file to the project root dir | |
| # 2. simply remove PROJECT_ROOT variable from paths in "configs/paths/default.yaml" | |
| # 3. always run entry files from the project root dir | |
| # | |
| # https://github.com/ashleve/pyrootutils | |
| # ------------------------------------------------------------------------------------ # | |
| import os.path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import hydra | |
| import pytorch_lightning as pl | |
| from omegaconf import DictConfig, OmegaConf | |
| from pie_datasets import DatasetDict | |
| from pie_modules.models import * # noqa: F403 | |
| from pie_modules.models import SimpleGenerativeModel | |
| from pie_modules.models.interface import RequiresTaskmoduleConfig | |
| from pie_modules.taskmodules import * # noqa: F403 | |
| from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE | |
| from pytorch_ie import PieDataModule, Pipeline | |
| from pytorch_ie.core import PyTorchIEModel, TaskModule | |
| from pytorch_ie.models import * # noqa: F403 | |
| from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses | |
| from pytorch_ie.taskmodules import * # noqa: F403 | |
| from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize | |
| from pytorch_lightning import Callback, Trainer | |
| from pytorch_lightning.loggers import Logger | |
| from src import utils | |
| from src.models import * # noqa: F403 | |
| from src.serializer.interface import DocumentSerializer | |
| from src.taskmodules import * # noqa: F403 | |
| log = utils.get_pylogger(__name__) | |
| def get_metric_value(metric_dict: dict, metric_name: str) -> Optional[float]: | |
| """Safely retrieves value of the metric logged in LightningModule.""" | |
| if not metric_name: | |
| log.info("Metric name is None! Skipping metric value retrieval...") | |
| return None | |
| if metric_name not in metric_dict: | |
| raise Exception( | |
| f"Metric value not found! <metric_name={metric_name}>\n" | |
| "Make sure metric name logged in LightningModule is correct!\n" | |
| "Make sure `optimized_metric` name in `hparams_search` config is correct!" | |
| ) | |
| metric_value = metric_dict[metric_name].item() | |
| log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") | |
| return metric_value | |
| def flatten_nested_dict(d: Dict[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]: | |
| """Flatten a nested dictionary. | |
| Args: | |
| d (Dict[str, Any]): The dictionary to flatten. | |
| parent_key (str): The parent key. | |
| sep (str): The separator. | |
| Returns: | |
| Dict[str, Any]: The flattened dictionary. | |
| """ | |
| items: List[Tuple[str, Any]] = [] | |
| for k, v in d.items(): | |
| new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
| if isinstance(v, dict): | |
| items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| def train(cfg: DictConfig) -> Tuple[dict, dict]: | |
| """Trains the model. Can additionally evaluate on a testset, using best weights obtained during | |
| training. | |
| This method is wrapped in optional @task_wrapper decorator which applies extra utilities | |
| before and after the call. | |
| Args: | |
| cfg (DictConfig): Configuration composed by Hydra. | |
| Returns: | |
| Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. | |
| """ | |
| # set seed for random number generators in pytorch, numpy and python.random | |
| if cfg.get("seed"): | |
| pl.seed_everything(cfg.seed, workers=True) | |
| # Init pytorch-ie taskmodule | |
| log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>") | |
| taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial") | |
| # Init pytorch-ie dataset | |
| log.info(f"Instantiating dataset <{cfg.dataset._target_}>") | |
| dataset: DatasetDict = hydra.utils.instantiate( | |
| cfg.dataset, | |
| _convert_="partial", | |
| ) | |
| # auto-convert the dataset if the taskmodule specifies a document type | |
| dataset = dataset.to_document_type(taskmodule, downcast=False) | |
| # Init pytorch-ie datamodule | |
| log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") | |
| datamodule: PieDataModule = hydra.utils.instantiate( | |
| cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial" | |
| ) | |
| # Use the train dataset split to prepare the taskmodule | |
| taskmodule.prepare(dataset[datamodule.train_split]) | |
| # Init the pytorch-ie model | |
| log.info(f"Instantiating model <{cfg.model._target_}>") | |
| # get additional model arguments | |
| additional_model_kwargs: Dict[str, Any] = {} | |
| model_cls = hydra.utils.get_class(cfg.model["_target_"]) | |
| # NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE! | |
| # SEE EXAMPLES BELOW. | |
| if issubclass(model_cls, RequiresNumClasses): | |
| additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id) | |
| if issubclass(model_cls, RequiresModelNameOrPath): | |
| if "model_name_or_path" not in cfg.model: | |
| raise Exception( | |
| f"Please specify model_name_or_path in the model config for {model_cls.__name__}." | |
| ) | |
| if isinstance(taskmodule, ChangesTokenizerVocabSize): | |
| additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer) | |
| pooler_config = cfg["model"].get("pooler") | |
| if pooler_config is not None: | |
| if isinstance(pooler_config, str): | |
| pooler_config = {"type": pooler_config} | |
| pooler_config = dict(pooler_config) | |
| if pooler_config["type"] in ["start_tokens", "mention_pooling"]: | |
| # NOTE: This is very hacky, we should create a new interface class, e.g. RequiresPoolerNumIndices | |
| if hasattr(taskmodule, "argument_role2idx"): | |
| pooler_config["num_indices"] = len(taskmodule.argument_role2idx) | |
| else: | |
| pooler_config["num_indices"] = 1 | |
| elif pooler_config["type"] == "cls_token": | |
| pass | |
| else: | |
| raise Exception( | |
| f"unknown pooler type: {pooler_config['type']}. Please adjust the train.py script for that type." | |
| ) | |
| additional_model_kwargs["pooler"] = pooler_config | |
| if issubclass(model_cls, RequiresTaskmoduleConfig): | |
| additional_model_kwargs["taskmodule_config"] = taskmodule.config | |
| if model_cls == SimpleGenerativeModel: | |
| # There may be already some base_model_config entries in the model config. Also need to convert the | |
| # base_model_config to a dict, because it is a OmegaConf object which does not accept additional entries. | |
| base_model_config = ( | |
| dict(cfg.model.base_model_config) if "base_model_config" in cfg.model else {} | |
| ) | |
| if isinstance(taskmodule, PointerNetworkTaskModuleForEnd2EndRE): | |
| base_model_config.update( | |
| dict( | |
| bos_token_id=taskmodule.bos_id, | |
| eos_token_id=taskmodule.eos_id, | |
| pad_token_id=taskmodule.eos_id, | |
| target_token_ids=taskmodule.target_token_ids, | |
| embedding_weight_mapping=taskmodule.label_embedding_weight_mapping, | |
| ) | |
| ) | |
| additional_model_kwargs["base_model_config"] = base_model_config | |
| if issubclass(model_cls, SimpleSequenceClassificationModelWithInputTypeIds): # noqa: F405 | |
| # add the number of input type ids to the model: | |
| # 2 for B- and I-labels for each entity type, 1 for O labels, 1 for padding | |
| additional_model_kwargs["num_token_type_ids"] = len(taskmodule.entity_labels) * 2 + 1 + 1 | |
| # initialize the model | |
| model: PyTorchIEModel = hydra.utils.instantiate( | |
| cfg.model, _convert_="partial", **additional_model_kwargs | |
| ) | |
| log.info("Instantiating callbacks...") | |
| callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks") | |
| log.info("Instantiating loggers...") | |
| logger: List[Logger] = utils.instantiate_dict_entries(cfg, key="logger") | |
| log.info(f"Instantiating trainer <{cfg.trainer._target_}>") | |
| trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) | |
| object_dict = { | |
| "cfg": cfg, | |
| "dataset": dataset, | |
| "taskmodule": taskmodule, | |
| "model": model, | |
| "callbacks": callbacks, | |
| "logger": logger, | |
| "trainer": trainer, | |
| } | |
| if logger: | |
| log.info("Logging hyperparameters!") | |
| utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg) | |
| if cfg.paths.model_save_dir is not None: | |
| log.info(f"Save taskmodule to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]") | |
| taskmodule.save_pretrained( | |
| save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub | |
| ) | |
| else: | |
| log.warning("the taskmodule is not saved because no save_dir is specified") | |
| if cfg.get("train"): | |
| log.info("Starting training!") | |
| trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) | |
| train_metrics = trainer.callback_metrics | |
| best_ckpt_path = trainer.checkpoint_callback.best_model_path | |
| best_epoch = None | |
| if best_ckpt_path != "": | |
| log.info(f"Best ckpt path: {best_ckpt_path}") | |
| best_checkpoint_file = os.path.basename(best_ckpt_path) | |
| utils.log_hyperparameters( | |
| logger=logger, | |
| best_checkpoint=best_checkpoint_file, | |
| checkpoint_dir=trainer.checkpoint_callback.dirpath, | |
| ) | |
| # get epoch from best_checkpoint_file (e.g. "epoch_078.ckpt") | |
| try: | |
| best_epoch = int(os.path.splitext(best_checkpoint_file)[0].split("_")[-1]) | |
| except Exception as e: | |
| log.warning( | |
| f'Could not retrieve epoch from best checkpoint file name: "{e}". ' | |
| f"Expected format: " + '"epoch_{best_epoch}.ckpt"' | |
| ) | |
| if not cfg.trainer.get("fast_dev_run") or cfg.get("predict", False): | |
| if cfg.paths.model_save_dir is not None: | |
| if best_ckpt_path == "": | |
| log.warning("Best ckpt not found! Using current weights for saving...") | |
| else: | |
| model = type(model).load_from_checkpoint(best_ckpt_path) | |
| log.info(f"Save model to {cfg.paths.model_save_dir} [push_to_hub={cfg.push_to_hub}]") | |
| model.save_pretrained( | |
| save_directory=cfg.paths.model_save_dir, push_to_hub=cfg.push_to_hub | |
| ) | |
| else: | |
| log.warning("the model is not saved because no save_dir is specified") | |
| if cfg.get("validate"): | |
| log.info("Starting validation!") | |
| if best_ckpt_path == "": | |
| log.warning("Best ckpt not found! Using current weights for validation...") | |
| trainer.validate(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None) | |
| elif cfg.get("train"): | |
| log.warning( | |
| "Validation after training is skipped! That means, the finally reported validation scores are " | |
| "the values from the *last* checkpoint, not from the *best* checkpoint (which is saved)!" | |
| ) | |
| if cfg.get("test"): | |
| log.info("Starting testing!") | |
| if best_ckpt_path == "": | |
| log.warning("Best ckpt not found! Using current weights for testing...") | |
| trainer.test(model=model, datamodule=datamodule, ckpt_path=best_ckpt_path or None) | |
| test_metrics = trainer.callback_metrics | |
| test_metrics["best_epoch"] = best_epoch | |
| # merge train and test metrics | |
| metric_dict = {**train_metrics, **test_metrics} | |
| # add model_save_dir to the result so that it gets dumped to job_return_value.json | |
| # if we use hydra_callbacks.SaveJobReturnValueCallback | |
| if cfg.paths.get("model_save_dir") is not None: | |
| metric_dict["model_save_dir"] = cfg.paths.model_save_dir | |
| if cfg.get("predict"): | |
| # Init the inference pipeline | |
| pipeline: Optional[Pipeline] = None | |
| if cfg.get("pipeline") and cfg.pipeline.get("_target_"): | |
| log.info(f"Instantiating inference pipeline <{cfg.pipeline._target_}>") | |
| pipeline = hydra.utils.instantiate(cfg.pipeline, _convert_="partial") | |
| # Init the serializer | |
| serializer: Optional[DocumentSerializer] = None | |
| if cfg.get("serializer") and cfg.serializer.get("_target_"): | |
| log.info(f"Instantiating serializer <{cfg.serializer._target_}>") | |
| serializer = hydra.utils.instantiate(cfg.serializer, _convert_="partial") | |
| # predict and serialize | |
| predict_metrics: Dict[str, Any] = utils.predict_and_serialize( | |
| pipeline=pipeline, | |
| serializer=serializer, | |
| dataset=dataset[cfg.dataset_split], | |
| document_batch_size=cfg.get("document_batch_size", None), | |
| ) | |
| # flatten the predict_metrics dict | |
| predict_metrics_flat = flatten_nested_dict(predict_metrics, sep="/") | |
| metric_dict.update(predict_metrics_flat) | |
| if cfg.get("delete_model_dir"): | |
| import shutil | |
| log.info(f"Deleting model directory {cfg.paths.model_save_dir}") | |
| shutil.rmtree(cfg.paths.model_save_dir) | |
| return metric_dict, object_dict | |
| def main(cfg: DictConfig) -> Optional[float]: | |
| # train the model | |
| metric_dict, _ = train(cfg) | |
| # safely retrieve metric value for hydra-based hyperparameter optimization | |
| if cfg.get("optimized_metric") is not None: | |
| metric_value = get_metric_value( | |
| metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") | |
| ) | |
| # return optimized metric | |
| return metric_value | |
| else: | |
| return metric_dict | |
| if __name__ == "__main__": | |
| utils.replace_sys_args_with_values_from_files() | |
| utils.prepare_omegaconf() | |
| OmegaConf.register_new_resolver("eval", eval) | |
| main() | |