# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
from abc import ABC
from abc import abstractmethod
from functools import cached_property
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import deepspeed
import numpy as np
import torch
import torch.cuda
import torch.distributed.nn
import wandb
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPDataLoaderAdapter
from devtools import debug
from tqdm import tqdm
from transformers import set_seed
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from wandb.sdk.wandb_run import Run as WandbRun
from arctic_training.callback.logging import post_loss_log_cb
from arctic_training.callback.mixin import CallbackMixin
from arctic_training.callback.mixin import callback_wrapper
from arctic_training.checkpoint.engine import CheckpointEngine
from arctic_training.config.trainer import TrainerConfig
from arctic_training.data.factory import DataFactory
from arctic_training.data.utils import OverfitOneBatchDataLoader
from arctic_training.logging import logger
from arctic_training.metrics import Metrics
from arctic_training.model.factory import ModelFactory
from arctic_training.model.tiled_compute import enable_tiled_mlp_compute
from arctic_training.optimizer.factory import OptimizerFactory
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method
from arctic_training.scheduler.factory import SchedulerFactory
from arctic_training.tokenizer.factory import TokenizerFactory
from arctic_training.utils import append_json_file
[docs]
class Trainer(ABC, CallbackMixin, metaclass=RegistryMeta):
"""Base Trainer class."""
name: str
"""
Name of the trainer used for registering custom trainers. This name
should be unique and is used in the training recipe YAMLs to identify which
trainer to be used.
"""
config: TrainerConfig
"""
The type of the config class that the trainer uses. This should be a
subclass of TrainerConfig and add any trainer-specific fields.
"""
data_factory: DataFactory
"""
A List of valid data factory types that the trainer can use. These should
inherit from DataFactory. The first item in the list will be used as the
default if the type is not explicitly set in the YAML config.
"""
model_factory: ModelFactory
"""
A List of valid model factory types that the trainer can use. These should
inherit from ModelFactory. The first item in the list will be used as the
default if the type is not explicitly set in the YAML config.
"""
checkpoint_engine: CheckpointEngine
"""
A List of valid checkpoint engine types that the trainer can use. These
should inherit from CheckpointEngine. The first item in the list will be
used as the default if the type is not explicitly set in the YAML config.
"""
optimizer_factory: OptimizerFactory
"""
A List of valid optimizer factory types that the trainer can use. These
should inherit from OptimizerFactory. The first item in the list will be
used as the default if the type is not explicitly set in the YAML config.
"""
scheduler_factory: SchedulerFactory
"""
A List of valid scheduler factory types that the trainer can use. These
should inherit from SchedulerFactory. The first item in the list will be
used as the default if the type is not explicitly set in the YAML config.
"""
tokenizer_factory: TokenizerFactory
"""
A List of valid tokenizer factory types that the trainer can use. These
should inherit from TokenizerFactory. The first item in the list will be
used as the default if the type is not explicitly set in the YAML config.
"""
callbacks: List[Tuple[str, Callable]] = [
post_loss_log_cb,
]
"""
A list of callbacks for the trainer. Callbacks are specified as tuples of a
string indicating where the callback should be placed and a callable that
implements the callback. Callback events for the trainer include `pre-` and
`post-` for `init`, `train`, `epoch`, `step`, and `checkpoint`.
"""
@classmethod
def _validate_subclass(cls) -> None:
_validate_class_attribute_set(cls, "name")
_validate_class_attribute_type(cls, "config", TrainerConfig)
_validate_class_attribute_type(cls, "data_factory", DataFactory)
_validate_class_attribute_type(cls, "model_factory", ModelFactory)
_validate_class_attribute_type(cls, "checkpoint_engine", CheckpointEngine)
_validate_class_attribute_type(cls, "optimizer_factory", OptimizerFactory)
_validate_class_attribute_type(cls, "scheduler_factory", SchedulerFactory)
_validate_class_attribute_type(cls, "tokenizer_factory", TokenizerFactory)
_validate_class_method(cls, "loss", ["self", "batch"])
_validate_class_method(cls, "step", ["self", "batch"])
_validate_class_method(cls, "epoch", ["self"])
_validate_class_method(cls, "train", ["self"])
_validate_class_method(cls, "checkpoint", ["self"])
def __init__(self, config: TrainerConfig, mode: str = "train") -> None:
logger.info(f"Initializing Trainer with config:\n{debug.format(config)}")
self.config = config
self.epoch_idx = 0
self.train_batch_idx = 0
self.global_step = 0
self.global_step_this_run = 0
self.global_step_at_start_this_run = 0
self.early_stop = False
self.early_stop_reason = ""
self.world_size = config.world_size
self.global_rank = config.global_rank
self.epoch_finished = False
self.training_finished = False
self.wandb_experiment: Optional[WandbRun] = None
self.is_resume = False # Track if we resumed from ckpt
self.wandb_run_id = None
self._set_seeds(self.config.seed)
if self.config.mem_profiler == "e2e":
torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries)
tokenizer_factory = self.config.tokenizer.factory(self)
self.tokenizer = tokenizer_factory()
data_factory = self.config.data.factory(self)
self.train_dataloader, self.eval_dataloader = data_factory()
if mode == "process-data":
return
if self.config.overfit_first_batch:
self.train_dataloader = OverfitOneBatchDataLoader(self.train_dataloader)
# checkpointing and resume
self.checkpoint_engines = [engine(self) for engine in self.config.checkpoint_engines]
for engine in self.checkpoint_engines:
# currently only deepspeed engine supports resume from intermediate checkpoint
if engine.name == "deepspeed" and engine.config.auto_resume and engine.latest_checkpoint_exists:
self.is_resume = True
# XXX: We can abstract this section further with AT-specific wrapper, but
# UlyssesSPAttentionHF should not have any AT-specific objects / assumptions
mpu = UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=self.config.model.name_or_path,
core_attn_implementation=self.config.model.attn_implementation,
sequence_parallel_size=self.config.sequence_parallel_size,
micro_batch_size=self.config.micro_batch_size,
seq_length=self.config.data.max_length,
seq_length_is_variable=True,
)
# Important: this is most likely not beneficial under seqlen=64k
if self.config.activation_checkpoint_cpu_offload:
# activation_checkpointing_cpu_offload becomes very benefitial at very long seqlen
# e.g., llama 8b at 800k (100k effective per gpu) will save 24GB per gpu:
# ((100_000*4096)*2*32/2**30), but for short sequences the offload will just slow things
# down,
#
# XXX: could parameterize or run a few lengths to see at which threshold it becomes
# beneficial - a user might still want this on even at shorter seqlen if they don't
# mind slower performance. discussing adding this functionality to pytorch core
# (https://pytorch.slack.com/archives/C3PDTEV8E/p1745274102600729)
from arctic_training.monkey_patches import monkey_patch_checkpoint_function_with_cpu_offload
monkey_patch_checkpoint_function_with_cpu_offload()
# MLP tiling - has to happen before model is instantiated
if self.config.tiled_mlp_compute:
enable_tiled_mlp_compute(self.config.model.name_or_path)
dschf = HfDeepSpeedConfig(self.config.deepspeed) # noqa: F841
model_factory = self.config.model.factory(self)
self.model = model_factory()
# prevent causal mask from being created in HF Transformers - it's a huge `[bs, seqlen, seqlen]` tensor
# XXX: This should also benefit a single gpu use case when SDPA is used - so perhaps remove the SP>1 check?
if self.config.sequence_parallel_size > 1 and self.config.model.attn_implementation not in [
"flash_attention_2",
"flash_attention_3",
]:
import transformers.masking_utils
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", lambda *args, **kwargs: None)
optimizer_factory = self.config.optimizer.factory(self)
self.optimizer = optimizer_factory()
scheduler_factory = self.config.scheduler.factory(self)
self.scheduler = scheduler_factory()
self.model, *_ = deepspeed.initialize(
model=self.model,
optimizer=self.optimizer,
args=self.config,
lr_scheduler=self.scheduler,
config=self.config.deepspeed,
mpu=mpu,
)
self.ds_wall_clock_available = hasattr(self.model, "get_wall_clock_timers")
if self.config.sequence_parallel_size > 1:
# deepspeed.initialize needs to run first
from deepspeed.utils import groups
# set SP-trainer attributes to be used later
self.sp_group = groups._get_sequence_parallel_group()
self.sp_world_size = groups._get_sequence_parallel_world_size()
self.sp_rank = groups._get_sequence_parallel_rank()
# wrap the DL with Ulysses one
self.train_dataloader = UlyssesSPDataLoaderAdapter(
self.train_dataloader,
sp_rank=self.sp_rank,
sp_group=self.sp_group,
sp_world_size=self.sp_world_size,
device=self.device,
)
if self.eval_dataloader is not None:
self.eval_dataloader = UlyssesSPDataLoaderAdapter(
self.eval_dataloader,
sp_rank=self.sp_rank,
sp_group=self.sp_group,
sp_world_size=self.sp_world_size,
device=self.device,
)
for engine in self.checkpoint_engines:
if engine.config.auto_resume:
engine.load(self.model)
self.metrics = Metrics(self)
if self.global_rank == 0 and self.config.wandb.enable:
# in order for resume to continue the same wandb run we need to re-use a run_id from the previous run
if self.wandb_run_id is None:
self.wandb_run_id = wandb.util.generate_id()
# Note: wandb.init() is not type annotated so we need to use type: ignore
self.wandb_experiment = wandb.init( # type: ignore
id=self.wandb_run_id,
entity=self.config.wandb.entity,
project=self.config.wandb.project,
name=self.config.wandb.name,
config=self.config.model_dump(),
# do not put `wandb` in the root of the repo as it conflicts with wandb package
dir=f"{self.config.logger.output_dir}/wandb",
)
def _set_seeds(self, seed: int) -> None:
logger.info(f"Setting random seeds to {seed}")
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
set_seed(seed)
@property
def model_unwrapped(self):
"""Return the original model before it was wrapped by deepspeed"""
if hasattr(self.model, "module"):
return self.model.module
else:
return self.model
@property
def epochs(self) -> tqdm:
"""Epochs iterator."""
total_epochs = self.config.epochs
if self.config.train_iters:
total_epochs = math.ceil(
self.config.train_iters * self.config.gradient_accumulation_steps / len(self.train_dataloader)
)
return tqdm(
range(self.epoch_idx, total_epochs),
desc="Epochs",
unit="epoch",
disable=(self.global_rank != 0) or (self.config.train_log_iter_interval != 0),
)
@property
def train_batches(self) -> tqdm:
"""Training data iterator."""
return tqdm(
self.train_dataloader,
desc="Train Batches",
unit="batch",
disable=(self.global_rank != 0) or (self.config.train_log_iter_interval != 0),
)
@property
def eval_batches(self) -> tqdm:
"""Evaluation data iterator."""
return tqdm(
self.eval_dataloader,
desc="Eval Batches",
unit="batch",
disable=self.global_rank != 0 or not self.is_eval_log_iter(),
)
[docs]
def is_eval_log_iter(self) -> bool:
return self.global_step // self.config.eval_interval % self.config.eval_log_iter_interval == 0
@cached_property
def device(self) -> torch.device:
"""Current device."""
return torch.device(get_accelerator().device_name(self.config.local_rank))
@property
def training_horizon(self) -> int:
"""Total number of training iterations."""
if self.train_dataloader is None:
raise ValueError("Train dataloader not initialized.")
if self.config.train_iters:
return self.config.train_iters
# XXX: this was incorrect for GAS
return self.config.epochs * len(self.train_dataloader) # // self.config.gradient_accumulation_steps
[docs]
@callback_wrapper("loss")
@abstractmethod
def loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Loss function for the trainer. This method should be implemented by the
inheriting trainer class.
"""
raise NotImplementedError("Loss method must be implemented by the trainer.")
[docs]
@callback_wrapper("backward")
def backward(self, loss: torch.Tensor) -> None:
"""
Backward function for the trainer. This method is called after the loss
method and is responsible for backpropagating the loss through the model.
"""
self.model.backward(loss)
[docs]
def need_early_exit(self):
"""
If we need to exit early, set `self.early_stop_reason` and return True
Otherwise return False
"""
# exit conditions in the order of likelyhood
if (
self.config.exit_iteration_this_run > 0
and self.config.exit_iteration_this_run == self.global_step_this_run
):
self.early_stop_reason = f"reached exit_iteration_this_run of {self.global_step_this_run}"
return True
elif self.config.exit_iteration > 0 and self.config.exit_iteration == self.global_step:
self.early_stop_reason = f"reached exit_iteration of {self.global_step}"
return True
elif self.config.kill_switch_path.exists():
self.early_stop_reason = f"detected kill switch {self.config.kill_switch_path}"
return True
elif self.global_step >= self.training_horizon:
self.early_stop_reason = f"reached training_horizon of {self.global_step}"
return True
return False
[docs]
@callback_wrapper("step")
def step(self, batch: Dict[str, torch.Tensor]) -> None:
"""
Step function for the trainer. Each batch of training data is passed to
this method.
"""
self.model.train()
with deepspeed.runtime.engine.autocast_if_enabled(self.model):
loss = self.loss(batch)
self.backward(loss)
def maybe_item(v):
return v.item() if torch.is_tensor(v) else v
self.metrics.record("loss", maybe_item(loss))
self.model.step()
self.checkpoint()
# DeepSpeed increments its global step after the step() call, so we use it as the golden truth
self.global_step = self.model.global_steps
self.global_step_this_run = self.global_step - self.global_step_at_start_this_run
[docs]
@callback_wrapper("epoch")
def epoch(self) -> None:
"""
Epoch training loop. This method will be called for each epoch of
training and iterates across batches of training data, calling the step
method on each batch.
"""
self.epoch_finished = False
self.metrics.start_timer("iter")
# enable memory allocation history, which will add tracebacks and event history to memory snapshots
if self.config.mem_profiler == "step":
torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries)
batch_iterator = iter(self.train_batches)
if self.is_resume:
logger.info(f"Resumed from checkpoint at global step: {self.global_step}.")
batches_to_skip = self.global_step % len(self.train_dataloader)
logger.info(f"Advancing {batches_to_skip} batches.")
for _ in range(batches_to_skip):
next(batch_iterator)
self.train_batch_idx += batches_to_skip
self.is_resume = False
for batch in batch_iterator:
self.train_batch_idx += 1
# Run the early exit checks before stepping to correctly deal with resume should the training not continue
if self.need_early_exit():
self.early_stop = True
break
self.gas_boundary = self.train_batch_idx % self.config.gradient_accumulation_steps == 0
if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation:
# deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately
sample_seqlens = batch.pop("packed_sample_seqlens")
else:
sample_seqlens = [
[len(batch["input_ids"][idx]) * self.config.sequence_parallel_size]
for idx in range(len(batch["input_ids"]))
]
self.metrics.seqlens = sample_seqlens
self.metrics.start_timer("step")
self.step(batch)
self.metrics.stop_timer("step")
self.metrics.restart_timer("iter")
if self.config.train_log_iter_interval != 0:
self.metrics.print_summary()
if self.gas_boundary:
if (
self.global_rank == 0
and self.config.train_log_iter_interval != 0
and self.global_step % self.config.train_log_iter_interval == 0
):
metrics = {k: v for k, v in self.metrics.summary_dict.items()}
if self.ds_wall_clock_available:
ds_timers = self.model.get_wall_clock_timers()
metrics.update(ds_timers)
append_json_file(self.config.train_log_metrics_path, metrics)
# do not log the first train iteration to wandb, since it's a massive outlier
# on all performance metrics, which messes up the scale of the report
if self.wandb_experiment is not None and self.global_step > 1:
metrics = {k: v for k, v in metrics.items() if k not in ["iter"]}
self.wandb_experiment.log(metrics, step=self.global_step)
if self.config.eval_interval != 0 and self.global_step % self.config.eval_interval == 0:
self.evaluate()
if self.is_eval_log_iter():
self.metrics.print_summary(prefix="eval")
if self.wandb_experiment is not None:
metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]}
self.wandb_experiment.log(metrics, step=self.global_step)
self.metrics.stop_timer("iter")
self.epoch_finished = True
[docs]
@callback_wrapper("train")
def train(self) -> None:
"""
Main training loop. Calls the epoch method for each epoch of training.
"""
try:
# to be able to keep track of number of steps of this run inside step()
self.global_step_at_start_this_run = self.global_step
for epoch_idx in self.epochs:
self.epoch_idx = epoch_idx
self.epoch()
if self.early_stop:
break
self.checkpoint()
self.training_finished = True
if self.global_rank == 0:
if self.early_stop:
print(f"*** Exiting training early because training {self.early_stop_reason}")
else:
print("*** Training finished normally.")
self.checkpoint()
except Exception as e:
logger.error(f"Training failed with error: {e}")
# logger.info(f"{self._trainer_state}")
raise (e)
finally:
if self.config.mem_profiler is not None:
torch.cuda.memory._dump_snapshot(self.config.mem_profiler_dir / f"{self.global_rank}.pickle")
if self.wandb_experiment is not None:
self.wandb_experiment.finish()
[docs]
@callback_wrapper("evaluate")
def evaluate(self) -> None:
"""
Evaluation loop. Measures the model's performance on the evaluation dataset.
"""
self.model.eval()
with torch.no_grad():
losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches]
self.metrics.record("loss/eval", losses) # type: ignore
[docs]
@callback_wrapper("checkpoint")
def checkpoint(self) -> None:
if self.global_step_this_run == 0:
logger.info("No steps were run this run, not saving the checkpoint")
return
for engine in self.checkpoint_engines:
if engine.do_checkpoint:
logger.info(f"Saving Checkpoint at global step: {self.global_step}.")
engine.save(self.model)