Source code for arctic_training.config.trainer

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import sys
import tempfile
import uuid
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Literal
from typing import Union
from typing import cast

import yaml
from pydantic import Field
from pydantic import ValidationInfo
from pydantic import field_validator
from pydantic import model_validator
from typing_extensions import Self

from arctic_training.config.base import BaseConfig
from arctic_training.config.checkpoint import CheckpointConfig
from arctic_training.config.data import DataConfig
from arctic_training.config.enums import DType
from arctic_training.config.logger import LoggerConfig
from arctic_training.config.model import ModelConfig
from arctic_training.config.optimizer import OptimizerConfig
from arctic_training.config.scheduler import SchedulerConfig
from arctic_training.config.tokenizer import TokenizerConfig
from arctic_training.config.utils import HumanInt
from arctic_training.config.utils import UniqueKeyLoader
from arctic_training.config.utils import parse_human_val
from arctic_training.config.wandb import WandBConfig
from arctic_training.registry import _get_class_attr_type_hints
from arctic_training.registry import get_registered_checkpoint_engine
from arctic_training.registry import get_registered_data_factory
from arctic_training.registry import get_registered_model_factory
from arctic_training.registry import get_registered_optimizer_factory
from arctic_training.registry import get_registered_scheduler_factory
from arctic_training.registry import get_registered_tokenizer_factory
from arctic_training.registry import get_registered_trainer

if TYPE_CHECKING:
    from arctic_training.checkpoint.engine import CheckpointEngine

TRAINER_DEFAULT = "sft"
CUSTOM_CODE_DEFAULT = Path("train.py")


