Model Factory

The Model Factory is responsible for generating the model used in training from the ModelConfig. A Model Factory can be created by inheriting from the ModelFactory class and implementing the create_config() and create_model() methods.

class arctic_training.model.factory.ModelFactory(trainer, model_config=None)[source]

Bases: ABC, CallbackMixin

Base class for model creation.

Parameters:
name: str

Name of the model factory used for registering custom model factories. This name should be unique and is used in training recipe YAMLs to identify which model factory to be used.

config: ModelConfig

The type of config class that the model factory uses. This should contain all model-specific parameters.

property trainer: Trainer
property device: str
property world_size: int
property global_rank: int
abstract create_config()[source]

Creates the model config (e.g., huggingface model config).

Return type:

Any

abstract create_model(model_config)[source]

Creates the model.

Return type:

PreTrainedModel

Attributes

Similar to other Factory classes in ArcticTraining, the ModelFactory class must have a name attribute that is used to identify the factory when registering it with ArcticTraining and a config attribute type hint that is used to validate the config object passed to the factory.

Properties

A ModelFactory has several attributes that can be used to access information about the trainer and distributed state at runtime, including device, trainer, world_size, and global_rank.

Methods

ModelFactories have just two methods that must be defined: create_config() and create_model(). The create_config() method should return a config object that can be used to generate the desired model and the create_model() method should return the model object created using the generated config.

HuggingFace Style Factories

A custom model factory can be created from the ModelFactory building block, but ArcticTraining also comes with two ModelFactory implementations that can be used out of the box: HFModelFactory and LigerModelFactory. Each of these will load models from HuggingFace Hub given a path to a local repo or the model name. The LigerModelFactory extends HFModelFactory and adds support for using optimizations in the Liger-Kernel library.