Trainer
ArcticTraining provides a flexible and extensible training framework that allows you to customize and create your own training workflows. At the core of this framework is the Trainer class, which orchestrates the training process by managing the model, optimizer, data loader, and other components.
The Trainer class is designed to be modular and extensible, allowing you to quickly swap in and out different building blocks to experiment with different training strategies. Here, we’ll walk through the key features of the Trainer class and show you how to create your own custom trainers.
- class arctic_training.trainer.trainer.Trainer(config, mode='train')[source]
Bases:
ABC,CallbackMixinBase Trainer class.
- Parameters:
config (TrainerConfig)
mode (str)
-
name:
str Name of the trainer used for registering custom trainers. This name should be unique and is used in the training recipe YAMLs to identify which trainer to be used.
-
data_factory:
DataFactory A List of valid data factory types that the trainer can use. These should inherit from DataFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
model_factory:
ModelFactory A List of valid model factory types that the trainer can use. These should inherit from ModelFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
checkpoint_engine:
CheckpointEngine A List of valid checkpoint engine types that the trainer can use. These should inherit from CheckpointEngine. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
optimizer_factory:
OptimizerFactory A List of valid optimizer factory types that the trainer can use. These should inherit from OptimizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
scheduler_factory:
SchedulerFactory A List of valid scheduler factory types that the trainer can use. These should inherit from SchedulerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
tokenizer_factory:
TokenizerFactory A List of valid tokenizer factory types that the trainer can use. These should inherit from TokenizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.
-
callbacks:
List[Tuple[str,Callable]] = [('post-loss', <function _log_loss_value>)] A list of callbacks for the trainer. Callbacks are specified as tuples of a string indicating where the callback should be placed and a callable that implements the callback. Callback events for the trainer include pre- and post- for init, train, epoch, step, and checkpoint.
-
config:
TrainerConfig The type of the config class that the trainer uses. This should be a subclass of TrainerConfig and add any trainer-specific fields.
- property model_unwrapped
Return the original model before it was wrapped by deepspeed
- property epochs: tqdm.tqdm
Epochs iterator.
- property train_batches: tqdm.tqdm
Training data iterator.
- property device: torch.device
Current device.
- property training_horizon: int
Total number of training iterations.
- abstract loss(batch)[source]
Loss function for the trainer. This method should be implemented by the inheriting trainer class.
- Return type:
Tensor- Parameters:
batch (Dict[str, torch.Tensor])
- backward(loss)[source]
Backward function for the trainer. This method is called after the loss method and is responsible for backpropagating the loss through the model.
- Return type:
None- Parameters:
loss (torch.Tensor)
- step(batch)[source]
Step function for the trainer. Each batch of training data is passed to this method.
- Return type:
None- Parameters:
batch (Dict[str, torch.Tensor])
- epoch()[source]
Epoch training loop. This method will be called for each epoch of training and iterates across batches of training data, calling the step method on each batch.
- Return type:
None
Attributes
Creating a custom trainer starts with Inheriting from the base
Trainer class and defining the name attribute. The
name attribute is used to identify the trainer when registering it with
ArcticTraining. Additionally, you can define custom types for
config, data_factory,
model_factory, checkpoint_engine,
optimizer_factory, scheduler_factory, and
tokenizer_factory to specify the default factories for each
component.
Specify the type hint for these attributes tells ArcticTraining which building blocks are compatible with your custom trainer. You may define multiple compatible building blocks by using typing.Union in the type hint. When multiple types are specified for one of these attributes, the first is used as a default in the case where type is not specified in the input config.
Properties
The Trainer class provides several properties that can be used to access
information about the state of the trainer at runtime. These include
epochs, train_batches,
device, training_horizon, and
warmup_steps.
Properties should typically not be set by custom trainers, but can be used by other custom classes, like new checkpoint engines or model factories, to access information about the training process.
Methods
The Trainer class has several methods that divide the training loop into
segments. At minimum, a new trainer must specify the loss()
method. However any of the train(), epoch(),
step(), or checkpoint() methods can be
overridden to customize the training process.
Train
@callback_wrapper("train")
def train(self) -> None:
"""
Main training loop. Calls the epoch method for each epoch of training.
"""
try:
for epoch_idx in self.epochs:
self.epoch_idx = epoch_idx
self.epoch()
if self.early_stop:
break
self.checkpoint()
self.training_finished = True
logger.info("Training finished.")
self.checkpoint()
except Exception as e:
logger.error(f"Training failed with error: {e}")
# logger.info(f"{self._trainer_state}")
raise (e)
finally:
if self.config.mem_profiler is not None:
torch.cuda.memory._dump_snapshot(self.config.mem_profiler_dir / f"{self.global_rank}.pickle")
if self.wandb_experiment is not None:
self.wandb_experiment.finish()
Epoch
@callback_wrapper("epoch")
def epoch(self) -> None:
"""
Epoch training loop. This method will be called for each epoch of
training and iterates across batches of training data, calling the step
method on each batch.
"""
self.epoch_finished = False
self.metrics.start_timer("iter")
# enable memory allocation history, which will add tracebacks and event history to memory snapshots
if self.config.mem_profiler == "step":
torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries)
for batch in self.train_batches:
self.train_batch_idx += 1
if (
self.config.gradient_accumulation_steps == 1
or self.train_batch_idx % self.config.gradient_accumulation_steps == 0
):
self.gas_boundary = True
else:
self.gas_boundary = False
if "packed_sample_seqlens" in batch and self.config.model.attn_implementation == "flash_attention_2":
# deal correctly with packed samples under FA2, by calculating each seqlen tflos separately
sample_seqlens = batch.pop("packed_sample_seqlens")
else:
sample_seqlens = [
[len(batch["input_ids"][idx]) * self.config.sequence_parallel_size]
for idx in range(len(batch["input_ids"]))
]
self.metrics.seqlens = sample_seqlens
self.metrics.start_timer("step")
self.step(batch)
self.metrics.stop_timer("step")
self.metrics.restart_timer("iter")
if (
self.config.train_log_iter_interval != 0
and self.train_batch_idx % self.config.train_log_iter_interval == 0
):
self.metrics.print_summary()
if self.global_rank == 0 and self.gas_boundary:
metrics = {k: v for k, v in self.metrics.summary_dict.items()}
append_json_file(self.config.train_log_metrics_path, metrics)
# first iter is a massive outlier for many fields - so skip it in wandb
if self.wandb_experiment is not None and self.train_batch_idx > 1:
metrics.pop("iter") # not needed for wandb
self.wandb_experiment.log(metrics, step=self.model.global_steps)
if self.config.kill_switch_path.exists():
self.early_stop = True
if self.early_stop:
break
self.metrics.stop_timer("iter")
self.epoch_finished = True
Step
@callback_wrapper("step")
def step(self, batch: Dict[str, torch.Tensor]) -> None:
"""
Step function for the trainer. Each batch of training data is passed to
this method.
"""
self.model.train()
loss = self.loss(batch)
self.backward(loss)
def maybe_item(v):
return v.item() if torch.is_tensor(v) else v
self.metrics.record("loss", maybe_item(loss))
self.model.step()
# use deepspeed global step as golden truth
self.global_step = self.model.global_steps
if self.global_step >= self.training_horizon:
self.early_stop = True
self.checkpoint()
if self.config.exit_iteration > 0 and self.config.exit_iteration == self.global_step:
self.early_stop = True
logger.info(f"Hit exit iteration of {self.global_step}, ending training")
Checkpoint
@callback_wrapper("checkpoint")
def checkpoint(self) -> None:
for engine in self.checkpoint_engines:
if engine.do_checkpoint:
logger.info(f"Saving Checkpoint at global step: {self.global_step}.")
engine.save(self.model)
Supervised Fine-Tuning (SFT) Trainer
To help you get started with creating custom trainers, ArcticTraining includes a Supervised Fine-Tuning (SFT) trainer that demonstrates how to build a training pipeline from the base building blocks. The SFT trainer can in turn be used as a starting point and extended for creating your own custom trainers.
To create the SFT trainer, we subclass the Trainer and override the
loss() method. We also define the necessary components described
in Trainer Attributes. We use a custom data factory,
SFTDataFactory, which we describe in greater detail in the Data
Factory section. The remainder of the attributes use the base building
blocks from ArcticTraining. For example the model factory defaults to the
HFModelFactory (because it is listed first in the model_factory attribute
type hint), but this trainer can work with either HFModelFactory or
LigerModelFactory.
class SFTTrainer(Trainer):
name = "sft"
data_factory: SFTDataFactory
model_factory: Union[HFModelFactory, LigerModelFactory]
checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine]
optimizer_factory: Union[FusedAdamOptimizerFactory, CPUAdamOptimizerFactory]
scheduler_factory: Union[HFSchedulerFactory]
tokenizer_factory: Union[HFTokenizerFactory]
def loss(self, batch) -> torch.Tensor:
batch = to_device(batch, self.device)
if self.config.sequence_parallel_size == 1:
# if model.type=liger is configured - this will use a much more efficient fused
# logits+loss liger kernel - using significantly less gpu memory and a bit faster
# compute (liger fused logits+loss kernel does not repeat forward during backward)
outputs = self.model(**batch, use_cache=False)
loss = outputs.loss
return loss
# Ulysses SP expectations:
# 1. batch has `labels`` replaced with `shift_labels`` (which are already preshifted in
# DataLoader)
# 2. this rank deals with a seqlen dimension shard so once the loss is calculated it needs
# to do a differentiable weighted loss average to get the grads right
if "labels" in batch:
raise ValueError(
"found labels in batch - they shouldn't be there, instead shift_labels should be there - check"
" that UlyssesAttentionHFDataLoaderWrapper has been applied to the original DataLoader object"
)
if "shift_labels" not in batch:
raise ValueError(
"shift_labels are missing from the batch - check that UlyssesAttentionHFDataLoaderWrapper has been"
" applied to the original DataLoader object"
)
shift_labels = batch["shift_labels"]
# We have 2 implementation of efficient tiled logits+loss computation.
# 1. Liger fused cross-entropy is the fastest/most memory efficient way - liger-kernel
# doesn't recompute forward inside backward, instead it computes the gradients in the
# forward path.
# 2. But liger kernel isn't implemented for all HF Transformers models, so then we fall
# back onto our tiled logits+loss compute implementation that is almost as efficient
# memory-wise, but which has more compute overhead before backward re-runs forward. The
# total memory usage is very similar, but cuda cache flushes earlier if pushing close to
# OOM than liger.
if self.config.model.type == "liger":
# letting liger do fused logits+loss calculation
outputs = self.model(**batch, use_cache=False)
loss = outputs.loss
else:
# Currently relying on an automatic num_shards derivation based on the goal that it'll
# take approximately 1GB of fp32 logits in a shard, could make this configurable if
# desired later. Less than 1GB doesn't seem to make much of an impact, but perhaps a
# higher number will be more efficient as it'll run less shards.
num_shards = "auto"
if num_shards == "auto":
# parameterize to about 1GB fp32 logits shards
slice_size_in_gb = 1 # XXX: make configurable?
bs, seqlen = shift_labels.shape
vocab_size = self.model_unwrapped.config.vocab_size
logits_numel = bs * seqlen * vocab_size
size_in_gb = logits_numel * 4 / 2**30 # fp32
# the sp shard's seqlen sp shard can be easily not divisible by the derived number
# of chunked loss shards, so we use the uppper ceiling and allow the last chunk to
# be shorter than the rest
num_shards = math.ceil(size_in_gb / slice_size_in_gb)
# print(f"derived {num_shards} shards for size {size_in_gb}GB")
model_with_head = self.model_unwrapped
outputs = model_with_head.model(**batch, use_cache=False)
hidden_states = outputs.last_hidden_state
kwargs_to_shard = dict(
hidden_states=hidden_states,
shift_labels=shift_labels,
)
kwargs_to_pass = dict(model_with_head=model_with_head, vocab_size=self.model_unwrapped.config.vocab_size)
grad_requiring_tensor_key = "hidden_states"
compute_params = [model_with_head.lm_head.weight]
seqlen = shift_labels.shape[1]
# since -100s shift_labels are ignored we have to perform a weighted average on each
# loss slice as each slice may contribute a different number of non- -100 labels
def fused_logits_loss_fn(
model_with_head=None, hidden_states=None, labels=None, shift_labels=None, vocab_size=0
):
logits = model_with_head.lm_head(hidden_states)
if all((shift_labels == -100).squeeze()):
# fake loss calculation, since CE will return nan, but grads will be set
# a normal loss_fn upcasts logits to float so match it
loss_sum = (logits.sum() * 0.0).float()
else:
good_items = sum((shift_labels != -100).squeeze())
loss = model_with_head.loss_function(
logits=logits, labels=labels, vocab_size=vocab_size, shift_labels=shift_labels
)
loss_sum = loss * good_items
return loss_sum
total_loss_sum = sequence_tiled_compute(
fused_logits_loss_fn,
seqlen,
num_shards,
kwargs_to_shard,
kwargs_to_pass,
grad_requiring_tensor_key,
compute_params,
output_unshard_dimension=0, # loss is a scalar
output_reduction="sum",
)
total_good_items = sum((shift_labels != -100).squeeze())
loss = total_loss_sum / total_good_items
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=self.sp_group)
good_tokens = sum((shift_labels != -100).view(-1))
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=self.sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(self.sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / total_good_tokens
return loss