Model Factory
The Model Factory is responsible for generating the model used in training from
the ModelConfig. A Model Factory can be created by inheriting from the
ModelFactory class and implementing the
create_config() and create_model()
methods.
- class arctic_training.model.factory.ModelFactory(trainer, model_config=None)[source]
Bases:
ABC,CallbackMixinBase class for model creation.
- Parameters:
trainer (Trainer)
model_config (ModelConfig | None)
-
name:
str Name of the model factory used for registering custom model factories. This name should be unique and is used in training recipe YAMLs to identify which model factory to be used.
-
config:
ModelConfig The type of config class that the model factory uses. This should contain all model-specific parameters.
- property device: str
- property world_size: int
- property global_rank: int
Attributes
Similar to other Factory classes in ArcticTraining, the ModelFactory class must
have a name attribute that is used to identify the
factory when registering it with ArcticTraining and a
config attribute type hint that is used to validate the
config object passed to the factory.
Properties
A ModelFactory has several attributes that can be used to access information
about the trainer and distributed state at runtime, including
device, trainer,
world_size, and global_rank.
Methods
ModelFactories have just two methods that must be defined:
create_config() and create_model().
The create_config() method should return a config object
that can be used to generate the desired model and the
create_model() method should return the model object
created using the generated config.
HuggingFace Style Factories
A custom model factory can be created from the ModelFactory building
block, but ArcticTraining also comes with two ModelFactory implementations that
can be used out of the box: HFModelFactory and
LigerModelFactory. Each of these will load models from HuggingFace
Hub given a path to a local repo or the model name. The
LigerModelFactory extends HFModelFactory and adds support
for using optimizations in the Liger-Kernel
library.