Trainer

ArcticTraining provides a flexible and extensible training framework that allows you to customize and create your own training workflows. At the core of this framework is the Trainer class, which orchestrates the training process by managing the model, optimizer, data loader, and other components.

The Trainer class is designed to be modular and extensible, allowing you to quickly swap in and out different building blocks to experiment with different training strategies. Here, we’ll walk through the key features of the Trainer class and show you how to create your own custom trainers.

class arctic_training.trainer.trainer.Trainer(config, mode='train')[source]

Bases: ABC, CallbackMixin

Base Trainer class.

Parameters:
name: str

Name of the trainer used for registering custom trainers. This name should be unique and is used in the training recipe YAMLs to identify which trainer to be used.

data_factory: DataFactory

A List of valid data factory types that the trainer can use. These should inherit from DataFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.

model_factory: ModelFactory

A List of valid model factory types that the trainer can use. These should inherit from ModelFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.

checkpoint_engine: CheckpointEngine

A List of valid checkpoint engine types that the trainer can use. These should inherit from CheckpointEngine. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.

optimizer_factory: OptimizerFactory

A List of valid optimizer factory types that the trainer can use. These should inherit from OptimizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.

scheduler_factory: SchedulerFactory

A List of valid scheduler factory types that the trainer can use. These should inherit from SchedulerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.

tokenizer_factory: TokenizerFactory

A List of valid tokenizer factory types that the trainer can use. These should inherit from TokenizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config.

callbacks: List[Tuple[str, Callable]] = [('post-loss', <function _log_loss_value>)]

A list of callbacks for the trainer. Callbacks are specified as tuples of a string indicating where the callback should be placed and a callable that implements the callback. Callback events for the trainer include pre- and post- for init, train, epoch, step, and checkpoint.

config: TrainerConfig

The type of the config class that the trainer uses. This should be a subclass of TrainerConfig and add any trainer-specific fields.

property model_unwrapped

Return the original model before it was wrapped by deepspeed

property epochs: tqdm.tqdm

Epochs iterator.

property train_batches: tqdm.tqdm

Training data iterator.

property eval_batches: tqdm.tqdm

Evaluation data iterator.

is_eval_log_iter()[source]
Return type:

bool

property device: torch.device

Current device.

property training_horizon: int

Total number of training iterations.

abstract loss(batch)[source]

Loss function for the trainer. This method should be implemented by the inheriting trainer class.

Return type:

Tensor

Parameters:

batch (Dict[str, torch.Tensor])

backward(loss)[source]

Backward function for the trainer. This method is called after the loss method and is responsible for backpropagating the loss through the model.

Return type:

None

Parameters:

loss (torch.Tensor)

need_early_exit()[source]

If we need to exit early, set self.early_stop_reason and return True Otherwise return False

step(batch)[source]

Step function for the trainer. Each batch of training data is passed to this method.

Return type:

None

Parameters:

batch (Dict[str, torch.Tensor])

epoch()[source]

Epoch training loop. This method will be called for each epoch of training and iterates across batches of training data, calling the step method on each batch.

Return type:

None

train()[source]

Main training loop. Calls the epoch method for each epoch of training.

Return type:

None

evaluate()[source]

Evaluation loop. Measures the model’s performance on the evaluation dataset.

Return type:

None

checkpoint()[source]
Return type:

None

count_model_parameters()[source]

Returns a dictionary containing “total” and “trainable” parameters.

count_model_params_in_original_model()[source]

This counts total params in the model before it got sliced into MoE EP slices

print_model_parameters_header()[source]

Always print stats about the model we are about to train on rank 0

Attributes

Creating a custom trainer starts with Inheriting from the base Trainer class and defining the name attribute. The name attribute is used to identify the trainer when registering it with ArcticTraining. Additionally, you can define custom types for config, data_factory, model_factory, checkpoint_engine, optimizer_factory, scheduler_factory, and tokenizer_factory to specify the default factories for each component.

Specify the type hint for these attributes tells ArcticTraining which building blocks are compatible with your custom trainer. You may define multiple compatible building blocks by using typing.Union in the type hint. When multiple types are specified for one of these attributes, the first is used as a default in the case where type is not specified in the input config.

Properties

The Trainer class provides several properties that can be used to access information about the state of the trainer at runtime. These include epochs, train_batches, device, training_horizon, and warmup_steps.

Properties should typically not be set by custom trainers, but can be used by other custom classes, like new checkpoint engines or model factories, to access information about the training process.

Methods

The Trainer class has several methods that divide the training loop into segments. At minimum, a new trainer must specify the loss() method. However any of the train(), epoch(), step(), or checkpoint() methods can be overridden to customize the training process.

Train

    @callback_wrapper("train")
    def train(self) -> None:
        """
        Main training loop. Calls the epoch method for each epoch of training.
        """

        self.print_model_parameters_header()

        try:

            # to be able to keep track of number of steps of this run inside step()
            self.global_step_at_start_this_run = self.global_step

            for epoch_idx in self.epochs:
                self.epoch_idx = epoch_idx
                self.epoch()
                if self.early_stop:
                    break
                self.checkpoint()
            self.training_finished = True
            if self.global_rank == 0:
                if self.early_stop:
                    print(f"*** Exiting training early because training {self.early_stop_reason}")
                else:
                    print("*** Training finished normally.")
            self.checkpoint()
        except Exception as e:
            logger.error(f"Training failed with error: {e}")
            # logger.info(f"{self._trainer_state}")
            raise (e)
        finally:
            if self.config.mem_profiler is not None:
                torch.cuda.memory._dump_snapshot(self.config.mem_profiler_dir / f"{self.global_rank}.pickle")

            if self.wandb_experiment is not None:
                self.wandb_experiment.finish()

Epoch

    @callback_wrapper("epoch")
    def epoch(self) -> None:
        """
        Epoch training loop. This method will be called for each epoch of
        training and iterates across batches of training data, calling the step
        method on each batch.
        """
        self.epoch_finished = False
        self.metrics.start_timer("iter")

        # enable memory allocation history, which will add tracebacks and event history to memory snapshots
        if self.config.mem_profiler == "step":
            torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries)

        batch_iterator = iter(self.train_batches)
        if self.is_resume:
            logger.info(f"Resumed from checkpoint at global step: {self.global_step}.")
            batches_to_skip = self.global_step % len(self.train_dataloader)
            logger.info(f"Advancing {batches_to_skip} batches.")
            for _ in range(batches_to_skip):
                next(batch_iterator)
            self.train_batch_idx += batches_to_skip
            self.is_resume = False

        for batch in batch_iterator:
            self.train_batch_idx += 1

            # Run the early exit checks before stepping to correctly deal with resume should the training not continue
            if self.need_early_exit():
                self.early_stop = True
                break

            self.gas_boundary = self.train_batch_idx % self.config.gradient_accumulation_steps == 0

            if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation:
                # deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately
                sample_seqlens = batch.pop("packed_sample_seqlens")
            else:
                sample_seqlens = [
                    [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size]
                    for idx in range(len(batch["input_ids"]))
                ]
            self.metrics.seqlens = sample_seqlens

            self.metrics.start_timer("step")
            self.step(batch)
            self.metrics.stop_timer("step")

            self.metrics.restart_timer("iter")

            if self.config.train_log_iter_interval != 0:
                self.metrics.print_summary()

            if self.gas_boundary:
                if (
                    self.global_rank == 0
                    and self.config.train_log_iter_interval != 0
                    and self.global_step % self.config.train_log_iter_interval == 0
                ):
                    metrics = {k: v for k, v in self.metrics.summary_dict.items()}
                    if self.ds_wall_clock_available:
                        ds_timers = self.model.get_wall_clock_timers()
                        metrics.update(ds_timers)

                    append_json_file(self.config.train_log_metrics_path, metrics)

                    # do not log the first train iteration to wandb, since it's a massive outlier
                    # on all performance metrics, which messes up the scale of the report
                    if self.wandb_experiment is not None and self.global_step > 1:
                        metrics = {k: v for k, v in metrics.items() if k not in ["iter"]}
                        self.wandb_experiment.log(metrics, step=self.global_step)

                if self.config.eval_interval != 0 and self.global_step % self.config.eval_interval == 0:
                    self.evaluate()

                    if self.is_eval_log_iter():
                        self.metrics.print_summary(prefix="eval")

                        if self.wandb_experiment is not None:
                            metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]}
                            self.wandb_experiment.log(metrics, step=self.global_step)

        self.metrics.stop_timer("iter")
        self.epoch_finished = True

Step

    @callback_wrapper("step")
    def step(self, batch: Dict[str, torch.Tensor]) -> None:
        """
        Step function for the trainer. Each batch of training data is passed to
        this method.
        """

        self.model.train()

        with deepspeed.runtime.engine.autocast_if_enabled(self.model):
            loss = self.loss(batch)

        self.backward(loss)

        def maybe_item(v):
            return v.item() if torch.is_tensor(v) else v

        self.metrics.record("loss", maybe_item(loss))

        # if neededing to debug AMoE-EP grads
        # from deepspeed.utils import safe_get_full_grad
        #
        # if hasattr(self.model_unwrapped.model.layers[1].mlp, "router_gate"):
        #     pr0("------------------------->8------------- grads ------------->8----------",
        #         force=True)
        #     pr0(
        #         f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.router_gate))=}",
        #         force=True,
        #     )
        #     pr0(
        #         f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.expert_gate_up))=}",
        #         force=True,
        #     )
        #     pr0(
        #         f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.expert_down))=}",
        #         force=True,
        #     )
        #     pr0(
        #         f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].post_attention_layernorm.weight))=}",
        #         force=True,
        #     )
        #     pr0("------------------------->8------------- grads end --------->8----------",
        #         force=True)
        # exit()

        self.model.step()

        self.checkpoint()

        # DeepSpeed increments its global step after the step() call, so we use it as the golden truth
        self.global_step = self.model.global_steps
        self.global_step_this_run = self.global_step - self.global_step_at_start_this_run

Checkpoint

    @callback_wrapper("checkpoint")
    def checkpoint(self) -> None:
        if self.global_step_this_run == 0:
            logger.info("No steps were run this run, not saving the checkpoint")
            return

        for engine in self.checkpoint_engines:
            if engine.do_checkpoint:

                if engine.name == "huggingface" and self.use_arctic_moe:
                    if self.training_finished:
                        # export to the original moe mlp format/layout - this is slow but it's the end of the training so it's fine.
                        from arctic_training.model.moe.utils import remap_arctic_moe_to_orig_moe_mlp_params

                        logger.info("Exporting to the original MoE format before saving the checkpoint")
                        remap_arctic_moe_to_orig_moe_mlp_params(self.model)
                    else:
                        raise ValueError(
                            "Currently supporting saving to HF checkpoint for AMoE models only when the training is"
                            " finished, because conversion will be very slow. For interim checkpoints use `deepspeed`"
                            " type of the checkpoint as it'd be much faster to save to and resume from. "
                        )

                logger.info(f"Saving Checkpoint at global step: {self.global_step}.")
                engine.save(self.model)

Supervised Fine-Tuning (SFT) Trainer

To help you get started with creating custom trainers, ArcticTraining includes a Supervised Fine-Tuning (SFT) trainer that demonstrates how to build a training pipeline from the base building blocks. The SFT trainer can in turn be used as a starting point and extended for creating your own custom trainers.

To create the SFT trainer, we subclass the Trainer and override the loss() method. We also define the necessary components described in Trainer Attributes. We use a custom data factory, SFTDataFactory, which we describe in greater detail in the Data Factory section. The remainder of the attributes use the base building blocks from ArcticTraining. For example the model factory defaults to the HFModelFactory (because it is listed first in the model_factory attribute type hint), but this trainer can work with either HFModelFactory or LigerModelFactory.

class SFTTrainer(Trainer):
    name = "sft"
    data_factory: SFTDataFactory
    model_factory: Union[HFModelFactory, LigerModelFactory]
    checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine]
    optimizer_factory: Union[
        FusedAdamOptimizerFactory, FusedAdamMoEOptimizerFactory, CPUAdamOptimizerFactory, CPUAdamMoEOptimizerFactory
    ]
    scheduler_factory: Union[HFSchedulerFactory]
    tokenizer_factory: Union[HFTokenizerFactory]

    def loss(self, batch) -> torch.Tensor:
        batch = to_device(batch, self.device)

        if self.config.sequence_parallel_size == 1:
            # if model.type=liger is configured - this will use a much more efficient fused
            # logits+loss liger kernel - using significantly less gpu memory and a bit faster
            # compute (liger fused logits+loss kernel does not repeat forward during backward)
            outputs = self.model(**batch, use_cache=False)
            loss = outputs.loss
            return loss

        # Ulysses SP expectations:
        # 1. batch has `labels`` replaced with `shift_labels`` (which are already preshifted in
        #    DataLoader)
        # 2. this rank deals with a seqlen dimension shard so once the loss is calculated it needs
        #    to do a differentiable weighted loss average to get the grads right

        if "labels" in batch:
            raise ValueError(
                "found labels in batch - they shouldn't be there, instead shift_labels should be there - check"
                " that UlyssesSPDataLoaderAdapter has been applied to the original DataLoader object"
            )
        if "shift_labels" not in batch:
            raise ValueError(
                "shift_labels are missing from the batch - check that UlyssesSPDataLoaderAdapter has been"
                " applied to the original DataLoader object"
            )

        shift_labels = batch["shift_labels"]

        # We have 2 implementation of efficient tiled logits+loss computation.
        # 1. Liger fused cross-entropy is the fastest/most memory efficient way - liger-kernel
        #    doesn't recompute forward inside backward, instead it computes the gradients in the
        #    forward path.
        # 2. But liger kernel isn't implemented for all HF Transformers models, so then we fall
        #    back onto our tiled logits+loss compute implementation that is almost as efficient
        #    memory-wise, but which has more compute overhead before backward re-runs forward. The
        #    total memory usage is very similar, but cuda cache flushes earlier if pushing close to
        #    OOM than liger.
        if self.config.model.type == "liger":

            # letting liger do fused logits+loss calculation
            outputs = self.model(**batch, use_cache=False)
            loss = outputs.loss

            if loss is None:
                # XXX: not sure why this happens with SP>1 and eval-enabled, I checked shift_labels contain valid non -100 tokens - disabling fused_linear_cross_entropy=False in AutoLigerKernelForCausalLM.from_pretrained doesn't help. all works when eval is off.
                raise ValueError(
                    "Liger-Kernel failed to compute loss (returned None) - it's known to fail with eval enabled along"
                    " train steps when SP>1."
                )

        else:
            # Currently relying on an automatic num_shards derivation based on the goal that it'll
            # take approximately 1GB of fp32 logits in a shard, could make this configurable if
            # desired later. Less than 1GB doesn't seem to make much of an impact, but perhaps a
            # higher number will be more efficient as it'll run less shards.
            num_shards = "auto"
            if num_shards == "auto":
                # parameterize to about 1GB fp32 logits shards
                slice_size_in_gb = 1  # XXX: make configurable?
                bs, seqlen = shift_labels.shape
                vocab_size = self.model_unwrapped.config.vocab_size
                logits_numel = bs * seqlen * vocab_size
                size_in_gb = logits_numel * 4 / 2**30  # fp32
                # the sp shard's seqlen sp shard can be easily not divisible by the derived number
                # of chunked loss shards, so we use the uppper ceiling and allow the last chunk to
                # be shorter than the rest
                num_shards = math.ceil(size_in_gb / slice_size_in_gb)
                # print(f"derived {num_shards} shards for size {size_in_gb}GB")

            model_with_head = self.model_unwrapped
            outputs = model_with_head.model(**batch, use_cache=False)
            hidden_states = outputs.last_hidden_state
            compute_params = [model_with_head.lm_head.weight]
            seqlen = shift_labels.shape[1]
            mask = None
            output_reduction = "sum"

            # since -100s shift_labels are ignored we have to perform a weighted average on each
            # loss slice as each slice may contribute a different number of non- -100 labels
            def fused_logits_loss_fn(model_with_head=None, hidden_states=None, shift_labels=None):
                vocab_size = model_with_head.config.vocab_size
                logits = model_with_head.lm_head(hidden_states)
                if all((shift_labels == -100).squeeze()):
                    # fake loss calculation, since CE will return nan, but grads will be set
                    # a normal loss_fn upcasts logits to float so match it
                    loss_sum = (logits.sum() * 0.0).float()
                else:
                    good_items = ((shift_labels != -100).squeeze()).sum()
                    loss = model_with_head.loss_function(
                        logits=logits, labels=None, vocab_size=vocab_size, shift_labels=shift_labels
                    )
                    loss_sum = loss * good_items
                return loss_sum

            total_loss_sum = TiledFusedLogitsLoss.apply(
                fused_logits_loss_fn,
                model_with_head,
                hidden_states,
                shift_labels,
                mask,
                num_shards,
                compute_params,
                output_reduction,
            )
            total_good_items = (shift_labels != -100).squeeze().sum()
            loss = total_loss_sum / max(total_good_items, 1)

        # differentiable weighted per-shard-loss aggregation across ranks
        losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=self.sp_group)
        good_tokens = ((shift_labels != -100).view(-1)).sum()
        good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=self.sp_group)
        total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(self.sp_world_size))
        total_good_tokens = sum(good_tokens_per_rank)
        loss = total_loss / total_good_tokens

        return loss