Source code for arctic_training.scheduler.factory

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import Any
from typing import Optional

from arctic_training.callback.mixin import CallbackMixin
from arctic_training.callback.mixin import callback_wrapper
from arctic_training.config.scheduler import SchedulerConfig
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method

if TYPE_CHECKING:
    from arctic_training.trainer.trainer import Trainer


[docs] class SchedulerFactory(ABC, CallbackMixin, metaclass=RegistryMeta): """Base class for all scheduler factories.""" name: str """ The name of the scheduler factory. This is used to identify the scheduler factory in the registry. """ config: SchedulerConfig """ The configuration class for the scheduler factory. This is used to validate the configuration passed to the factory. """ @classmethod def _validate_subclass(cls) -> None: _validate_class_attribute_set(cls, "name") _validate_class_attribute_type(cls, "config", SchedulerConfig) _validate_class_method(cls, "create_scheduler", ["self", "optimizer"]) def __init__( self, trainer: "Trainer", scheduler_config: Optional[SchedulerConfig] = None, ) -> None: if scheduler_config is None: scheduler_config = trainer.config.scheduler self._trainer = trainer self.config = scheduler_config def __call__(self) -> Any: scheduler = self.create_scheduler(optimizer=self.optimizer) return scheduler @property def trainer(self) -> "Trainer": return self._trainer @property def device(self) -> str: return self.trainer.device @property def world_size(self) -> int: return self.trainer.world_size @property def global_rank(self) -> int: return self.trainer.global_rank @property def optimizer(self) -> Any: return self.trainer.optimizer
[docs] @abstractmethod @callback_wrapper("create-scheduler") def create_scheduler(self, optimizer: Any) -> Any: """Create a scheduler from the optimizer.""" raise NotImplementedError