Source code for arctic_training.config.trainer

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import sys
import tempfile
import uuid
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Literal
from typing import Union
from typing import cast

import yaml
from pydantic import Field
from pydantic import ValidationInfo
from pydantic import field_validator
from pydantic import model_validator
from typing_extensions import Self

from arctic_training.config.base import BaseConfig
from arctic_training.config.checkpoint import CheckpointConfig
from arctic_training.config.data import DataConfig
from arctic_training.config.enums import DType
from arctic_training.config.logger import LoggerConfig
from arctic_training.config.model import ModelConfig
from arctic_training.config.optimizer import OptimizerConfig
from arctic_training.config.scheduler import SchedulerConfig
from arctic_training.config.tokenizer import TokenizerConfig
from arctic_training.config.utils import HumanInt
from arctic_training.config.utils import UniqueKeyLoader
from arctic_training.config.utils import parse_human_val
from arctic_training.config.wandb import WandBConfig
from arctic_training.registry import _get_class_attr_type_hints
from arctic_training.registry import get_registered_checkpoint_engine
from arctic_training.registry import get_registered_data_factory
from arctic_training.registry import get_registered_model_factory
from arctic_training.registry import get_registered_optimizer_factory
from arctic_training.registry import get_registered_scheduler_factory
from arctic_training.registry import get_registered_tokenizer_factory
from arctic_training.registry import get_registered_trainer

if TYPE_CHECKING:
    from arctic_training.checkpoint.engine import CheckpointEngine

TRAINER_DEFAULT = "sft"
CUSTOM_CODE_DEFAULT = Path("train.py")


