Checkpoint Engine
The checkpoint engine in ArcticTraining allows you to save the model in the
middle of training and/or after training has completed. Checkpoint engines can
be implemented from the base CheckpointEngine class by implementing
the load() and save() methods.
- class arctic_training.checkpoint.engine.CheckpointEngine(trainer, config)[source]
Bases:
ABC,CallbackMixinBase class for all checkpoint engines.
- Parameters:
trainer (Trainer)
config (CheckpointConfig)
-
name:
str The name of the checkpoint engine. This is used to identify the checkpoint engine in the registry.
-
config:
CheckpointConfig The configuration class for the checkpoint engine. This is used to validate the configuration passed to the engine.
- property global_rank: int
- property world_size: int
- property device: torch.device
- property epoch_finished: bool
- property training_finished: bool
- property do_checkpoint: bool
Checks the current state of the trainer and determines if we are at a checkpoint boundary.
- property checkpoint_dir: Path
Returns the directory where the checkpoint will be saved.
- property latest_checkpoint_exists: bool
Checks if the latest checkpoint exists.
Attributes
Similar to the *Factory classes of ArcticTraining, the CheckpointEngine
class requires only the name be defined and the
config attribute type hint. The name attribute is
used to identify the engine when registering it with ArcticTraining and the
config attribute type hint is used to validate the config object passed to
the engine.
Properties
A CheckpointEngine has several attributes that can be used to access information
about the trainer and distributed state at runtime, including
device, trainer,
world_size, and
global_rank. Additionally, the base
CheckpointEngine includes some unique properties that are helpful for
building new checkpoint engines, such as
do_checkpoint (which checks if a checkpoint should be
saved) and checkpoint_dir (which specifies the
directory where the checkpoint should be saved).
Methods
CheckpointEngines have just two methods that must be defined:
load() and save(). The
load() method should accept an intialized model and
load the model weights from an existing checkpoint. The
save() method should save the model to a checkpoint
directory.
HuggingFace and DeepSpeed Checkpoint Engines
While a custom checkpoint engine can be created from the
CheckpointEngine, Arctic Training includes two CheckpointEngine
implementations that can be used out of the box: HFCheckpointEngine
and DSCheckpointEngine.
The HFCheckpointEngine will save the model in a HuggingFace Hub style
using safetensor outputs. These checkpoints do no save the optimizer state
and thus are not compatible with resuming training from a checkpoint. As a
result, the load() method will raise an error if we
attempt to load a model from this style of checkpoint. This checkpoint engine is
useful for saving the model at the end of training for use with inference
libraries like vLLM.
The DeepSpeedCheckpointEngine uses the checkpoint capabilities from
the DeepSpeed library. These style of checkpoints save the optimizer state and
can be used to resume training. This checkpoint engine is useful for saving
training progress during the training loop.