Checkpoint Engine

The checkpoint engine in ArcticTraining allows you to save the model in the middle of training and/or after training has completed. Checkpoint engines can be implemented from the base CheckpointEngine class by implementing the load() and save() methods.

class arctic_training.checkpoint.engine.CheckpointEngine(trainer, config)[source]

Bases: ABC, CallbackMixin

Base class for all checkpoint engines.

Parameters:
name: str

The name of the checkpoint engine. This is used to identify the checkpoint engine in the registry.

config: CheckpointConfig

The configuration class for the checkpoint engine. This is used to validate the configuration passed to the engine.

property trainer: Trainer
property global_rank: int
property world_size: int
property device: torch.device
property epoch_finished: bool
property training_finished: bool
property do_checkpoint: bool

Checks the current state of the trainer and determines if we are at a checkpoint boundary.

property checkpoint_dir: Path

Returns the directory where the checkpoint will be saved.

property latest_checkpoint_exists: bool

Checks if the latest checkpoint exists.

abstract load(model)[source]

Loads the model weights from a checkpoint when training is resumed.

Return type:

Any

Parameters:

model (Any)

abstract save(model)[source]

Saves the model weights to a checkpoint.

Return type:

None

Parameters:

model (Any)

Attributes

Similar to the *Factory classes of ArcticTraining, the CheckpointEngine class requires only the name be defined and the config attribute type hint. The name attribute is used to identify the engine when registering it with ArcticTraining and the config attribute type hint is used to validate the config object passed to the engine.

Properties

A CheckpointEngine has several attributes that can be used to access information about the trainer and distributed state at runtime, including device, trainer, world_size, and global_rank. Additionally, the base CheckpointEngine includes some unique properties that are helpful for building new checkpoint engines, such as do_checkpoint (which checks if a checkpoint should be saved) and checkpoint_dir (which specifies the directory where the checkpoint should be saved).

Methods

CheckpointEngines have just two methods that must be defined: load() and save(). The load() method should accept an intialized model and load the model weights from an existing checkpoint. The save() method should save the model to a checkpoint directory.

HuggingFace and DeepSpeed Checkpoint Engines

While a custom checkpoint engine can be created from the CheckpointEngine, Arctic Training includes two CheckpointEngine implementations that can be used out of the box: HFCheckpointEngine and DSCheckpointEngine.

The HFCheckpointEngine will save the model in a HuggingFace Hub style using safetensor outputs. These checkpoints do no save the optimizer state and thus are not compatible with resuming training from a checkpoint. As a result, the load() method will raise an error if we attempt to load a model from this style of checkpoint. This checkpoint engine is useful for saving the model at the end of training for use with inference libraries like vLLM.

The DeepSpeedCheckpointEngine uses the checkpoint capabilities from the DeepSpeed library. These style of checkpoints save the optimizer state and can be used to resume training. This checkpoint engine is useful for saving training progress during the training loop.