Source code for arctic_training.checkpoint.engine

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any

import torch

from arctic_training.callback.mixin import CallbackMixin
from arctic_training.callback.mixin import callback_wrapper
from arctic_training.config.checkpoint import CheckpointConfig
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method

if TYPE_CHECKING:
    from arctic_training.trainer.trainer import Trainer


[docs] class CheckpointEngine(ABC, CallbackMixin, metaclass=RegistryMeta): """Base class for all checkpoint engines.""" name: str """ The name of the checkpoint engine. This is used to identify the checkpoint engine in the registry. """ config: CheckpointConfig """ The configuration class for the checkpoint engine. This is used to validate the configuration passed to the engine. """ @classmethod def _validate_subclass(cls) -> None: _validate_class_attribute_set(cls, "name") _validate_class_attribute_type(cls, "config", CheckpointConfig) _validate_class_method(cls, "load", ["self", "model"]) _validate_class_method(cls, "save", ["self", "model"]) def __init__(self, trainer: "Trainer", config: CheckpointConfig) -> None: self._trainer = trainer self.config = config @property def trainer(self) -> "Trainer": return self._trainer @property def global_rank(self) -> int: return self.trainer.global_rank @property def world_size(self) -> int: return self.trainer.world_size @property def device(self) -> torch.device: return self.trainer.device @property def epoch_finished(self) -> bool: return self.trainer.epoch_finished @property def training_finished(self) -> bool: return self.trainer.training_finished @property def do_checkpoint(self) -> bool: """ Checks the current state of the trainer and determines if we are at a checkpoint boundary. """ if not self.config.enabled: return False return_value = False if self.trainer.model.is_gradient_accumulation_boundary() and self.config.save_every_n_steps: return_value = self.trainer.global_step % self.config.save_every_n_steps == 0 if self.config.save_every_n_epochs: return_value = return_value or ( self.epoch_finished and (self.trainer.epoch_idx % self.config.save_every_n_epochs) == 0 ) if self.config.save_end_of_training: return_value = return_value or self.training_finished return return_value @property def checkpoint_dir(self) -> Path: """Returns the directory where the checkpoint will be saved.""" checkpoint_dir = self.config.output_dir / f"global_step_{self.trainer.global_step}" checkpoint_dir.mkdir(parents=True, exist_ok=True) return checkpoint_dir @property def latest_checkpoint_exists(self) -> bool: """Checks if the latest checkpoint exists.""" raise NotImplementedError("latest_checkpoint_exists method must be implemented")
[docs] @abstractmethod @callback_wrapper("load") def load(self, model: Any) -> Any: """Loads the model weights from a checkpoint when training is resumed.""" raise NotImplementedError("load method must be implemented")
[docs] @abstractmethod @callback_wrapper("save") def save(self, model: Any) -> None: """Saves the model weights to a checkpoint.""" raise NotImplementedError("save method must be implemented")