[docs] class TrainerConfig(BaseConfig): """Base Trainer Configuration.""" type: str = TRAINER_DEFAULT """ Trainer type. """ code: Path = CUSTOM_CODE_DEFAULT """ Path to the python script containing custom trainer implementation. """ skip_validation: bool = False """ Skips validation of types for subconfigs and registered classes. """ model: ModelConfig """ Model configuration. """ tokenizer: TokenizerConfig = Field(default_factory=TokenizerConfig) """ Tokenizer configuration. """ data: DataConfig """ Train and eval data configuration. """ logger: LoggerConfig = Field(default_factory=LoggerConfig) """ Logger configuration. """ wandb: WandBConfig = Field(default_factory=WandBConfig) """ Weights and Biases configuration. """ scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig) """ Scheduler configuration. """ optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) """ Optimizer configuration. """ deepspeed: Dict[str, Any] = {} """ DeepSpeed config dict. Will be automatically filled if not provided by the user. """ epochs: int = Field(default=1, ge=0) """ Number of epochs to train. """ loss_log_interval: HumanInt = Field(default=1, ge=0) """ Number of steps between logging loss. """ train_log_iter_interval: Literal[0, 1] = 1 """ Iters between training metric log outputs. `0` is off, only intervals of `1` currently supported. """ # XXX: fixme: the default output dir is broken # train_log_metrics_path: Path = Field( # default_factory=lambda data: data["logger"].output_dir / "train-log-metrics.jsonl" # ) # """ .jsonl path to log precise metrics according to the `train_log_iter_interval` schedule. Defaults to `logger.output_dir/train-log-metrics.jsonl` """ train_log_metrics_path: Path = Path("train-log-metrics.jsonl") """ .jsonl path to log precise metrics according to the `train_log_iter_interval` schedule. Defaults to `./train-log-metrics.jsonl` """ gradient_accumulation_steps: int = Field(default=1, ge=1) """ Number of gradient accumulation steps. """ micro_batch_size: int = Field(default=1, ge=1) """ Micro batch size per GPU. """ sequence_parallel_size: int = Field(default=1, ge=1) """ Sequence Parallelism Degree. Disabled if set to 1 """ activation_checkpoint_cpu_offload: bool = False """ Offload activation checkpoint tensors to cpu. Enables a much longer sequence length. It is not very beneficial if sequence length is <64k """ tiled_mlp_compute: bool = False """ Tile the MLP computation to save GPU memory. Currently only limited architectures supported, but can be expanded to more. """ seed: int = Field(default=42, ge=0) """ Random seed value for numpy, python.random, torch, and transformers. """ checkpoint: List[CheckpointConfig] = [] """ Checkpoint configurations. Multiple checkpoint engines may be used together. """ train_iters: HumanInt = Field(default=0, ge=0) """ Maximum number of training iterations. """ eval_frequency: int = Field(default=0, ge=0) exit_iteration: int = Field(default=0, ge=0) """ Force exit of training after specified iteration count (useful for debugging). """ min_iterations: HumanInt = Field(default=0, ge=0) """ When >0, the training dataset will be replicated until there is enough data to run this many iterations. """ overfit_first_batch: bool = False """ Train only on repetitions of the first training batch. Useful for development. """ mem_profiler: Literal[None, "step", "e2e"] = None """ Enable memory profiling. """ mem_profiler_dir: Path = Field(default_factory=lambda data: data["logger"].output_dir / "mem-prof") """ Path to save memory profiling results. Defaults to `logger.output_dir/mem-prof`. """ mem_profiler_max_entries: HumanInt = Field(default=100_000, ge=1) """ Maximum number of entries to store in the memory profiler. """ kill_switch_path: Path = Path("/tmp/at_kill_switch") """ Path to a file that can be used to trigger a graceful shutdown mid-training (sets early exit to True). """ @model_validator(mode="after") def init_dist(self) -> Self: import deepspeed from deepspeed.accelerator import get_accelerator get_accelerator().set_device(self.local_rank) deepspeed.init_distributed() return self @property def checkpoint_engines(self) -> List[partial["CheckpointEngine"]]: checkpoint_engines = [] for checkpoint in self.checkpoint: checkpoint_engine = get_registered_checkpoint_engine(checkpoint.type) checkpoint_engines.append(partial(checkpoint_engine, config=checkpoint)) return checkpoint_engines @property def zero_3_enabled(self) -> bool: return self.deepspeed.get("zero_optimization", {}).get("stage", 0) == 3 @staticmethod def _get_subconfig_object( v: Union[Dict, BaseConfig], info: ValidationInfo, get_class_fn: Callable, attr_name: str, ) -> BaseConfig: # Get the trainer class as it will tell us which types of factory # classes (and thus configs) are default/compatible trainer_type = info.data["type"] trainer_cls = get_registered_trainer(trainer_type) # Get type hints for this factory class. This is a list of compatible # classes for the given attribute field. attribute_type_hints = _get_class_attr_type_hints(trainer_cls, attr_name) # Convert to a dictionary as default values are the base config classes # and we likely need to use a different class based on the trainer type # or user requested `type` field value. if isinstance(v, dict): config_dict = v else: # Must exclude computed fields to avoid validation errors config_dict = v.model_dump(exclude={"local_rank", "global_rank", "world_size"}) # Determine which attribute class to use (e.g., for `model`: # HFModelFactory, LigerModelFactory, etc.) if config_dict.get("type", ""): # User explicitly specified the type attr_cls = get_class_fn(config_dict["type"]) else: # User did not specify the type, use the first (maybe only) hint as default type attr_cls = attribute_type_hints[0] # Check that the requested/resolved type is compatible with the trainer if not info.data.get("skip_validation") and attr_cls not in attribute_type_hints: raise ValueError( f"{attr_cls.__name__} is not supported for {attr_name} in" f" {trainer_cls.__name__}. Supported types are" f" {[cls.__name__ for cls in attribute_type_hints]}." ) # Make sure the `type` field is set in the config dict config_dict["type"] = attr_cls.name # Get the config class for the factory class and creat the config config_cls = _get_class_attr_type_hints(attr_cls, "config")[0] return config_cls(**config_dict) @staticmethod def _to_list(v: Union[Any, List[Any]]) -> List[Any]: if not isinstance(v, list): return [v] return v @field_validator("checkpoint", mode="before") @classmethod def init_checkpoint_configs( cls, v: Union[Union[Dict, CheckpointConfig], List[Union[Dict, CheckpointConfig]]], info: ValidationInfo, ) -> List[CheckpointConfig]: v = cls._to_list(v) return_list = [] for sub_v in v: return_list.append( cls._get_subconfig_object( v=sub_v, info=info, get_class_fn=get_registered_checkpoint_engine, attr_name="checkpoint_engine", ) ) return [cast(CheckpointConfig, subconfig) for subconfig in return_list] @field_validator("data", mode="before") @classmethod def init_data_config(cls, v: Union[Dict, DataConfig], info: ValidationInfo) -> DataConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_data_factory, attr_name="data_factory", ) return cast(DataConfig, subconfig) @field_validator("model", mode="before") @classmethod def init_model_config(cls, v: Union[Dict, ModelConfig], info: ValidationInfo) -> ModelConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_model_factory, attr_name="model_factory", ) return cast(ModelConfig, subconfig) @field_validator("optimizer", mode="before") @classmethod def init_optimizer_config(cls, v: Union[Dict, OptimizerConfig], info: ValidationInfo) -> OptimizerConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_optimizer_factory, attr_name="optimizer_factory", ) return cast(OptimizerConfig, subconfig) @field_validator("scheduler", mode="before") @classmethod def init_scheduler_config(cls, v: Union[Dict, SchedulerConfig], info: ValidationInfo) -> SchedulerConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_scheduler_factory, attr_name="scheduler_factory", ) return cast(SchedulerConfig, subconfig) @field_validator("tokenizer", mode="before") @classmethod def init_tokenizer_config(cls, v: Union[Dict, TokenizerConfig], info: ValidationInfo) -> TokenizerConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_tokenizer_factory, attr_name="tokenizer_factory", ) return cast(TokenizerConfig, subconfig) @model_validator(mode="after") def validate_eval_frequency(self) -> Self: if self.data.eval_sources or self.data.train_eval_split[1] > 0.0: assert self.eval_frequency > 0, "eval_frequency must be set if eval dataset is provided." return self @model_validator(mode="after") def set_tokenizer(self) -> Self: if not self.tokenizer.name_or_path: self.tokenizer.name_or_path = self.model.name_or_path return self @field_validator("logger", mode="after") @classmethod def initialize_logger(cls, v: LoggerConfig) -> LoggerConfig: from arctic_training.logging import setup_logger setup_logger(v) return v @field_validator("deepspeed", mode="before") @classmethod def coerce_deepspeed_human_friendly_values(cls, v: Dict[str, Any]) -> Dict[str, Any]: # Allow human friendly values for deepspeed config. This is a workaround # until we upstream this feature to the DeepSpeed pydantic configs. def coerce_dict_values(config_dict: Dict[str, Any]) -> Dict[str, Any]: coerced_dict: Dict[str, Any] = {} for key, value in config_dict.items(): if isinstance(value, dict): coerced_dict[key] = coerce_dict_values(value) else: try: coerced_dict[key] = parse_human_val(value) except Exception: coerced_dict[key] = value return coerced_dict return coerce_dict_values(v) @model_validator(mode="after") def build_deepspeed_config(self) -> Self: ds_config = self.deepspeed ds_config["train_micro_batch_size_per_gpu"] = self.micro_batch_size ds_config["train_batch_size"] = ( self.micro_batch_size * self.gradient_accumulation_steps * self.world_size / self.sequence_parallel_size ) ds_config["gradient_accumulation_steps"] = self.gradient_accumulation_steps ds_config["sequence_parallel_size"] = self.sequence_parallel_size ds_config["steps_per_print"] = ds_config.get("steps_per_print", 10) ds_config["zero_optimization"] = ds_config.get( "zero_optimization", { "stage": 2, "stage3_param_persistence_threshold": 1e4, "stage3_max_live_parameters": 3e7, "stage3_prefetch_bucket_size": 3e7, "memory_efficient_linear": False, }, ) if "bfloat16" not in ds_config: if self.model.dtype == DType.BF16: ds_config["bfloat16"] = {"enabled": True} if "fp16" not in ds_config: if self.model.dtype == DType.FP16: ds_config["fp16"] = {"enabled": True} ds_config["gradient_clipping"] = ds_config.get("gradient_clipping", 1.0) ds_config["prescale_gradients"] = ds_config.get("prescale_gradients", False) ds_config["wall_clock_breakdown"] = ds_config.get("wall_clock_breakdown", False) return self @model_validator(mode="after") def validate_single_checkpoint_resume(self) -> Self: resume_checkpoint_values = [c.auto_resume for c in self.checkpoint] assert sum(resume_checkpoint_values) <= 1, "Only one checkpoint can auto resume." return self @model_validator(mode="after") def train_log_metrics_path_prep(self) -> Self: if self.local_rank == 0: self.train_log_metrics_path.parent.mkdir(parents=True, exist_ok=True) self.train_log_metrics_path.open(mode="a") return self @model_validator(mode="after") def mem_profiler_mkdir(self) -> Self: if self.mem_profiler is not None: self.mem_profiler_dir.mkdir(parents=True, exist_ok=True) return self
def load_user_module_from_path(script_path: Path) -> None: # Symlink the script to a temporary directory to avoid clashing with other modules tmp_root = Path(tempfile.gettempdir()) shared_tmp_dir = tmp_root / "arctic_training_custom_module_symlinks" shared_tmp_dir.mkdir(exist_ok=True) # Generate the same unique name for a given script path across all processes unique_module_name = ( f"user_{script_path.stem}_{uuid.uuid5(uuid.NAMESPACE_URL, str(script_path.resolve())).hex[:8]}" ) symlink_path = shared_tmp_dir / f"{unique_module_name}.py" try: symlink_path.symlink_to(script_path) except FileExistsError: # Another proc created the symlink first, use that one pass # Insert into path so child procs can import it sys.path.insert(0, str(shared_tmp_dir)) spec = importlib.util.spec_from_file_location(unique_module_name, symlink_path) if spec is None or spec.loader is None: raise ImportError(f"Cannot load script from {symlink_path}") # Load user module module = importlib.util.module_from_spec(spec) sys.modules[unique_module_name] = module spec.loader.exec_module(module) def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig: if isinstance(config_file_or_dict, dict): config_dict = config_file_or_dict.copy() config_dir = Path.cwd() else: with open(config_file_or_dict, "r") as f: config_dict = yaml.load(f, Loader=UniqueKeyLoader) config_dir = config_file_or_dict.parent trainer_type = config_dict.get("type", TRAINER_DEFAULT) config_dict["type"] = trainer_type script_path = Path(config_dict.get("code", CUSTOM_CODE_DEFAULT)) if not script_path.is_absolute(): script_path = config_dir / script_path script_path = script_path.resolve() if script_path.exists(): config_dict["code"] = script_path load_user_module_from_path(script_path) elif config_dict.get("code") is not None: # User specified a script that doesn't exist raise FileNotFoundError(f"Cannot find script at {script_path}") trainer_cls = get_registered_trainer(trainer_type) config_cls = _get_class_attr_type_hints(trainer_cls, "config")[0] config = config_cls(**config_dict) return config