Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import importlib | |
| import inspect | |
| import logging | |
| import os | |
| import subprocess | |
| import sys | |
| from collections.abc import Iterable | |
| from dataclasses import dataclass, field | |
| from typing import Optional, Union | |
| import yaml | |
| from transformers import HfArgumentParser | |
| from transformers.hf_argparser import DataClass, DataClassType | |
| from transformers.utils import is_rich_available | |
| logger = logging.getLogger(__name__) | |
| class ScriptArguments: | |
| """ | |
| Arguments common to all scripts. | |
| Args: | |
| dataset_name (`str`): | |
| Dataset name. | |
| dataset_config (`str` or `None`, *optional*, defaults to `None`): | |
| Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. | |
| dataset_train_split (`str`, *optional*, defaults to `"train"`): | |
| Dataset split to use for training. | |
| dataset_test_split (`str`, *optional*, defaults to `"test"`): | |
| Dataset split to use for evaluation. | |
| dataset_streaming (`bool`, *optional*, defaults to `False`): | |
| Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. | |
| gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): | |
| Whether to apply `use_reentrant` for gradient checkpointing. | |
| ignore_bias_buffers (`bool`, *optional*, defaults to `False`): | |
| Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar | |
| type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. | |
| """ | |
| dataset_name: Optional[str] = field(default=None, metadata={"help": "Dataset name."}) | |
| dataset_config: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` " | |
| "function." | |
| }, | |
| ) | |
| dataset_train_split: str = field(default="train", metadata={"help": "Dataset split to use for training."}) | |
| dataset_test_split: str = field(default="test", metadata={"help": "Dataset split to use for evaluation."}) | |
| dataset_streaming: bool = field( | |
| default=False, | |
| metadata={"help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode."}, | |
| ) | |
| gradient_checkpointing_use_reentrant: bool = field( | |
| default=False, | |
| metadata={"help": "Whether to apply `use_reentrant` for gradient checkpointing."}, | |
| ) | |
| ignore_bias_buffers: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid " | |
| "scalar type, inplace operation. See " | |
| "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992." | |
| }, | |
| ) | |
| def init_zero_verbose(): | |
| """ | |
| Perform zero verbose init - use this method on top of the CLI modules to make | |
| logging and warning output cleaner. Uses Rich if available, falls back otherwise. | |
| """ | |
| import logging | |
| import warnings | |
| FORMAT = "%(message)s" | |
| if is_rich_available(): | |
| from rich.logging import RichHandler | |
| handler = RichHandler() | |
| else: | |
| handler = logging.StreamHandler() | |
| logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[handler], level=logging.ERROR) | |
| # Custom warning handler to redirect warnings to the logging system | |
| def warning_handler(message, category, filename, lineno, file=None, line=None): | |
| logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}") | |
| # Add the custom warning handler - we need to do that before importing anything to make sure the loggers work well | |
| warnings.showwarning = warning_handler | |
| class TrlParser(HfArgumentParser): | |
| """ | |
| A subclass of [`transformers.HfArgumentParser`] designed for parsing command-line arguments with dataclass-backed | |
| configurations, while also supporting configuration file loading and environment variable management. | |
| Args: | |
| dataclass_types (`Union[DataClassType, Iterable[DataClassType]]` or `None`, *optional*, defaults to `None`): | |
| Dataclass types to use for argument parsing. | |
| **kwargs: | |
| Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor. | |
| Examples: | |
| ```yaml | |
| # config.yaml | |
| env: | |
| VAR1: value1 | |
| arg1: 23 | |
| ``` | |
| ```python | |
| # main.py | |
| import os | |
| from dataclasses import dataclass | |
| from trl import TrlParser | |
| @dataclass | |
| class MyArguments: | |
| arg1: int | |
| arg2: str = "alpha" | |
| parser = TrlParser(dataclass_types=[MyArguments]) | |
| training_args = parser.parse_args_and_config() | |
| print(training_args, os.environ.get("VAR1")) | |
| ``` | |
| ```bash | |
| $ python main.py --config config.yaml | |
| (MyArguments(arg1=23, arg2='alpha'),) value1 | |
| $ python main.py --arg1 5 --arg2 beta | |
| (MyArguments(arg1=5, arg2='beta'),) None | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, | |
| **kwargs, | |
| ): | |
| # Make sure dataclass_types is an iterable | |
| if dataclass_types is None: | |
| dataclass_types = [] | |
| elif not isinstance(dataclass_types, Iterable): | |
| dataclass_types = [dataclass_types] | |
| # Check that none of the dataclasses have the "config" field | |
| for dataclass_type in dataclass_types: | |
| if "config" in dataclass_type.__dataclass_fields__: | |
| raise ValueError( | |
| f"Dataclass {dataclass_type.__name__} has a field named 'config'. This field is reserved for the " | |
| f"config file path and should not be used in the dataclass." | |
| ) | |
| super().__init__(dataclass_types=dataclass_types, **kwargs) | |
| def parse_args_and_config( | |
| self, | |
| args: Optional[Iterable[str]] = None, | |
| return_remaining_strings: bool = False, | |
| fail_with_unknown_args: bool = True, | |
| ) -> tuple[DataClass, ...]: | |
| """ | |
| Parse command-line args and config file into instances of the specified dataclass types. | |
| This method wraps [`transformers.HfArgumentParser.parse_args_into_dataclasses`] and also parses the config file | |
| specified with the `--config` flag. The config file (in YAML format) provides argument values that replace the | |
| default values in the dataclasses. Command line arguments can override values set by the config file. The | |
| method also sets any environment variables specified in the `env` field of the config file. | |
| """ | |
| args = list(args) if args is not None else sys.argv[1:] | |
| if "--config" in args: | |
| # Get the config file path from | |
| config_index = args.index("--config") | |
| args.pop(config_index) # remove the --config flag | |
| config_path = args.pop(config_index) # get the path to the config file | |
| with open(config_path) as yaml_file: | |
| config = yaml.safe_load(yaml_file) | |
| # Set the environment variables specified in the config file | |
| if "env" in config: | |
| env_vars = config.pop("env", {}) | |
| if not isinstance(env_vars, dict): | |
| raise ValueError("`env` field should be a dict in the YAML file.") | |
| for key, value in env_vars.items(): | |
| os.environ[key] = str(value) | |
| # Set the defaults from the config values | |
| config_remaining_strings = self.set_defaults_with_config(**config) | |
| else: | |
| config_remaining_strings = [] | |
| # Parse the arguments from the command line | |
| output = self.parse_args_into_dataclasses(args=args, return_remaining_strings=return_remaining_strings) | |
| # Merge remaining strings from the config file with the remaining strings from the command line | |
| if return_remaining_strings: | |
| args_remaining_strings = output[-1] | |
| return output[:-1] + (config_remaining_strings + args_remaining_strings,) | |
| elif fail_with_unknown_args and config_remaining_strings: | |
| raise ValueError( | |
| f"Unknown arguments from config file: {config_remaining_strings}. Please remove them, add them to the " | |
| "dataclass, or set `fail_with_unknown_args=False`." | |
| ) | |
| else: | |
| return output | |
| def set_defaults_with_config(self, **kwargs) -> list[str]: | |
| """ | |
| Overrides the parser's default values with those provided via keyword arguments, including for subparsers. | |
| Any argument with an updated default will also be marked as not required | |
| if it was previously required. | |
| Returns a list of strings that were not consumed by the parser. | |
| """ | |
| def apply_defaults(parser, kw): | |
| used_keys = set() | |
| for action in parser._actions: | |
| # Handle subparsers recursively | |
| if isinstance(action, argparse._SubParsersAction): | |
| for subparser in action.choices.values(): | |
| used_keys.update(apply_defaults(subparser, kw)) | |
| elif action.dest in kw: | |
| action.default = kw[action.dest] | |
| action.required = False | |
| used_keys.add(action.dest) | |
| return used_keys | |
| used_keys = apply_defaults(self, kwargs) | |
| # Remaining args not consumed by the parser | |
| remaining = [ | |
| item for key, value in kwargs.items() if key not in used_keys for item in (f"--{key}", str(value)) | |
| ] | |
| return remaining | |
| def get_git_commit_hash(package_name): | |
| try: | |
| # Import the package to locate its path | |
| package = importlib.import_module(package_name) | |
| # Get the path to the package using inspect | |
| package_path = os.path.dirname(inspect.getfile(package)) | |
| # Navigate up to the Git repository root if the package is inside a subdirectory | |
| git_repo_path = os.path.abspath(os.path.join(package_path, "..")) | |
| git_dir = os.path.join(git_repo_path, ".git") | |
| if os.path.isdir(git_dir): | |
| # Run the git command to get the current commit hash | |
| commit_hash = ( | |
| subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=git_repo_path).strip().decode("utf-8") | |
| ) | |
| return commit_hash | |
| else: | |
| return None | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |