Source code for arctic_training.model.factory

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import Any
from typing import Optional

from transformers import PreTrainedModel

from arctic_training.callback.mixin import CallbackMixin
from arctic_training.callback.mixin import callback_wrapper
from arctic_training.config.model import ModelConfig
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method

if TYPE_CHECKING:
    from arctic_training.trainer.trainer import Trainer


[docs] class ModelFactory(ABC, CallbackMixin, metaclass=RegistryMeta): """Base class for model creation.""" name: str """ Name of the model factory used for registering custom model factories. This name should be unique and is used in training recipe YAMLs to identify which model factory to be used. """ config: ModelConfig """ The type of config class that the model factory uses. This should contain all model-specific parameters. """ @classmethod def _validate_subclass(cls) -> None: _validate_class_attribute_set(cls, "name") _validate_class_attribute_type(cls, "config", ModelConfig) _validate_class_method(cls, "create_config", ["self"]) _validate_class_method(cls, "create_model", ["self", "model_config"]) def __init__(self, trainer: "Trainer", model_config: Optional[ModelConfig] = None) -> None: if model_config is None: model_config = trainer.config.model self._trainer = trainer self.config = model_config def __call__(self) -> PreTrainedModel: config = self.create_config() model = self.create_model(model_config=config) return model @property def trainer(self) -> "Trainer": return self._trainer @property def device(self) -> str: return self.trainer.device @property def world_size(self) -> int: return self.trainer.world_size @property def global_rank(self) -> int: return self.trainer.global_rank
[docs] @abstractmethod @callback_wrapper("create-config") def create_config(self) -> Any: """Creates the model config (e.g., huggingface model config).""" raise NotImplementedError("create_config method must be implemented")
[docs] @abstractmethod @callback_wrapper("create-model") def create_model(self, model_config) -> PreTrainedModel: """Creates the model.""" raise NotImplementedError("create_model method must be implemented")