Source code for arctic_training.optimizer.factory

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import Any
from typing import Optional

from arctic_training.callback.mixin import CallbackMixin
from arctic_training.callback.mixin import callback_wrapper
from arctic_training.config.optimizer import OptimizerConfig
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method

if TYPE_CHECKING:
    from arctic_training.trainer.trainer import Trainer


[docs] class OptimizerFactory(ABC, CallbackMixin, metaclass=RegistryMeta): """Base class for optimizer creation.""" name: str """ Name of the optimizer factory used for registering custom optimizer factories. This name should be unique and is used in training recipe YAMLs to identify which optimizer factory to be used. """ config: OptimizerConfig """ The type of config class that the optimizer factory uses. This should contain all optimizer-specific parameters. """ @classmethod def _validate_subclass(cls) -> None: _validate_class_attribute_set(cls, "name") _validate_class_attribute_type(cls, "config", OptimizerConfig) _validate_class_method(cls, "create_optimizer", ["self", "model", "optimizer_config"]) def __init__( self, trainer: "Trainer", optimizer_config: Optional[OptimizerConfig] = None, ) -> None: if optimizer_config is None: optimizer_config = trainer.config.optimizer self._trainer = trainer self.config = optimizer_config def __call__(self) -> Any: optimizer = self.create_optimizer(self.trainer.model, self.config) return optimizer @property def trainer(self) -> "Trainer": return self._trainer @property def device(self) -> str: return self.trainer.device @property def model(self) -> Any: return self.trainer.model @property def world_size(self) -> int: return self.trainer.world_size @property def global_rank(self) -> int: return self.trainer.global_rank
[docs] @abstractmethod @callback_wrapper("create-optimizer") def create_optimizer(self, model: Any, optimizer_config: "OptimizerConfig") -> Any: """Creates the optimizer given a model and an optimizer config.""" pass