Optimizer Factory
The OptimizerFactory is responsible for generating the optimizer used
in training from the model created with the
arctic_training.model.factory.ModelFactory. An Optimizer Factory can be
created by inheriting from the OptimizerFactory class and
implementing the create_optimizer() method.
- class arctic_training.optimizer.factory.OptimizerFactory(trainer, optimizer_config=None)[source]
Bases:
ABC,CallbackMixinBase class for optimizer creation.
- Parameters:
trainer (Trainer)
optimizer_config (OptimizerConfig | None)
-
name:
str Name of the optimizer factory used for registering custom optimizer factories. This name should be unique and is used in training recipe YAMLs to identify which optimizer factory to be used.
-
config:
OptimizerConfig The type of config class that the optimizer factory uses. This should contain all optimizer-specific parameters.
- property device: str
- property model: Any
- property world_size: int
- property global_rank: int
- abstract create_optimizer(model, optimizer_config)[source]
Creates the optimizer given a model and an optimizer config.
- Return type:
Any- Parameters:
model (Any)
optimizer_config (OptimizerConfig)
Attributes
Similar to other Factory classes in ArcticTraining, the
OptimizerFactory class must have a name
attribute that is used to identify the factory when registering it with
ArcticTraining and a config attribute type hint that
is used to validate the config object passed to the factory.
Properties
OptimizerFactory has several attributes that can be used to access
information about the trainer and distributed state at runtime, including
device, trainer,
model, world_size, and
global_rank.
Methods
The OptimizerFactory has just one method that must be defined:
create_optimizer(). Given a model and optimizer config,
the method should return the optimizer.
Adam Optimizer Factory
As an example of how to create a new OptimizerFactory, we provide the
arctic_training.optimizer.factory.FusedAdamOptimizerFactory which
returns the FusedAdam optimizer from the DeepSpeed library.