Source code for arctic_training.trainer.trainer

# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from functools import cached_property
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

import deepspeed
import numpy as np
import torch
import torch.cuda
import torch.distributed.nn
import wandb
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPDataLoaderAdapter
from devtools import debug
from tqdm import tqdm
from transformers import set_seed
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from wandb.sdk.wandb_run import Run as WandbRun

from arctic_training.callback.logging import post_loss_log_cb
from arctic_training.callback.mixin import CallbackMixin
from arctic_training.callback.mixin import callback_wrapper
from arctic_training.checkpoint.engine import CheckpointEngine
from arctic_training.config.trainer import TrainerConfig
from arctic_training.data.factory import DataFactory
from arctic_training.data.utils import OverfitOneBatchDataLoader
from arctic_training.debug.utils import pr0
from arctic_training.debug.utils import see_memory_usage
from arctic_training.logging import logger
from arctic_training.metrics import Metrics
from arctic_training.model.factory import ModelFactory
from arctic_training.model.moe.utils import amoe_install_deepspeed_timers
from arctic_training.model.moe.utils import detect_if_moe_model
from arctic_training.model.tiled_compute import enable_tiled_mlp_compute
from arctic_training.optimizer.factory import OptimizerFactory
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method
from arctic_training.scheduler.factory import SchedulerFactory
from arctic_training.tokenizer.factory import TokenizerFactory
from arctic_training.utils import append_json_file


