Trainer
ArcticTraining provides a flexible and extensible training framework that allows you to customize and create your own training workflows. At the core of this framework is the Trainer class, which orchestrates the training process by managing the model, optimizer, data loader, and other components.
The Trainer class is designed to be modular and extensible, allowing you to quickly swap in and out different building blocks to experiment with different training strategies. Here, we’ll walk through the key features of the Trainer class and show you how to create your own custom trainers.
- class arctic_training.trainer.trainer.Trainer(config, mode='train')[source]
Bases:
ABC,CallbackMixinBase Trainer class.
- Parameters:
config (TrainerConfig)
mode (str)
-
name:
str Name of the trainer used for registering custom trainers. This name should be unique and is used in the training recipe YAMLs to identify which trainer to be used.
-
data_factory:
DataFactory A List of valid data factory types that the trainer can use. These should inherit from DataFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
model_factory:
ModelFactory A List of valid model factory types that the trainer can use. These should inherit from ModelFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
checkpoint_engine:
CheckpointEngine A List of valid checkpoint engine types that the trainer can use. These should inherit from CheckpointEngine. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
optimizer_factory:
OptimizerFactory A List of valid optimizer factory types that the trainer can use. These should inherit from OptimizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
scheduler_factory:
SchedulerFactory A List of valid scheduler factory types that the trainer can use. These should inherit from SchedulerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
tokenizer_factory:
TokenizerFactory A List of valid tokenizer factory types that the trainer can use. These should inherit from TokenizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
callbacks:
List[Tuple[str,Callable]] = [('post-loss', <function _log_loss_value>)] A list of callbacks for the trainer. Callbacks are specified as tuples of a string indicating where the callback should be placed and a callable that implements the callback. Callback events for the trainer include pre- and post- for init, train, epoch, step, and checkpoint.
-
config:
TrainerConfig The type of the config class that the trainer uses. This should be a subclass of TrainerConfig and add any trainer-specific fields.
- property model_unwrapped
Return the original model before it was wrapped by deepspeed
- property epochs: tqdm.tqdm
Epochs iterator.
- property train_batches: tqdm.tqdm
Training data iterator.
- property eval_batches: tqdm.tqdm
Evaluation data iterator.
- property device: torch.device
Current device.
- property training_horizon: int
Total number of training iterations.
- abstract loss(batch)[source]
Loss function for the trainer. This method should be implemented by the inheriting trainer class.
- Return type:
Tensor- Parameters:
batch (Dict[str, torch.Tensor])
- backward(loss)[source]
Backward function for the trainer. This method is called after the loss method and is responsible for backpropagating the loss through the model.
- Return type:
None- Parameters:
loss (torch.Tensor)
- need_early_exit()[source]
If we need to exit early, set self.early_stop_reason and return True Otherwise return False
- step(batch)[source]
Step function for the trainer. Each batch of training data is passed to this method.
- Return type:
None- Parameters:
batch (Dict[str, torch.Tensor])
- epoch()[source]
Epoch training loop. This method will be called for each epoch of training and iterates across batches of training data, calling the step method on each batch.
- Return type:
None
- train()[source]
Main training loop. Calls the epoch method for each epoch of training.
- Return type:
None
- evaluate()[source]
Evaluation loop. Measures the model’s performance on the evaluation dataset.
- Return type:
None
- count_model_parameters()[source]
Returns a dictionary containing “total” and “trainable” parameters.
Attributes
Creating a custom trainer starts with Inheriting from the base
Trainer class and defining the name attribute. The
name attribute is used to identify the trainer when registering it with
ArcticTraining. Additionally, you can define custom types for
config, data_factory,
model_factory, checkpoint_engine,
optimizer_factory, scheduler_factory, and
tokenizer_factory to specify the default factories for each
component.
Specify the type hint for these attributes tells ArcticTraining which building blocks are compatible with your custom trainer. You may define multiple compatible building blocks by using typing.Union in the type hint. When multiple types are specified for one of these attributes, the first is used as a default in the case where type is not specified in the input config.
Properties
The Trainer class provides several properties that can be used to access
information about the state of the trainer at runtime. These include
epochs, train_batches,
device, training_horizon, and
warmup_steps.
Properties should typically not be set by custom trainers, but can be used by other custom classes, like new checkpoint engines or model factories, to access information about the training process.
Methods
The Trainer class has several methods that divide the training loop into
segments. At minimum, a new trainer must specify the loss()
method. However any of the train(), epoch(),
step(), or checkpoint() methods can be
overridden to customize the training process.
Train
@callback_wrapper("train")
def train(self) -> None:
"""
Main training loop. Calls the epoch method for each epoch of training.
"""
self.print_model_parameters_header()
try:
# to be able to keep track of number of steps of this run inside step()
self.global_step_at_start_this_run = self.global_step
for epoch_idx in self.epochs:
self.epoch_idx = epoch_idx
self.epoch()
if self.early_stop:
break
self.checkpoint()
self.training_finished = True
if self.global_rank == 0:
if self.early_stop:
print(f"*** Exiting training early because training {self.early_stop_reason}")
else:
print("*** Training finished normally.")
self.checkpoint()
except Exception as e:
logger.error(f"Training failed with error: {e}")
# logger.info(f"{self._trainer_state}")
raise (e)
finally:
if self.config.mem_profiler is not None:
torch.cuda.memory._dump_snapshot(self.config.mem_profiler_dir / f"{self.global_rank}.pickle")
if self.wandb_experiment is not None:
self.wandb_experiment.finish()
Epoch
@callback_wrapper("epoch")
def epoch(self) -> None:
"""
Epoch training loop. This method will be called for each epoch of
training and iterates across batches of training data, calling the step
method on each batch.
"""
self.epoch_finished = False
self.metrics.start_timer("iter")
# enable memory allocation history, which will add tracebacks and event history to memory snapshots
if self.config.mem_profiler == "step":
torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries)
batch_iterator = iter(self.train_batches)
if self.is_resume:
logger.info(f"Resumed from checkpoint at global step: {self.global_step}.")
batches_to_skip = self.global_step % len(self.train_dataloader)
logger.info(f"Advancing {batches_to_skip} batches.")
for _ in range(batches_to_skip):
next(batch_iterator)
self.train_batch_idx += batches_to_skip
self.is_resume = False
for batch in batch_iterator:
self.train_batch_idx += 1
# Run the early exit checks before stepping to correctly deal with resume should the training not continue
if self.need_early_exit():
self.early_stop = True
break
self.gas_boundary = self.train_batch_idx % self.config.gradient_accumulation_steps == 0
if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation:
# deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately
sample_seqlens = batch.pop("packed_sample_seqlens")
else:
sample_seqlens = [
[len(batch["input_ids"][idx]) * self.config.sequence_parallel_size]
for idx in range(len(batch["input_ids"]))
]
self.metrics.seqlens = sample_seqlens
self.metrics.start_timer("step")
self.step(batch)
self.metrics.stop_timer("step")
self.metrics.restart_timer("iter")
if self.config.train_log_iter_interval != 0:
self.metrics.print_summary()
if self.gas_boundary:
if (
self.global_rank == 0
and self.config.train_log_iter_interval != 0
and self.global_step % self.config.train_log_iter_interval == 0
):
metrics = {k: v for k, v in self.metrics.summary_dict.items()}
if self.ds_wall_clock_available:
ds_timers = self.model.get_wall_clock_timers()
metrics.update(ds_timers)
append_json_file(self.config.train_log_metrics_path, metrics)
# do not log the first train iteration to wandb, since it's a massive outlier
# on all performance metrics, which messes up the scale of the report
if self.wandb_experiment is not None and self.global_step > 1:
metrics = {k: v for k, v in metrics.items() if k not in ["iter"]}
self.wandb_experiment.log(metrics, step=self.global_step)
if self.config.eval_interval != 0 and self.global_step % self.config.eval_interval == 0:
self.evaluate()
if self.is_eval_log_iter():
self.metrics.print_summary(prefix="eval")
if self.wandb_experiment is not None:
metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]}
self.wandb_experiment.log(metrics, step=self.global_step)
self.metrics.stop_timer("iter")
self.epoch_finished = True
Step
@callback_wrapper("step")
def step(self, batch: Dict[str, torch.Tensor]) -> None:
"""
Step function for the trainer. Each batch of training data is passed to
this method.
"""
self.model.train()
with deepspeed.runtime.engine.autocast_if_enabled(self.model):
loss = self.loss(batch)
self.backward(loss)
def maybe_item(v):
return v.item() if torch.is_tensor(v) else v
self.metrics.record("loss", maybe_item(loss))
# if neededing to debug AMoE-EP grads
# from deepspeed.utils import safe_get_full_grad
#
# if hasattr(self.model_unwrapped.model.layers[1].mlp, "router_gate"):
# pr0("------------------------->8------------- grads ------------->8----------",
# force=True)
# pr0(
# f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.router_gate))=}",
# force=True,
# )
# pr0(
# f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.expert_gate_up))=}",
# force=True,
# )
# pr0(
# f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.expert_down))=}",
# force=True,
# )
# pr0(
# f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].post_attention_layernorm.weight))=}",
# force=True,
# )
# pr0("------------------------->8------------- grads end --------->8----------",
# force=True)
# exit()
self.model.step()
self.checkpoint()
# DeepSpeed increments its global step after the step() call, so we use it as the golden truth
self.global_step = self.model.global_steps
self.global_step_this_run = self.global_step - self.global_step_at_start_this_run
Checkpoint
@callback_wrapper("checkpoint")
def checkpoint(self) -> None:
if self.global_step_this_run == 0:
logger.info("No steps were run this run, not saving the checkpoint")
return
for engine in self.checkpoint_engines:
if engine.do_checkpoint:
if engine.name == "huggingface" and self.use_arctic_moe:
if self.training_finished:
# export to the original moe mlp format/layout - this is slow but it's the end of the training so it's fine.
from arctic_training.model.moe.utils import remap_arctic_moe_to_orig_moe_mlp_params
logger.info("Exporting to the original MoE format before saving the checkpoint")
remap_arctic_moe_to_orig_moe_mlp_params(self.model)
else:
raise ValueError(
"Currently supporting saving to HF checkpoint for AMoE models only when the training is"
" finished, because conversion will be very slow. For interim checkpoints use `deepspeed`"
" type of the checkpoint as it'd be much faster to save to and resume from. "
)
logger.info(f"Saving Checkpoint at global step: {self.global_step}.")
engine.save(self.model)
Supervised Fine-Tuning (SFT) Trainer
To help you get started with creating custom trainers, ArcticTraining includes a Supervised Fine-Tuning (SFT) trainer that demonstrates how to build a training pipeline from the base building blocks. The SFT trainer can in turn be used as a starting point and extended for creating your own custom trainers.
To create the SFT trainer, we subclass the Trainer and override the
loss() method. We also define the necessary components described
in Trainer Attributes. We use a custom data factory,
SFTDataFactory, which we describe in greater detail in the Data
Factory section. The remainder of the attributes use the base building
blocks from ArcticTraining. For example the model factory defaults to the
HFModelFactory (because it is listed first in the model_factory attribute
type hint), but this trainer can work with either HFModelFactory or
LigerModelFactory.
class SFTTrainer(Trainer):
name = "sft"
data_factory: SFTDataFactory
model_factory: Union[HFModelFactory, LigerModelFactory]
checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine]
optimizer_factory: Union[
FusedAdamOptimizerFactory, FusedAdamMoEOptimizerFactory, CPUAdamOptimizerFactory, CPUAdamMoEOptimizerFactory
]
scheduler_factory: Union[HFSchedulerFactory]
tokenizer_factory: Union[HFTokenizerFactory]
def loss(self, batch) -> torch.Tensor:
batch = to_device(batch, self.device)
if self.config.sequence_parallel_size == 1:
# if model.type=liger is configured - this will use a much more efficient fused
# logits+loss liger kernel - using significantly less gpu memory and a bit faster
# compute (liger fused logits+loss kernel does not repeat forward during backward)
outputs = self.model(**batch, use_cache=False)
loss = outputs.loss
return loss
# Ulysses SP expectations:
# 1. batch has `labels`` replaced with `shift_labels`` (which are already preshifted in
# DataLoader)
# 2. this rank deals with a seqlen dimension shard so once the loss is calculated it needs
# to do a differentiable weighted loss average to get the grads right
if "labels" in batch:
raise ValueError(
"found labels in batch - they shouldn't be there, instead shift_labels should be there - check"
" that UlyssesSPDataLoaderAdapter has been applied to the original DataLoader object"
)
if "shift_labels" not in batch:
raise ValueError(
"shift_labels are missing from the batch - check that UlyssesSPDataLoaderAdapter has been"
" applied to the original DataLoader object"
)
shift_labels = batch["shift_labels"]
# We have 2 implementation of efficient tiled logits+loss computation.
# 1. Liger fused cross-entropy is the fastest/most memory efficient way - liger-kernel
# doesn't recompute forward inside backward, instead it computes the gradients in the
# forward path.
# 2. But liger kernel isn't implemented for all HF Transformers models, so then we fall
# back onto our tiled logits+loss compute implementation that is almost as efficient
# memory-wise, but which has more compute overhead before backward re-runs forward. The
# total memory usage is very similar, but cuda cache flushes earlier if pushing close to
# OOM than liger.
if self.config.model.type == "liger":
# letting liger do fused logits+loss calculation
outputs = self.model(**batch, use_cache=False)
loss = outputs.loss
if loss is None:
# XXX: not sure why this happens with SP>1 and eval-enabled, I checked shift_labels contain valid non -100 tokens - disabling fused_linear_cross_entropy=False in AutoLigerKernelForCausalLM.from_pretrained doesn't help. all works when eval is off.
raise ValueError(
"Liger-Kernel failed to compute loss (returned None) - it's known to fail with eval enabled along"
" train steps when SP>1."
)
else:
# Currently relying on an automatic num_shards derivation based on the goal that it'll
# take approximately 1GB of fp32 logits in a shard, could make this configurable if
# desired later. Less than 1GB doesn't seem to make much of an impact, but perhaps a
# higher number will be more efficient as it'll run less shards.
num_shards = "auto"
if num_shards == "auto":
# parameterize to about 1GB fp32 logits shards
slice_size_in_gb = 1 # XXX: make configurable?
bs, seqlen = shift_labels.shape
vocab_size = self.model_unwrapped.config.vocab_size
logits_numel = bs * seqlen * vocab_size
size_in_gb = logits_numel * 4 / 2**30 # fp32
# the sp shard's seqlen sp shard can be easily not divisible by the derived number
# of chunked loss shards, so we use the uppper ceiling and allow the last chunk to
# be shorter than the rest
num_shards = math.ceil(size_in_gb / slice_size_in_gb)
# print(f"derived {num_shards} shards for size {size_in_gb}GB")
model_with_head = self.model_unwrapped
outputs = model_with_head.model(**batch, use_cache=False)
hidden_states = outputs.last_hidden_state
compute_params = [model_with_head.lm_head.weight]
seqlen = shift_labels.shape[1]
mask = None
output_reduction = "sum"
# since -100s shift_labels are ignored we have to perform a weighted average on each
# loss slice as each slice may contribute a different number of non- -100 labels
def fused_logits_loss_fn(model_with_head=None, hidden_states=None, shift_labels=None):
vocab_size = model_with_head.config.vocab_size
logits = model_with_head.lm_head(hidden_states)
if all((shift_labels == -100).squeeze()):
# fake loss calculation, since CE will return nan, but grads will be set
# a normal loss_fn upcasts logits to float so match it
loss_sum = (logits.sum() * 0.0).float()
else:
good_items = ((shift_labels != -100).squeeze()).sum()
loss = model_with_head.loss_function(
logits=logits, labels=None, vocab_size=vocab_size, shift_labels=shift_labels
)
loss_sum = loss * good_items
return loss_sum
total_loss_sum = TiledFusedLogitsLoss.apply(
fused_logits_loss_fn,
model_with_head,
hidden_states,
shift_labels,
mask,
num_shards,
compute_params,
output_reduction,
)
total_good_items = (shift_labels != -100).squeeze().sum()
loss = total_loss_sum / max(total_good_items, 1)
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=self.sp_group)
good_tokens = ((shift_labels != -100).view(-1)).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=self.sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(self.sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / total_good_tokens
return loss