[docs] class TrainerConfig(BaseConfig): """Base Trainer Configuration.""" type: str = TRAINER_DEFAULT """ Trainer type. """ code: Path = CUSTOM_CODE_DEFAULT """ Path to the python script containing custom trainer implementation. """ skip_validation: bool = False """ Skips validation of types for subconfigs and registered classes. """ model: ModelConfig """ Model configuration. """ tokenizer: TokenizerConfig = Field(default_factory=TokenizerConfig) """ Tokenizer configuration. """ data: DataConfig """ Train and eval data configuration. """ logger: LoggerConfig = Field(default_factory=LoggerConfig) """ Logger configuration. """ wandb: WandBConfig = Field(default_factory=WandBConfig) """ Weights and Biases configuration. """ scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig) """ Scheduler configuration. """ optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) """ Optimizer configuration. """ deepspeed: Dict[str, Any] = {} """ DeepSpeed config dict. Will be automatically filled if not provided by the user. """ epochs: int = Field(default=1, ge=0) """ Number of epochs to train. """ loss_log_interval: HumanInt = Field(default=1, ge=0) """ Number of steps between logging loss. """ train_log_iter_interval: Literal[0, 1] = 1 """ Iters between training metric log outputs. `0` is off, only intervals of `1` currently supported. """ # XXX: fixme: the default output dir is broken # train_log_metrics_path: Path = Field( # default_factory=lambda data: data["logger"].output_dir / "train-log-metrics.jsonl" # ) # """ .jsonl path to log precise metrics according to the `train_log_iter_interval` schedule. Defaults to `logger.output_dir/train-log-metrics.jsonl` """ train_log_metrics_path: Path = Path("train-log-metrics.jsonl") """ .jsonl path to log precise metrics according to the `train_log_iter_interval` schedule. Defaults to `./train-log-metrics.jsonl` """ gradient_accumulation_steps: int = Field(default=1, ge=1) """ Number of gradient accumulation steps. """ micro_batch_size: int = Field(default=1, ge=1) """ Micro batch size per GPU. """ sequence_parallel_size: int = Field(default=1, ge=1) """ Sequence Parallelism Degree. Disabled if set to 1 """ expert_parallel_size: int = Field(default=1, ge=1) """ Expert Parallelism Degree. Disabled if set to 1 """ # do we need arctic_moe if expert_parallel_size > 1? arctic_moe: Union[bool, Literal["auto"]] = False """ Whether to enable arctic moe. Values True/False/"auto" - "auto" will try to detect if the model is an MoE model """ enable_arctic_moe_custom_optimization: bool = False """ Whether to enable custom optimization for arctic moe. This includes custom alltoall and custom kernels for MoE layers. """ enable_routing_replay: bool = False """ Whether to enable routing replay for MoE layers. This can help with stability when training MoE models. """ activation_checkpoint_cpu_offload: bool = False """ Offload activation checkpoint tensors to cpu. Enables a much longer sequence length. It is not very beneficial if sequence length is <64k """ tiled_mlp_compute: bool = False """ Tile the MLP computation to save GPU memory. Currently only limited architectures supported, but can be expanded to more. """ seed: int = Field(default=42, ge=0) """ Random seed value for numpy, python.random, torch, and transformers. """ checkpoint: List[CheckpointConfig] = [] """ Checkpoint configurations. Multiple checkpoint engines may be used together. """ train_iters: HumanInt = Field(default=0, ge=0) """ Maximum number of training iterations. """ eval_interval: HumanInt = Field(default=0, ge=0) """ Number of iterations between evaluations. If 0, no evaluation is performed. """ eval_log_iter_interval: HumanInt = Field(default=1, ge=0) """ Iters between eval metric log outputs. `0` is off. """ exit_iteration: int = Field(default=0, ge=0) """ Do not continue training after specified iteration count even if there is still data and epochs to run (useful for debugging and tests). """ exit_iteration_this_run: int = Field(default=0, ge=0) """ Force exit of training after specified iteration count in this run (but will restart running until `exit_iteration` or running out of data/epochs after resume (useful for debugging and tests). """ min_iterations: HumanInt = Field(default=0, ge=0) """ When >0, the training dataset will be replicated until there is enough data to run this many iterations. """ overfit_first_batch: bool = False """ Train only on repetitions of the first training batch. Useful for development. """ mem_profiler: Literal[None, "step", "e2e"] = None """ Enable memory profiling. """ mem_profiler_dir: Path = Field(default_factory=lambda data: data["logger"].output_dir / "mem-prof") """ Path to save memory profiling results. Defaults to `logger.output_dir/mem-prof`. """ mem_profiler_max_entries: HumanInt = Field(default=100_000, ge=1) """ Maximum number of entries to store in the memory profiler. """ kill_switch_path: Path = Path("/tmp/at_kill_switch") """ Path to a file that can be used to trigger a graceful shutdown mid-training (sets early exit to True). """ @model_validator(mode="after") def set_max_length(self) -> Self: if "max_length" not in self.data.model_fields_set: from transformers import AutoConfig model_config = AutoConfig.from_pretrained(self.model.name_or_path) if not hasattr(model_config, "max_position_embeddings"): raise ValueError( f"Model config for {self.model.name_or_path} does not have a `max_position_embeddings` settings." " Set `data.max_length` in your config." ) self.data.max_length = model_config.max_position_embeddings return self @model_validator(mode="after") def init_dist(self) -> Self: import deepspeed from deepspeed.accelerator import get_accelerator get_accelerator().set_device(self.local_rank) deepspeed.init_distributed() return self @property def checkpoint_engines(self) -> List[partial["CheckpointEngine"]]: checkpoint_engines = [] for checkpoint in self.checkpoint: checkpoint_engine = get_registered_checkpoint_engine(checkpoint.type) checkpoint_engines.append(partial(checkpoint_engine, config=checkpoint)) return checkpoint_engines @property def zero_3_enabled(self) -> bool: return self.deepspeed.get("zero_optimization", {}).get("stage", 0) == 3 @staticmethod def _get_subconfig_object( v: Union[Dict, BaseConfig], info: ValidationInfo, get_class_fn: Callable, attr_name: str, ) -> BaseConfig: # Get the trainer class as it will tell us which types of factory # classes (and thus configs) are default/compatible trainer_type = info.data["type"] trainer_cls = get_registered_trainer(trainer_type) # Get type hints for this factory class. This is a list of compatible # classes for the given attribute field. attribute_type_hints = _get_class_attr_type_hints(trainer_cls, attr_name) # Convert to a dictionary as default values are the base config classes # and we likely need to use a different class based on the trainer type # or user requested `type` field value. if isinstance(v, dict): config_dict = v else: # Must exclude computed fields to avoid validation errors config_dict = v.model_dump(exclude={"local_rank", "global_rank", "world_size"}) # Determine which attribute class to use (e.g., for `model`: # HFModelFactory, LigerModelFactory, etc.) if config_dict.get("type", ""): # User explicitly specified the type attr_cls = get_class_fn(config_dict["type"]) else: # User did not specify the type, use the first (maybe only) hint as default type attr_cls = attribute_type_hints[0] # Check that the requested/resolved type is compatible with the trainer if not info.data.get("skip_validation") and attr_cls not in attribute_type_hints: raise ValueError( f"{attr_cls.__name__} is not supported for {attr_name} in" f" {trainer_cls.__name__}. Supported types are" f" {[cls.__name__ for cls in attribute_type_hints]}." ) # Make sure the `type` field is set in the config dict config_dict["type"] = attr_cls.name # Get the config class for the factory class and creat the config config_cls = _get_class_attr_type_hints(attr_cls, "config")[0] return config_cls(**config_dict) @staticmethod def _to_list(v: Union[Any, List[Any]]) -> List[Any]: if not isinstance(v, list): return [v] return v @field_validator("checkpoint", mode="before") @classmethod def init_checkpoint_configs( cls, v: Union[Union[Dict, CheckpointConfig], List[Union[Dict, CheckpointConfig]]], info: ValidationInfo, ) -> List[CheckpointConfig]: v = cls._to_list(v) return_list = [] for sub_v in v: return_list.append( cls._get_subconfig_object( v=sub_v, info=info, get_class_fn=get_registered_checkpoint_engine, attr_name="checkpoint_engine", ) ) return [cast(CheckpointConfig, subconfig) for subconfig in return_list] @field_validator("data", mode="before") @classmethod def init_data_config(cls, v: Union[Dict, DataConfig], info: ValidationInfo) -> DataConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_data_factory, attr_name="data_factory", ) return cast(DataConfig, subconfig) @field_validator("model", mode="before") @classmethod def init_model_config(cls, v: Union[Dict, ModelConfig], info: ValidationInfo) -> ModelConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_model_factory, attr_name="model_factory", ) return cast(ModelConfig, subconfig) @field_validator("optimizer", mode="before") @classmethod def init_optimizer_config(cls, v: Union[Dict, OptimizerConfig], info: ValidationInfo) -> OptimizerConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_optimizer_factory, attr_name="optimizer_factory", ) return cast(OptimizerConfig, subconfig) @field_validator("scheduler", mode="before") @classmethod def init_scheduler_config(cls, v: Union[Dict, SchedulerConfig], info: ValidationInfo) -> SchedulerConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_scheduler_factory, attr_name="scheduler_factory", ) return cast(SchedulerConfig, subconfig) @field_validator("tokenizer", mode="before") @classmethod def init_tokenizer_config(cls, v: Union[Dict, TokenizerConfig], info: ValidationInfo) -> TokenizerConfig: subconfig = cls._get_subconfig_object( v=v, info=info, get_class_fn=get_registered_tokenizer_factory, attr_name="tokenizer_factory", ) return cast(TokenizerConfig, subconfig) @model_validator(mode="after") def validate_eval_interval(self) -> Self: if self.data.eval_sources or self.data.train_eval_split[1] > 0.0: assert self.eval_interval > 0, "`eval_interval` must be set if eval dataset is provided." if self.eval_interval > 0: assert ( self.data.eval_sources or self.data.train_eval_split[1] > 0.0 ), "`eval_interval` must be set only if eval dataset is provided." return self @model_validator(mode="after") def set_tokenizer(self) -> Self: if not self.tokenizer.name_or_path: self.tokenizer.name_or_path = self.model.name_or_path return self @field_validator("logger", mode="after") @classmethod def initialize_logger(cls, v: LoggerConfig) -> LoggerConfig: from arctic_training.logging import setup_logger setup_logger(v) return v @field_validator("deepspeed", mode="before") @classmethod def coerce_deepspeed_human_friendly_values(cls, v: Dict[str, Any]) -> Dict[str, Any]: # Allow human friendly values for deepspeed config. This is a workaround # until we upstream this feature to the DeepSpeed pydantic configs. def coerce_dict_values(config_dict: Dict[str, Any]) -> Dict[str, Any]: coerced_dict: Dict[str, Any] = {} for key, value in config_dict.items(): if isinstance(value, dict): coerced_dict[key] = coerce_dict_values(value) else: try: coerced_dict[key] = parse_human_val(value) except Exception: coerced_dict[key] = value return coerced_dict return coerce_dict_values(v) @model_validator(mode="after") def build_deepspeed_config(self) -> Self: ds_config = self.deepspeed ds_config["train_micro_batch_size_per_gpu"] = self.micro_batch_size ds_config["train_batch_size"] = ( self.micro_batch_size * self.gradient_accumulation_steps * self.world_size / self.sequence_parallel_size ) ds_config["gradient_accumulation_steps"] = self.gradient_accumulation_steps ds_config["sequence_parallel_size"] = self.sequence_parallel_size ds_config["steps_per_print"] = ds_config.get("steps_per_print", 10) from transformers import AutoConfig model_config = AutoConfig.from_pretrained(self.model.name_or_path) if hasattr(model_config, "hidden_size"): hidden_size = model_config.hidden_size elif hasattr(model_config, "hidden_sizes"): # if there are many hidden sizes pick the largest one hidden_size = max(model_config.hidden_sizes) else: raise ValueError( "Can find neither `model_config.hidden_size` nor `model_config.hidden_sizes`, in the " f" {self.model.name_or_path}'s config" ) # the following defaults come from the Deepspeed team recommendation ds_config["zero_optimization"] = ds_config.get( "zero_optimization", { "stage": 2, "stage3_param_persistence_threshold": 10 * hidden_size, "stage3_max_live_parameters": 2 * hidden_size * hidden_size, "stage3_prefetch_bucket_size": int(0.9 * hidden_size * hidden_size), "reduce_bucket_size": hidden_size * hidden_size, }, ) if "bfloat16" not in ds_config: if self.model.dtype == DType.BF16: ds_config["bfloat16"] = {"enabled": True} if "fp16" not in ds_config: if self.model.dtype == DType.FP16: ds_config["fp16"] = {"enabled": True} ds_config["gradient_clipping"] = ds_config.get("gradient_clipping", 1.0) ds_config["prescale_gradients"] = ds_config.get("prescale_gradients", False) ds_config["wall_clock_breakdown"] = ds_config.get("wall_clock_breakdown", False) return self @model_validator(mode="after") def validate_single_checkpoint_resume(self) -> Self: resume_checkpoint_values = [c.auto_resume for c in self.checkpoint] assert sum(resume_checkpoint_values) <= 1, "Only one checkpoint can auto resume." return self @model_validator(mode="after") def train_log_metrics_path_prep(self) -> Self: if self.local_rank == 0: self.train_log_metrics_path.parent.mkdir(parents=True, exist_ok=True) self.train_log_metrics_path.open(mode="a") return self @model_validator(mode="after") def mem_profiler_mkdir(self) -> Self: if self.mem_profiler is not None: self.mem_profiler_dir.mkdir(parents=True, exist_ok=True) return self @model_validator(mode="after") def validate_sft_sample_packing(self) -> Self: if hasattr(self.data, "pack_samples") and self.data.pack_samples: assert self.micro_batch_size == 1, "`micro_batch_size` must be 1 if `pack_samples` is enabled." return self
def load_user_module_from_path(script_path: Path) -> None: # Symlink the entire directory containing the script to avoid issues with relative imports script_dir = script_path.parent tmp_root = Path(tempfile.gettempdir()) shared_tmp_dir = tmp_root / "arctic_training_custom_module_symlinks" shared_tmp_dir.mkdir(exist_ok=True) # Generate the same unique name for a given script directory across all processes unique_dir_name = f"user_dir_{uuid.uuid5(uuid.NAMESPACE_URL, str(script_dir.resolve())).hex[:8]}" symlink_dir_path = shared_tmp_dir / unique_dir_name try: symlink_dir_path.symlink_to(script_dir) except FileExistsError: # Another proc created the symlink first, use that one pass # Now load the specific script from the symlinked directory script_name = script_path.stem unique_module_name = f"{unique_dir_name}_{script_name}" symlinked_script_path = symlink_dir_path / script_path.name # Create a symlink in the shared directory with the unique module name # so that child processes can import it by name unique_module_file = shared_tmp_dir / f"{unique_module_name}.py" try: unique_module_file.symlink_to(symlinked_script_path) except FileExistsError: # Another proc created the symlink first, use that one pass # Add both the shared temp dir and the symlinked directory to sys.path # - shared_tmp_dir: so child processes can import the uniquely named module # - symlink_dir_path: so user modules can import from each other shared_path_str = str(shared_tmp_dir) if shared_path_str not in sys.path: sys.path.append(shared_path_str) user_path_str = str(symlink_dir_path) if user_path_str not in sys.path: sys.path.append(user_path_str) spec = importlib.util.spec_from_file_location(unique_module_name, symlinked_script_path) if spec is None or spec.loader is None: raise ImportError(f"Cannot load script from {symlinked_script_path}") # Load user module module = importlib.util.module_from_spec(spec) sys.modules[unique_module_name] = module spec.loader.exec_module(module) def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig: if isinstance(config_file_or_dict, dict): config_dict = config_file_or_dict.copy() config_dir = Path.cwd() else: with open(config_file_or_dict, "r") as f: config_dict = yaml.load(f, Loader=UniqueKeyLoader) config_dir = config_file_or_dict.parent trainer_type = config_dict.get("type", TRAINER_DEFAULT) config_dict["type"] = trainer_type script_path = Path(config_dict.get("code", CUSTOM_CODE_DEFAULT)) if not script_path.is_absolute(): script_path = config_dir / script_path script_path = script_path.resolve() if script_path.exists(): config_dict["code"] = script_path load_user_module_from_path(script_path) elif config_dict.get("code") is not None: # User specified a script that doesn't exist raise FileNotFoundError(f"Cannot find script at {script_path}") trainer_cls = get_registered_trainer(trainer_type) config_cls = _get_class_attr_type_hints(trainer_cls, "config")[0] config = config_cls(**config_dict) return config