Source code for arctic_training.trainer.trainer

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
import random
from abc import ABC
from abc import abstractmethod
from functools import cached_property
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

import deepspeed
import numpy as np
import torch
import torch.cuda
import torch.distributed.nn
import wandb
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPDataLoaderAdapter
from devtools import debug
from tqdm import tqdm
from transformers import set_seed
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from wandb.sdk.wandb_run import Run as WandbRun

from arctic_training.callback.logging import post_loss_log_cb
from arctic_training.callback.mixin import CallbackMixin
from arctic_training.callback.mixin import callback_wrapper
from arctic_training.checkpoint.engine import CheckpointEngine
from arctic_training.config.trainer import TrainerConfig
from arctic_training.data.factory import DataFactory
from arctic_training.data.utils import OverfitOneBatchDataLoader
from arctic_training.logging import logger
from arctic_training.metrics import Metrics
from arctic_training.model.factory import ModelFactory
from arctic_training.model.tiled_compute import enable_tiled_mlp_compute
from arctic_training.optimizer.factory import OptimizerFactory
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method
from arctic_training.scheduler.factory import SchedulerFactory
from arctic_training.tokenizer.factory import TokenizerFactory
from arctic_training.utils import append_json_file


[docs] class Trainer(ABC, CallbackMixin, metaclass=RegistryMeta): """Base Trainer class.""" name: str """ Name of the trainer used for registering custom trainers. This name should be unique and is used in the training recipe YAMLs to identify which trainer to be used. """ config: TrainerConfig """ The type of the config class that the trainer uses. This should be a subclass of TrainerConfig and add any trainer-specific fields. """ data_factory: DataFactory """ A List of valid data factory types that the trainer can use. These should inherit from DataFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ model_factory: ModelFactory """ A List of valid model factory types that the trainer can use. These should inherit from ModelFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ checkpoint_engine: CheckpointEngine """ A List of valid checkpoint engine types that the trainer can use. These should inherit from CheckpointEngine. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ optimizer_factory: OptimizerFactory """ A List of valid optimizer factory types that the trainer can use. These should inherit from OptimizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ scheduler_factory: SchedulerFactory """ A List of valid scheduler factory types that the trainer can use. These should inherit from SchedulerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ tokenizer_factory: TokenizerFactory """ A List of valid tokenizer factory types that the trainer can use. These should inherit from TokenizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ callbacks: List[Tuple[str, Callable]] = [ post_loss_log_cb, ] """ A list of callbacks for the trainer. Callbacks are specified as tuples of a string indicating where the callback should be placed and a callable that implements the callback. Callback events for the trainer include `pre-` and `post-` for `init`, `train`, `epoch`, `step`, and `checkpoint`. """ @classmethod def _validate_subclass(cls) -> None: _validate_class_attribute_set(cls, "name") _validate_class_attribute_type(cls, "config", TrainerConfig) _validate_class_attribute_type(cls, "data_factory", DataFactory) _validate_class_attribute_type(cls, "model_factory", ModelFactory) _validate_class_attribute_type(cls, "checkpoint_engine", CheckpointEngine) _validate_class_attribute_type(cls, "optimizer_factory", OptimizerFactory) _validate_class_attribute_type(cls, "scheduler_factory", SchedulerFactory) _validate_class_attribute_type(cls, "tokenizer_factory", TokenizerFactory) _validate_class_method(cls, "loss", ["self", "batch"]) _validate_class_method(cls, "step", ["self", "batch"]) _validate_class_method(cls, "epoch", ["self"]) _validate_class_method(cls, "train", ["self"]) _validate_class_method(cls, "checkpoint", ["self"]) def __init__(self, config: TrainerConfig, mode: str = "train") -> None: logger.info(f"Initializing Trainer with config:\n{debug.format(config)}") self.config = config self.epoch_idx = 0 self.train_batch_idx = 0 self.global_step = 0 self.global_step_this_run = 0 self.global_step_at_start_this_run = 0 self.early_stop = False self.early_stop_reason = "" self.world_size = config.world_size self.global_rank = config.global_rank self.epoch_finished = False self.training_finished = False self.wandb_experiment: Optional[WandbRun] = None self.is_resume = False # Track if we resumed from ckpt self.wandb_run_id = None self._set_seeds(self.config.seed) if self.config.mem_profiler == "e2e": torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries) tokenizer_factory = self.config.tokenizer.factory(self) self.tokenizer = tokenizer_factory() data_factory = self.config.data.factory(self) self.train_dataloader, self.eval_dataloader = data_factory() if mode == "process-data": return if self.config.overfit_first_batch: self.train_dataloader = OverfitOneBatchDataLoader(self.train_dataloader) # checkpointing and resume self.checkpoint_engines = [engine(self) for engine in self.config.checkpoint_engines] for engine in self.checkpoint_engines: # currently only deepspeed engine supports resume from intermediate checkpoint if engine.name == "deepspeed" and engine.config.auto_resume and engine.latest_checkpoint_exists: self.is_resume = True # XXX: We can abstract this section further with AT-specific wrapper, but # UlyssesSPAttentionHF should not have any AT-specific objects / assumptions mpu = UlyssesSPAttentionHF.register_with_transformers( model_name_or_path=self.config.model.name_or_path, core_attn_implementation=self.config.model.attn_implementation, sequence_parallel_size=self.config.sequence_parallel_size, micro_batch_size=self.config.micro_batch_size, seq_length=self.config.data.max_length, seq_length_is_variable=True, ) # Important: this is most likely not beneficial under seqlen=64k if self.config.activation_checkpoint_cpu_offload: # activation_checkpointing_cpu_offload becomes very benefitial at very long seqlen # e.g., llama 8b at 800k (100k effective per gpu) will save 24GB per gpu: # ((100_000*4096)*2*32/2**30), but for short sequences the offload will just slow things # down, # # XXX: could parameterize or run a few lengths to see at which threshold it becomes # beneficial - a user might still want this on even at shorter seqlen if they don't # mind slower performance. discussing adding this functionality to pytorch core # (https://pytorch.slack.com/archives/C3PDTEV8E/p1745274102600729) from arctic_training.monkey_patches import monkey_patch_checkpoint_function_with_cpu_offload monkey_patch_checkpoint_function_with_cpu_offload() # MLP tiling - has to happen before model is instantiated if self.config.tiled_mlp_compute: enable_tiled_mlp_compute(self.config.model.name_or_path) dschf = HfDeepSpeedConfig(self.config.deepspeed) # noqa: F841 model_factory = self.config.model.factory(self) self.model = model_factory() # prevent causal mask from being created in HF Transformers - it's a huge `[bs, seqlen, seqlen]` tensor # XXX: This should also benefit a single gpu use case when SDPA is used - so perhaps remove the SP>1 check? if self.config.sequence_parallel_size > 1 and self.config.model.attn_implementation not in [ "flash_attention_2", "flash_attention_3", ]: import transformers.masking_utils transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", lambda *args, **kwargs: None) optimizer_factory = self.config.optimizer.factory(self) self.optimizer = optimizer_factory() scheduler_factory = self.config.scheduler.factory(self) self.scheduler = scheduler_factory() self.model, *_ = deepspeed.initialize( model=self.model, optimizer=self.optimizer, args=self.config, lr_scheduler=self.scheduler, config=self.config.deepspeed, mpu=mpu, ) self.ds_wall_clock_available = hasattr(self.model, "get_wall_clock_timers") if self.config.sequence_parallel_size > 1: # deepspeed.initialize needs to run first from deepspeed.utils import groups # set SP-trainer attributes to be used later self.sp_group = groups._get_sequence_parallel_group() self.sp_world_size = groups._get_sequence_parallel_world_size() self.sp_rank = groups._get_sequence_parallel_rank() # wrap the DL with Ulysses one self.train_dataloader = UlyssesSPDataLoaderAdapter( self.train_dataloader, sp_rank=self.sp_rank, sp_group=self.sp_group, sp_world_size=self.sp_world_size, device=self.device, ) if self.eval_dataloader is not None: self.eval_dataloader = UlyssesSPDataLoaderAdapter( self.eval_dataloader, sp_rank=self.sp_rank, sp_group=self.sp_group, sp_world_size=self.sp_world_size, device=self.device, ) for engine in self.checkpoint_engines: if engine.config.auto_resume: engine.load(self.model) self.metrics = Metrics(self) if self.global_rank == 0 and self.config.wandb.enable: # in order for resume to continue the same wandb run we need to re-use a run_id from the previous run if self.wandb_run_id is None: self.wandb_run_id = wandb.util.generate_id() # Note: wandb.init() is not type annotated so we need to use type: ignore self.wandb_experiment = wandb.init( # type: ignore id=self.wandb_run_id, entity=self.config.wandb.entity, project=self.config.wandb.project, name=self.config.wandb.name, config=self.config.model_dump(), # do not put `wandb` in the root of the repo as it conflicts with wandb package dir=f"{self.config.logger.output_dir}/wandb", ) def _set_seeds(self, seed: int) -> None: logger.info(f"Setting random seeds to {seed}") torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) set_seed(seed) @property def model_unwrapped(self): """Return the original model before it was wrapped by deepspeed""" if hasattr(self.model, "module"): return self.model.module else: return self.model @property def epochs(self) -> tqdm: """Epochs iterator.""" total_epochs = self.config.epochs if self.config.train_iters: total_epochs = math.ceil( self.config.train_iters * self.config.gradient_accumulation_steps / len(self.train_dataloader) ) return tqdm( range(self.epoch_idx, total_epochs), desc="Epochs", unit="epoch", disable=(self.global_rank != 0) or (self.config.train_log_iter_interval != 0), ) @property def train_batches(self) -> tqdm: """Training data iterator.""" return tqdm( self.train_dataloader, desc="Train Batches", unit="batch", disable=(self.global_rank != 0) or (self.config.train_log_iter_interval != 0), ) @property def eval_batches(self) -> tqdm: """Evaluation data iterator.""" return tqdm( self.eval_dataloader, desc="Eval Batches", unit="batch", disable=self.global_rank != 0 or not self.is_eval_log_iter(), )
[docs] def is_eval_log_iter(self) -> bool: return self.global_step // self.config.eval_interval % self.config.eval_log_iter_interval == 0
@cached_property def device(self) -> torch.device: """Current device.""" return torch.device(get_accelerator().device_name(self.config.local_rank)) @property def training_horizon(self) -> int: """Total number of training iterations.""" if self.train_dataloader is None: raise ValueError("Train dataloader not initialized.") if self.config.train_iters: return self.config.train_iters # XXX: this was incorrect for GAS return self.config.epochs * len(self.train_dataloader) # // self.config.gradient_accumulation_steps
[docs] @callback_wrapper("loss") @abstractmethod def loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Loss function for the trainer. This method should be implemented by the inheriting trainer class. """ raise NotImplementedError("Loss method must be implemented by the trainer.")
[docs] @callback_wrapper("backward") def backward(self, loss: torch.Tensor) -> None: """ Backward function for the trainer. This method is called after the loss method and is responsible for backpropagating the loss through the model. """ self.model.backward(loss)
[docs] def need_early_exit(self): """ If we need to exit early, set `self.early_stop_reason` and return True Otherwise return False """ # exit conditions in the order of likelyhood if ( self.config.exit_iteration_this_run > 0 and self.config.exit_iteration_this_run == self.global_step_this_run ): self.early_stop_reason = f"reached exit_iteration_this_run of {self.global_step_this_run}" return True elif self.config.exit_iteration > 0 and self.config.exit_iteration == self.global_step: self.early_stop_reason = f"reached exit_iteration of {self.global_step}" return True elif self.config.kill_switch_path.exists(): self.early_stop_reason = f"detected kill switch {self.config.kill_switch_path}" return True elif self.global_step >= self.training_horizon: self.early_stop_reason = f"reached training_horizon of {self.global_step}" return True return False
[docs] @callback_wrapper("step") def step(self, batch: Dict[str, torch.Tensor]) -> None: """ Step function for the trainer. Each batch of training data is passed to this method. """ self.model.train() with deepspeed.runtime.engine.autocast_if_enabled(self.model): loss = self.loss(batch) self.backward(loss) def maybe_item(v): return v.item() if torch.is_tensor(v) else v self.metrics.record("loss", maybe_item(loss)) self.model.step() self.checkpoint() # DeepSpeed increments its global step after the step() call, so we use it as the golden truth self.global_step = self.model.global_steps self.global_step_this_run = self.global_step - self.global_step_at_start_this_run
[docs] @callback_wrapper("epoch") def epoch(self) -> None: """ Epoch training loop. This method will be called for each epoch of training and iterates across batches of training data, calling the step method on each batch. """ self.epoch_finished = False self.metrics.start_timer("iter") # enable memory allocation history, which will add tracebacks and event history to memory snapshots if self.config.mem_profiler == "step": torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries) batch_iterator = iter(self.train_batches) if self.is_resume: logger.info(f"Resumed from checkpoint at global step: {self.global_step}.") batches_to_skip = self.global_step % len(self.train_dataloader) logger.info(f"Advancing {batches_to_skip} batches.") for _ in range(batches_to_skip): next(batch_iterator) self.train_batch_idx += batches_to_skip self.is_resume = False for batch in batch_iterator: self.train_batch_idx += 1 # Run the early exit checks before stepping to correctly deal with resume should the training not continue if self.need_early_exit(): self.early_stop = True break self.gas_boundary = self.train_batch_idx % self.config.gradient_accumulation_steps == 0 if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation: # deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately sample_seqlens = batch.pop("packed_sample_seqlens") else: sample_seqlens = [ [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] for idx in range(len(batch["input_ids"])) ] self.metrics.seqlens = sample_seqlens self.metrics.start_timer("step") self.step(batch) self.metrics.stop_timer("step") self.metrics.restart_timer("iter") if self.config.train_log_iter_interval != 0: self.metrics.print_summary() if self.gas_boundary: if ( self.global_rank == 0 and self.config.train_log_iter_interval != 0 and self.global_step % self.config.train_log_iter_interval == 0 ): metrics = {k: v for k, v in self.metrics.summary_dict.items()} if self.ds_wall_clock_available: ds_timers = self.model.get_wall_clock_timers() metrics.update(ds_timers) append_json_file(self.config.train_log_metrics_path, metrics) # do not log the first train iteration to wandb, since it's a massive outlier # on all performance metrics, which messes up the scale of the report if self.wandb_experiment is not None and self.global_step > 1: metrics = {k: v for k, v in metrics.items() if k not in ["iter"]} self.wandb_experiment.log(metrics, step=self.global_step) if self.config.eval_interval != 0 and self.global_step % self.config.eval_interval == 0: self.evaluate() if self.is_eval_log_iter(): self.metrics.print_summary(prefix="eval") if self.wandb_experiment is not None: metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]} self.wandb_experiment.log(metrics, step=self.global_step) self.metrics.stop_timer("iter") self.epoch_finished = True
[docs] @callback_wrapper("train") def train(self) -> None: """ Main training loop. Calls the epoch method for each epoch of training. """ try: # to be able to keep track of number of steps of this run inside step() self.global_step_at_start_this_run = self.global_step for epoch_idx in self.epochs: self.epoch_idx = epoch_idx self.epoch() if self.early_stop: break self.checkpoint() self.training_finished = True if self.global_rank == 0: if self.early_stop: print(f"*** Exiting training early because training {self.early_stop_reason}") else: print("*** Training finished normally.") self.checkpoint() except Exception as e: logger.error(f"Training failed with error: {e}") # logger.info(f"{self._trainer_state}") raise (e) finally: if self.config.mem_profiler is not None: torch.cuda.memory._dump_snapshot(self.config.mem_profiler_dir / f"{self.global_rank}.pickle") if self.wandb_experiment is not None: self.wandb_experiment.finish()
[docs] @callback_wrapper("evaluate") def evaluate(self) -> None: """ Evaluation loop. Measures the model's performance on the evaluation dataset. """ self.model.eval() with torch.no_grad(): losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches] self.metrics.record("loss/eval", losses) # type: ignore
[docs] @callback_wrapper("checkpoint") def checkpoint(self) -> None: if self.global_step_this_run == 0: logger.info("No steps were run this run, not saving the checkpoint") return for engine in self.checkpoint_engines: if engine.do_checkpoint: logger.info(f"Saving Checkpoint at global step: {self.global_step}.") engine.save(self.model)