[docs] class Trainer(ABC, CallbackMixin, metaclass=RegistryMeta): """Base Trainer class.""" name: str """ Name of the trainer used for registering custom trainers. This name should be unique and is used in the training recipe YAMLs to identify which trainer to be used. """ config: TrainerConfig """ The type of the config class that the trainer uses. This should be a subclass of TrainerConfig and add any trainer-specific fields. """ data_factory: DataFactory """ A List of valid data factory types that the trainer can use. These should inherit from DataFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ model_factory: ModelFactory """ A List of valid model factory types that the trainer can use. These should inherit from ModelFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ checkpoint_engine: CheckpointEngine """ A List of valid checkpoint engine types that the trainer can use. These should inherit from CheckpointEngine. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ optimizer_factory: OptimizerFactory """ A List of valid optimizer factory types that the trainer can use. These should inherit from OptimizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ scheduler_factory: SchedulerFactory """ A List of valid scheduler factory types that the trainer can use. These should inherit from SchedulerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ tokenizer_factory: TokenizerFactory """ A List of valid tokenizer factory types that the trainer can use. These should inherit from TokenizerFactory. The first item in the list will be used as the default if the type is not explicitly set in the YAML config. """ callbacks: List[Tuple[str, Callable]] = [ post_loss_log_cb, ] """ A list of callbacks for the trainer. Callbacks are specified as tuples of a string indicating where the callback should be placed and a callable that implements the callback. Callback events for the trainer include `pre-` and `post-` for `init`, `train`, `epoch`, `step`, and `checkpoint`. """ @classmethod def _validate_subclass(cls) -> None: _validate_class_attribute_set(cls, "name") _validate_class_attribute_type(cls, "config", TrainerConfig) _validate_class_attribute_type(cls, "data_factory", DataFactory) _validate_class_attribute_type(cls, "model_factory", ModelFactory) _validate_class_attribute_type(cls, "checkpoint_engine", CheckpointEngine) _validate_class_attribute_type(cls, "optimizer_factory", OptimizerFactory) _validate_class_attribute_type(cls, "scheduler_factory", SchedulerFactory) _validate_class_attribute_type(cls, "tokenizer_factory", TokenizerFactory) _validate_class_method(cls, "loss", ["self", "batch"]) _validate_class_method(cls, "step", ["self", "batch"]) _validate_class_method(cls, "epoch", ["self"]) _validate_class_method(cls, "train", ["self"]) _validate_class_method(cls, "checkpoint", ["self"]) def __init__(self, config: TrainerConfig, mode: str = "train") -> None: logger.info(f"Initializing Trainer with config:\n{debug.format(config)}") self.config = config self.epoch_idx = 0 self.train_batch_idx = 0 self.global_step = 0 self.global_step_this_run = 0 self.global_step_at_start_this_run = 0 self.early_stop = False self.early_stop_reason = "" self.world_size = config.world_size self.global_rank = config.global_rank self.epoch_finished = False self.training_finished = False self.wandb_experiment: Optional[WandbRun] = None self.is_resume = False # Track if we resumed from ckpt self.wandb_run_id = None self._set_seeds(self.config.seed) if self.config.mem_profiler == "e2e": torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries) tokenizer_factory = self.config.tokenizer.factory(self) self.tokenizer = tokenizer_factory() data_factory = self.config.data.factory(self) self.train_dataloader, self.eval_dataloader = data_factory() if mode == "process-data": return if self.config.overfit_first_batch: self.train_dataloader = OverfitOneBatchDataLoader(self.train_dataloader) # checkpointing and resume self.checkpoint_engines = [engine(self) for engine in self.config.checkpoint_engines] for engine in self.checkpoint_engines: # currently only deepspeed engine supports resume from intermediate checkpoint if engine.name == "deepspeed" and engine.config.auto_resume and engine.latest_checkpoint_exists: self.is_resume = True # XXX: We can abstract this section further with AT-specific wrapper, but # UlyssesSPAttentionHF should not have any AT-specific objects / assumptions mpu = UlyssesSPAttentionHF.register_with_transformers( model_name_or_path=self.config.model.name_or_path, core_attn_implementation=self.config.model.attn_implementation, sequence_parallel_size=self.config.sequence_parallel_size, micro_batch_size=self.config.micro_batch_size, seq_length=self.config.data.max_length, seq_length_is_variable=True, ) # Important: this is most likely not beneficial under seqlen=64k if self.config.activation_checkpoint_cpu_offload: # activation_checkpointing_cpu_offload becomes very benefitial at very long seqlen # e.g., llama 8b at 800k (100k effective per gpu) will save 24GB per gpu: # ((100_000*4096)*2*32/2**30), but for short sequences the offload will just slow things # down, # # XXX: could parameterize or run a few lengths to see at which threshold it becomes # beneficial - a user might still want this on even at shorter seqlen if they don't # mind slower performance. discussing adding this functionality to pytorch core # (https://pytorch.slack.com/archives/C3PDTEV8E/p1745274102600729) from arctic_training.monkey_patches import monkey_patch_checkpoint_function_with_cpu_offload monkey_patch_checkpoint_function_with_cpu_offload() # MLP tiling - has to happen before model is instantiated if self.config.tiled_mlp_compute: enable_tiled_mlp_compute(self.config.model.name_or_path) dschf = HfDeepSpeedConfig(self.config.deepspeed) # noqa: F841 model_factory = self.config.model.factory(self) self.model = model_factory() self.count_model_params_in_original_model() # prevent causal mask from being created in HF Transformers - it's a huge `[bs, seqlen, seqlen]` tensor # XXX: This should also benefit a single gpu use case when SDPA is used - so perhaps remove the SP>1 check? if self.config.sequence_parallel_size > 1 and self.config.model.attn_implementation not in [ "flash_attention_2", "flash_attention_3", ]: import transformers.masking_utils transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", lambda *args, **kwargs: None) # Arctic MoE model remapping has to be called before an optimizer is created if self.config.arctic_moe == "auto": self.use_arctic_moe = detect_if_moe_model(self.model) else: self.use_arctic_moe = self.config.arctic_moe if self.use_arctic_moe: pr0("Activating ArcticMoE", force=False) import deepspeed.comm as dist from deepspeed.utils import groups from arctic_training.model.moe.utils import monkey_patch_ds_moe from arctic_training.model.moe.utils import remap_orig_moe_mlp_params_to_arctic_moe if not dist.is_initialized(): dist.init_distributed(dist_backend="nccl", dist_init_required=True) monkey_patch_ds_moe() # deepspeed.runtime.engine.DeepSpeedEngine.print_forward_breakdown = print_forward_breakdown # DeepspeedMoE is only integrated with ZeRO-2 zero_stage = self.config.deepspeed.get("zero_optimization", {}).get("stage", 0) if zero_stage != 2: raise ValueError( "at the moment Deepspeed supports only ZeRO stage 2 with MoE, but the configuration asks for ZeRO" f" stage={zero_stage}" ) # this config comes from use_data_before_expert_parallelism ds config which defaults to False # engine._config.use_data_before_expert_parallel_) # but we don't have the engine yet to get the ds config values - perhaps could extract this via AT-config? use_data_before_expert_parallel_ = False # the ep group has to be created before remap_orig_moe_mlp_params_to_arctic_moe as ep rank info is needed to remap pre-trained experts groups._create_expert_data_and_model_parallel( self.config.expert_parallel_size, mpu=None, use_data_before_expert_parallel_=use_data_before_expert_parallel_, ) # self.groups = ParallelGroups(expert_parallel_size=self.config.expert_parallel_size) # we sort out if we are in resume mode much later, by actually trying to load the model, but that's too late so we are going to rely on testing if the latest checkpoint exists instead # early_is_resume = False # for engine in self.checkpoint_engines: # # currently only deepspeed engine supports resume from intermediate checkpoint # if engine.name == "deepspeed" and engine.config.auto_resume and engine.latest_checkpoint_exists: # early_is_resume = True remap_orig_moe_mlp_params_to_arctic_moe( self.model, ep_size=self.config.expert_parallel_size, is_resume=self.is_resume, enable_custom_moe_kernel=self.config.enable_arctic_moe_custom_optimization, enable_routing_replay=self.config.enable_routing_replay, ) # self.groups) # XXX: check we can remap back # from arctic_training.model.moe.utils import remap_arctic_moe_params_to_orig_moe_mlp # remap_arctic_moe_params_to_orig_moe_mlp(self.model) see_memory_usage("after moe remap", force=False) # this is an optional debug instrumentation to trace overflows in params/grads # # inspectors are important to call after all model tweaks are done (e.g. after AMoE) # # from arctic_training.debug.underflow_overflow import DebugUnderflowOverflow # debug_overflow = DebugUnderflowOverflow(self.model, max_frames_to_save=100) # noqa # # from arctic_training.debug.underflow_overflow import DebugGradients # debug_grads = DebugGradients(self.model, trace_batch_nums=[1], max_frames_to_save=100) # noqa optimizer_factory = self.config.optimizer.factory(self) self.optimizer = optimizer_factory() scheduler_factory = self.config.scheduler.factory(self) self.scheduler = scheduler_factory() self.model, *_ = deepspeed.initialize( model=self.model, optimizer=self.optimizer, args=self.config, lr_scheduler=self.scheduler, config=self.config.deepspeed, mpu=mpu, ) if self.use_arctic_moe: amoe_install_deepspeed_timers(self.model, self.model_unwrapped) self.ds_wall_clock_available = hasattr(self.model, "get_wall_clock_timers") if self.config.sequence_parallel_size > 1: # deepspeed.initialize needs to run first from deepspeed.utils import groups # set SP-trainer attributes to be used later self.sp_group = groups._get_sequence_parallel_group() self.sp_world_size = groups._get_sequence_parallel_world_size() self.sp_rank = groups._get_sequence_parallel_rank() # wrap the DL with Ulysses one self.train_dataloader = UlyssesSPDataLoaderAdapter( self.train_dataloader, sp_rank=self.sp_rank, sp_group=self.sp_group, sp_world_size=self.sp_world_size, device=self.device, ) if self.eval_dataloader is not None: self.eval_dataloader = UlyssesSPDataLoaderAdapter( self.eval_dataloader, sp_rank=self.sp_rank, sp_group=self.sp_group, sp_world_size=self.sp_world_size, device=self.device, ) for engine in self.checkpoint_engines: if engine.config.auto_resume: engine.load(self.model) self.metrics = Metrics(self) if self.global_rank == 0 and self.config.wandb.enable: # in order for resume to continue the same wandb run we need to re-use a run_id from the previous run if self.wandb_run_id is None: self.wandb_run_id = wandb.util.generate_id() # type: ignore # Note: wandb.init() is not type annotated so we need to use type: ignore self.wandb_experiment = wandb.init( # type: ignore id=self.wandb_run_id, entity=self.config.wandb.entity, project=self.config.wandb.project, name=self.config.wandb.name, config=self.config.model_dump(), # do not put `wandb` in the root of the repo as it conflicts with wandb package dir=f"{self.config.logger.output_dir}/wandb", ) def _set_seeds(self, seed: int) -> None: logger.info(f"Setting random seeds to {seed}") torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) set_seed(seed) @property def model_unwrapped(self): """Return the original model before it was wrapped by deepspeed""" if hasattr(self.model, "module"): return self.model.module else: return self.model @property def epochs(self) -> tqdm: """Epochs iterator.""" total_epochs = self.config.epochs if self.config.train_iters: total_epochs = math.ceil( self.config.train_iters * self.config.gradient_accumulation_steps / len(self.train_dataloader) ) return tqdm( range(self.epoch_idx, total_epochs), desc="Epochs", unit="epoch", disable=(self.global_rank != 0) or (self.config.train_log_iter_interval != 0), ) @property def train_batches(self) -> tqdm: """Training data iterator.""" return tqdm( self.train_dataloader, desc="Train Batches", unit="batch", disable=(self.global_rank != 0) or (self.config.train_log_iter_interval != 0), ) @property def eval_batches(self) -> tqdm: """Evaluation data iterator.""" return tqdm( self.eval_dataloader, desc="Eval Batches", unit="batch", disable=self.global_rank != 0 or not self.is_eval_log_iter(), )
[docs] def is_eval_log_iter(self) -> bool: return self.global_step // self.config.eval_interval % self.config.eval_log_iter_interval == 0
@cached_property def device(self) -> torch.device: """Current device.""" return torch.device(get_accelerator().device_name(self.config.local_rank)) @property def training_horizon(self) -> int: """Total number of training iterations.""" if self.train_dataloader is None: raise ValueError("Train dataloader not initialized.") if self.config.train_iters: return self.config.train_iters # XXX: this was incorrect for GAS return self.config.epochs * len(self.train_dataloader) # // self.config.gradient_accumulation_steps
[docs] @callback_wrapper("loss") @abstractmethod def loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Loss function for the trainer. This method should be implemented by the inheriting trainer class. """ raise NotImplementedError("Loss method must be implemented by the trainer.")
[docs] @callback_wrapper("backward") def backward(self, loss: torch.Tensor) -> None: """ Backward function for the trainer. This method is called after the loss method and is responsible for backpropagating the loss through the model. """ self.model.backward(loss)
[docs] def need_early_exit(self): """ If we need to exit early, set `self.early_stop_reason` and return True Otherwise return False """ # exit conditions in the order of likelyhood if ( self.config.exit_iteration_this_run > 0 and self.config.exit_iteration_this_run == self.global_step_this_run ): self.early_stop_reason = f"reached exit_iteration_this_run of {self.global_step_this_run}" return True elif self.config.exit_iteration > 0 and self.config.exit_iteration == self.global_step: self.early_stop_reason = f"reached exit_iteration of {self.global_step}" return True elif self.config.kill_switch_path.exists(): self.early_stop_reason = f"detected kill switch {self.config.kill_switch_path}" return True elif self.global_step >= self.training_horizon: self.early_stop_reason = f"reached training_horizon of {self.global_step}" return True return False
[docs] @callback_wrapper("step") def step(self, batch: Dict[str, torch.Tensor]) -> None: """ Step function for the trainer. Each batch of training data is passed to this method. """ self.model.train() with deepspeed.runtime.engine.autocast_if_enabled(self.model): loss = self.loss(batch) self.backward(loss) def maybe_item(v): return v.item() if torch.is_tensor(v) else v self.metrics.record("loss", maybe_item(loss)) # if neededing to debug AMoE-EP grads # from deepspeed.utils import safe_get_full_grad # # if hasattr(self.model_unwrapped.model.layers[1].mlp, "router_gate"): # pr0("------------------------->8------------- grads ------------->8----------", # force=True) # pr0( # f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.router_gate))=}", # force=True, # ) # pr0( # f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.expert_gate_up))=}", # force=True, # ) # pr0( # f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.expert_down))=}", # force=True, # ) # pr0( # f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].post_attention_layernorm.weight))=}", # force=True, # ) # pr0("------------------------->8------------- grads end --------->8----------", # force=True) # exit() self.model.step() self.checkpoint() # DeepSpeed increments its global step after the step() call, so we use it as the golden truth self.global_step = self.model.global_steps self.global_step_this_run = self.global_step - self.global_step_at_start_this_run
[docs] @callback_wrapper("epoch") def epoch(self) -> None: """ Epoch training loop. This method will be called for each epoch of training and iterates across batches of training data, calling the step method on each batch. """ self.epoch_finished = False self.metrics.start_timer("iter") # enable memory allocation history, which will add tracebacks and event history to memory snapshots if self.config.mem_profiler == "step": torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries) batch_iterator = iter(self.train_batches) if self.is_resume: logger.info(f"Resumed from checkpoint at global step: {self.global_step}.") batches_to_skip = self.global_step % len(self.train_dataloader) logger.info(f"Advancing {batches_to_skip} batches.") for _ in range(batches_to_skip): next(batch_iterator) self.train_batch_idx += batches_to_skip self.is_resume = False for batch in batch_iterator: self.train_batch_idx += 1 # Run the early exit checks before stepping to correctly deal with resume should the training not continue if self.need_early_exit(): self.early_stop = True break self.gas_boundary = self.train_batch_idx % self.config.gradient_accumulation_steps == 0 if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation: # deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately sample_seqlens = batch.pop("packed_sample_seqlens") else: sample_seqlens = [ [len(batch["input_ids"][idx]) * self.config.sequence_parallel_size] for idx in range(len(batch["input_ids"])) ] self.metrics.seqlens = sample_seqlens self.metrics.start_timer("step") self.step(batch) self.metrics.stop_timer("step") self.metrics.restart_timer("iter") if self.config.train_log_iter_interval != 0: self.metrics.print_summary() if self.gas_boundary: if ( self.global_rank == 0 and self.config.train_log_iter_interval != 0 and self.global_step % self.config.train_log_iter_interval == 0 ): metrics = {k: v for k, v in self.metrics.summary_dict.items()} if self.ds_wall_clock_available: ds_timers = self.model.get_wall_clock_timers() metrics.update(ds_timers) append_json_file(self.config.train_log_metrics_path, metrics) # do not log the first train iteration to wandb, since it's a massive outlier # on all performance metrics, which messes up the scale of the report if self.wandb_experiment is not None and self.global_step > 1: metrics = {k: v for k, v in metrics.items() if k not in ["iter"]} self.wandb_experiment.log(metrics, step=self.global_step) if self.config.eval_interval != 0 and self.global_step % self.config.eval_interval == 0: self.evaluate() if self.is_eval_log_iter(): self.metrics.print_summary(prefix="eval") if self.wandb_experiment is not None: metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]} self.wandb_experiment.log(metrics, step=self.global_step) self.metrics.stop_timer("iter") self.epoch_finished = True
[docs] @callback_wrapper("train") def train(self) -> None: """ Main training loop. Calls the epoch method for each epoch of training. """ self.print_model_parameters_header() try: # to be able to keep track of number of steps of this run inside step() self.global_step_at_start_this_run = self.global_step for epoch_idx in self.epochs: self.epoch_idx = epoch_idx self.epoch() if self.early_stop: break self.checkpoint() self.training_finished = True if self.global_rank == 0: if self.early_stop: print(f"*** Exiting training early because training {self.early_stop_reason}") else: print("*** Training finished normally.") self.checkpoint() except Exception as e: logger.error(f"Training failed with error: {e}") # logger.info(f"{self._trainer_state}") raise (e) finally: if self.config.mem_profiler is not None: torch.cuda.memory._dump_snapshot(self.config.mem_profiler_dir / f"{self.global_rank}.pickle") if self.wandb_experiment is not None: self.wandb_experiment.finish()
[docs] @callback_wrapper("evaluate") def evaluate(self) -> None: """ Evaluation loop. Measures the model's performance on the evaluation dataset. """ self.model.eval() with torch.no_grad(): losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches] self.metrics.record("loss/eval", losses) # type: ignore
[docs] @callback_wrapper("checkpoint") def checkpoint(self) -> None: if self.global_step_this_run == 0: logger.info("No steps were run this run, not saving the checkpoint") return for engine in self.checkpoint_engines: if engine.do_checkpoint: if engine.name == "huggingface" and self.use_arctic_moe: if self.training_finished: # export to the original moe mlp format/layout - this is slow but it's the end of the training so it's fine. from arctic_training.model.moe.utils import remap_arctic_moe_to_orig_moe_mlp_params logger.info("Exporting to the original MoE format before saving the checkpoint") remap_arctic_moe_to_orig_moe_mlp_params(self.model) else: raise ValueError( "Currently supporting saving to HF checkpoint for AMoE models only when the training is" " finished, because conversion will be very slow. For interim checkpoints use `deepspeed`" " type of the checkpoint as it'd be much faster to save to and resume from. " ) logger.info(f"Saving Checkpoint at global step: {self.global_step}.") engine.save(self.model)
[docs] def count_model_parameters(self): """Returns a dictionary containing "total" and "trainable" parameters.""" sizes = defaultdict(int) def numel_fn(p): return p.ds_numel if hasattr(p, "ds_numel") else p.numel() for param in self.model.parameters(): numel = numel_fn(param) sizes["total"] += numel if param.requires_grad: sizes["trainable"] += numel # Converting defaultdict --> dict for nicer printing. return dict(sizes)
[docs] def count_model_params_in_original_model(self): """This counts total params in the model before it got sliced into MoE EP slices""" # XXX: perhaps add a new class for various stats? or may be add to metrics class? if torch.distributed.get_rank() == 0: self.original_hf_model_params = self.count_model_parameters()
[docs] def print_model_parameters_header(self): """Always print stats about the model we are about to train on rank 0""" if torch.distributed.get_rank() != 0: return orig_model_params = self.original_hf_model_params curr_model_params = self.count_model_parameters() world_size = self.world_size gas = self.config.gradient_accumulation_steps mbs = self.config.micro_batch_size gbs = mbs * gas * world_size header = f""" ------------------------------------- Original model: {self.config.model.name_or_path} - Total params : {orig_model_params["total"]:,} ({orig_model_params["total"]/1e9:0.2f}B) - Trainable params: {orig_model_params["trainable"]:,} ({orig_model_params["trainable"]/1e9:.2f}B) """ # XXX: if possible add MoE passive/activate params breakdown if AMoE is used? if self.config.expert_parallel_size > 1: # EP>1 spreads the experts across ranks header += f""" Rank 0 model with EP={self.config.expert_parallel_size}: - Total params : {curr_model_params["total"]:,} ({curr_model_params["total"]/1e9:0.2f}B) - Trainable params: {curr_model_params["trainable"]:,} ({curr_model_params["trainable"]/1e9:.2f}B) """ # DP is world size w/ EP>1 and SP>1 (but this might change with other parallelism) header += f""" Parallelism: - EP: {self.config.expert_parallel_size} - SP: {self.config.sequence_parallel_size} - DP: {world_size} """ header += f""" Maximum number of optimizer steps: {self.config.exit_iteration} Maximum number of epochs: {self.config.epochs} Number of gradient accumulation steps: {gas} Number of processes: {world_size} Batch sizes: - Micro batch size: {mbs} - Global batch size: {gbs} ------------------------------------- """ print(header)