# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from functools import cached_property
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import deepspeed
import numpy as np
import torch
import torch.cuda
import torch.distributed.nn
import wandb
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPDataLoaderAdapter
from devtools import debug
from tqdm import tqdm
from transformers import set_seed
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from wandb.sdk.wandb_run import Run as WandbRun
from arctic_training.callback.logging import post_loss_log_cb
from arctic_training.callback.mixin import CallbackMixin
from arctic_training.callback.mixin import callback_wrapper
from arctic_training.checkpoint.engine import CheckpointEngine
from arctic_training.config.trainer import TrainerConfig
from arctic_training.data.factory import DataFactory
from arctic_training.data.utils import OverfitOneBatchDataLoader
from arctic_training.debug.utils import pr0
from arctic_training.debug.utils import see_memory_usage
from arctic_training.logging import logger
from arctic_training.metrics import Metrics
from arctic_training.model.factory import ModelFactory
from arctic_training.model.moe.utils import amoe_install_deepspeed_timers
from arctic_training.model.moe.utils import detect_if_moe_model
from arctic_training.model.tiled_compute import enable_tiled_mlp_compute
from arctic_training.optimizer.factory import OptimizerFactory
from arctic_training.registry import RegistryMeta
from arctic_training.registry import _validate_class_attribute_set
from arctic_training.registry import _validate_class_attribute_type
from arctic_training.registry import _validate_class_method
from arctic_training.scheduler.factory import SchedulerFactory
from arctic_training.tokenizer.factory import TokenizerFactory
from arctic_training.utils import append_json_file
[docs]
class Trainer(ABC, CallbackMixin, metaclass=RegistryMeta):
"""Base Trainer class."""
name: str
"""
Name of the trainer used for registering custom trainers. This name
should be unique and is used in the training recipe YAMLs to identify which
trainer to be used.
"""
config: TrainerConfig
"""
The type of the config class that the trainer uses. This should be a
subclass of TrainerConfig and add any trainer-specific fields.
"""
data_factory: DataFactory
"""
A List of valid data factory types that the trainer can use. These should
inherit from DataFactory. The first item in the list will be used as the
default if the type is not explicitly set in the YAML config.
"""
model_factory: ModelFactory
"""
A List of valid model factory types that the trainer can use. These should
inherit from ModelFactory. The first item in the list will be used as the
default if the type is not explicitly set in the YAML config.
"""
checkpoint_engine: CheckpointEngine
"""
A List of valid checkpoint engine types that the trainer can use. These
should inherit from CheckpointEngine. The first item in the list will be
used as the default if the type is not explicitly set in the YAML config.
"""
optimizer_factory: OptimizerFactory
"""
A List of valid optimizer factory types that the trainer can use. These
should inherit from OptimizerFactory. The first item in the list will be
used as the default if the type is not explicitly set in the YAML config.
"""
scheduler_factory: SchedulerFactory
"""
A List of valid scheduler factory types that the trainer can use. These
should inherit from SchedulerFactory. The first item in the list will be
used as the default if the type is not explicitly set in the YAML config.
"""
tokenizer_factory: TokenizerFactory
"""
A List of valid tokenizer factory types that the trainer can use. These
should inherit from TokenizerFactory. The first item in the list will be
used as the default if the type is not explicitly set in the YAML config.
"""
callbacks: List[Tuple[str, Callable]] = [
post_loss_log_cb,
]
"""
A list of callbacks for the trainer. Callbacks are specified as tuples of a
string indicating where the callback should be placed and a callable that
implements the callback. Callback events for the trainer include `pre-` and
`post-` for `init`, `train`, `epoch`, `step`, and `checkpoint`.
"""
@classmethod
def _validate_subclass(cls) -> None:
_validate_class_attribute_set(cls, "name")
_validate_class_attribute_type(cls, "config", TrainerConfig)
_validate_class_attribute_type(cls, "data_factory", DataFactory)
_validate_class_attribute_type(cls, "model_factory", ModelFactory)
_validate_class_attribute_type(cls, "checkpoint_engine", CheckpointEngine)
_validate_class_attribute_type(cls, "optimizer_factory", OptimizerFactory)
_validate_class_attribute_type(cls, "scheduler_factory", SchedulerFactory)
_validate_class_attribute_type(cls, "tokenizer_factory", TokenizerFactory)
_validate_class_method(cls, "loss", ["self", "batch"])
_validate_class_method(cls, "step", ["self", "batch"])
_validate_class_method(cls, "epoch", ["self"])
_validate_class_method(cls, "train", ["self"])
_validate_class_method(cls, "checkpoint", ["self"])
def __init__(self, config: TrainerConfig, mode: str = "train") -> None:
logger.info(f"Initializing Trainer with config:\n{debug.format(config)}")
self.config = config
self.epoch_idx = 0
self.train_batch_idx = 0
self.global_step = 0
self.global_step_this_run = 0
self.global_step_at_start_this_run = 0
self.early_stop = False
self.early_stop_reason = ""
self.world_size = config.world_size
self.global_rank = config.global_rank
self.epoch_finished = False
self.training_finished = False
self.wandb_experiment: Optional[WandbRun] = None
self.is_resume = False # Track if we resumed from ckpt
self.wandb_run_id = None
self._set_seeds(self.config.seed)
if self.config.mem_profiler == "e2e":
torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries)
tokenizer_factory = self.config.tokenizer.factory(self)
self.tokenizer = tokenizer_factory()
data_factory = self.config.data.factory(self)
self.train_dataloader, self.eval_dataloader = data_factory()
if mode == "process-data":
return
if self.config.overfit_first_batch:
self.train_dataloader = OverfitOneBatchDataLoader(self.train_dataloader)
# checkpointing and resume
self.checkpoint_engines = [engine(self) for engine in self.config.checkpoint_engines]
for engine in self.checkpoint_engines:
# currently only deepspeed engine supports resume from intermediate checkpoint
if engine.name == "deepspeed" and engine.config.auto_resume and engine.latest_checkpoint_exists:
self.is_resume = True
# XXX: We can abstract this section further with AT-specific wrapper, but
# UlyssesSPAttentionHF should not have any AT-specific objects / assumptions
mpu = UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=self.config.model.name_or_path,
core_attn_implementation=self.config.model.attn_implementation,
sequence_parallel_size=self.config.sequence_parallel_size,
micro_batch_size=self.config.micro_batch_size,
seq_length=self.config.data.max_length,
seq_length_is_variable=True,
)
# Important: this is most likely not beneficial under seqlen=64k
if self.config.activation_checkpoint_cpu_offload:
# activation_checkpointing_cpu_offload becomes very benefitial at very long seqlen
# e.g., llama 8b at 800k (100k effective per gpu) will save 24GB per gpu:
# ((100_000*4096)*2*32/2**30), but for short sequences the offload will just slow things
# down,
#
# XXX: could parameterize or run a few lengths to see at which threshold it becomes
# beneficial - a user might still want this on even at shorter seqlen if they don't
# mind slower performance. discussing adding this functionality to pytorch core
# (https://pytorch.slack.com/archives/C3PDTEV8E/p1745274102600729)
from arctic_training.monkey_patches import monkey_patch_checkpoint_function_with_cpu_offload
monkey_patch_checkpoint_function_with_cpu_offload()
# MLP tiling - has to happen before model is instantiated
if self.config.tiled_mlp_compute:
enable_tiled_mlp_compute(self.config.model.name_or_path)
dschf = HfDeepSpeedConfig(self.config.deepspeed) # noqa: F841
model_factory = self.config.model.factory(self)
self.model = model_factory()
self.count_model_params_in_original_model()
# prevent causal mask from being created in HF Transformers - it's a huge `[bs, seqlen, seqlen]` tensor
# XXX: This should also benefit a single gpu use case when SDPA is used - so perhaps remove the SP>1 check?
if self.config.sequence_parallel_size > 1 and self.config.model.attn_implementation not in [
"flash_attention_2",
"flash_attention_3",
]:
import transformers.masking_utils
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", lambda *args, **kwargs: None)
# Arctic MoE model remapping has to be called before an optimizer is created
if self.config.arctic_moe == "auto":
self.use_arctic_moe = detect_if_moe_model(self.model)
else:
self.use_arctic_moe = self.config.arctic_moe
if self.use_arctic_moe:
pr0("Activating ArcticMoE", force=False)
import deepspeed.comm as dist
from deepspeed.utils import groups
from arctic_training.model.moe.utils import monkey_patch_ds_moe
from arctic_training.model.moe.utils import remap_orig_moe_mlp_params_to_arctic_moe
if not dist.is_initialized():
dist.init_distributed(dist_backend="nccl", dist_init_required=True)
monkey_patch_ds_moe()
# deepspeed.runtime.engine.DeepSpeedEngine.print_forward_breakdown = print_forward_breakdown
# DeepspeedMoE is only integrated with ZeRO-2
zero_stage = self.config.deepspeed.get("zero_optimization", {}).get("stage", 0)
if zero_stage != 2:
raise ValueError(
"at the moment Deepspeed supports only ZeRO stage 2 with MoE, but the configuration asks for ZeRO"
f" stage={zero_stage}"
)
# this config comes from use_data_before_expert_parallelism ds config which defaults to False
# engine._config.use_data_before_expert_parallel_)
# but we don't have the engine yet to get the ds config values - perhaps could extract this via AT-config?
use_data_before_expert_parallel_ = False
# the ep group has to be created before remap_orig_moe_mlp_params_to_arctic_moe as ep rank info is needed to remap pre-trained experts
groups._create_expert_data_and_model_parallel(
self.config.expert_parallel_size,
mpu=None,
use_data_before_expert_parallel_=use_data_before_expert_parallel_,
)
# self.groups = ParallelGroups(expert_parallel_size=self.config.expert_parallel_size)
# we sort out if we are in resume mode much later, by actually trying to load the model, but that's too late so we are going to rely on testing if the latest checkpoint exists instead
# early_is_resume = False
# for engine in self.checkpoint_engines:
# # currently only deepspeed engine supports resume from intermediate checkpoint
# if engine.name == "deepspeed" and engine.config.auto_resume and engine.latest_checkpoint_exists:
# early_is_resume = True
remap_orig_moe_mlp_params_to_arctic_moe(
self.model,
ep_size=self.config.expert_parallel_size,
is_resume=self.is_resume,
enable_custom_moe_kernel=self.config.enable_arctic_moe_custom_optimization,
enable_routing_replay=self.config.enable_routing_replay,
)
# self.groups)
# XXX: check we can remap back
# from arctic_training.model.moe.utils import remap_arctic_moe_params_to_orig_moe_mlp
# remap_arctic_moe_params_to_orig_moe_mlp(self.model)
see_memory_usage("after moe remap", force=False)
# this is an optional debug instrumentation to trace overflows in params/grads
#
# inspectors are important to call after all model tweaks are done (e.g. after AMoE)
#
# from arctic_training.debug.underflow_overflow import DebugUnderflowOverflow
# debug_overflow = DebugUnderflowOverflow(self.model, max_frames_to_save=100) # noqa
#
# from arctic_training.debug.underflow_overflow import DebugGradients
# debug_grads = DebugGradients(self.model, trace_batch_nums=[1], max_frames_to_save=100) # noqa
optimizer_factory = self.config.optimizer.factory(self)
self.optimizer = optimizer_factory()
scheduler_factory = self.config.scheduler.factory(self)
self.scheduler = scheduler_factory()
self.model, *_ = deepspeed.initialize(
model=self.model,
optimizer=self.optimizer,
args=self.config,
lr_scheduler=self.scheduler,
config=self.config.deepspeed,
mpu=mpu,
)
if self.use_arctic_moe:
amoe_install_deepspeed_timers(self.model, self.model_unwrapped)
self.ds_wall_clock_available = hasattr(self.model, "get_wall_clock_timers")
if self.config.sequence_parallel_size > 1:
# deepspeed.initialize needs to run first
from deepspeed.utils import groups
# set SP-trainer attributes to be used later
self.sp_group = groups._get_sequence_parallel_group()
self.sp_world_size = groups._get_sequence_parallel_world_size()
self.sp_rank = groups._get_sequence_parallel_rank()
# wrap the DL with Ulysses one
self.train_dataloader = UlyssesSPDataLoaderAdapter(
self.train_dataloader,
sp_rank=self.sp_rank,
sp_group=self.sp_group,
sp_world_size=self.sp_world_size,
device=self.device,
)
if self.eval_dataloader is not None:
self.eval_dataloader = UlyssesSPDataLoaderAdapter(
self.eval_dataloader,
sp_rank=self.sp_rank,
sp_group=self.sp_group,
sp_world_size=self.sp_world_size,
device=self.device,
)
for engine in self.checkpoint_engines:
if engine.config.auto_resume:
engine.load(self.model)
self.metrics = Metrics(self)
if self.global_rank == 0 and self.config.wandb.enable:
# in order for resume to continue the same wandb run we need to re-use a run_id from the previous run
if self.wandb_run_id is None:
self.wandb_run_id = wandb.util.generate_id() # type: ignore
# Note: wandb.init() is not type annotated so we need to use type: ignore
self.wandb_experiment = wandb.init( # type: ignore
id=self.wandb_run_id,
entity=self.config.wandb.entity,
project=self.config.wandb.project,
name=self.config.wandb.name,
config=self.config.model_dump(),
# do not put `wandb` in the root of the repo as it conflicts with wandb package
dir=f"{self.config.logger.output_dir}/wandb",
)
def _set_seeds(self, seed: int) -> None:
logger.info(f"Setting random seeds to {seed}")
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
set_seed(seed)
@property
def model_unwrapped(self):
"""Return the original model before it was wrapped by deepspeed"""
if hasattr(self.model, "module"):
return self.model.module
else:
return self.model
@property
def epochs(self) -> tqdm:
"""Epochs iterator."""
total_epochs = self.config.epochs
if self.config.train_iters:
total_epochs = math.ceil(
self.config.train_iters * self.config.gradient_accumulation_steps / len(self.train_dataloader)
)
return tqdm(
range(self.epoch_idx, total_epochs),
desc="Epochs",
unit="epoch",
disable=(self.global_rank != 0) or (self.config.train_log_iter_interval != 0),
)
@property
def train_batches(self) -> tqdm:
"""Training data iterator."""
return tqdm(
self.train_dataloader,
desc="Train Batches",
unit="batch",
disable=(self.global_rank != 0) or (self.config.train_log_iter_interval != 0),
)
@property
def eval_batches(self) -> tqdm:
"""Evaluation data iterator."""
return tqdm(
self.eval_dataloader,
desc="Eval Batches",
unit="batch",
disable=self.global_rank != 0 or not self.is_eval_log_iter(),
)
[docs]
def is_eval_log_iter(self) -> bool:
return self.global_step // self.config.eval_interval % self.config.eval_log_iter_interval == 0
@cached_property
def device(self) -> torch.device:
"""Current device."""
return torch.device(get_accelerator().device_name(self.config.local_rank))
@property
def training_horizon(self) -> int:
"""Total number of training iterations."""
if self.train_dataloader is None:
raise ValueError("Train dataloader not initialized.")
if self.config.train_iters:
return self.config.train_iters
# XXX: this was incorrect for GAS
return self.config.epochs * len(self.train_dataloader) # // self.config.gradient_accumulation_steps
[docs]
@callback_wrapper("loss")
@abstractmethod
def loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Loss function for the trainer. This method should be implemented by the
inheriting trainer class.
"""
raise NotImplementedError("Loss method must be implemented by the trainer.")
[docs]
@callback_wrapper("backward")
def backward(self, loss: torch.Tensor) -> None:
"""
Backward function for the trainer. This method is called after the loss
method and is responsible for backpropagating the loss through the model.
"""
self.model.backward(loss)
[docs]
def need_early_exit(self):
"""
If we need to exit early, set `self.early_stop_reason` and return True
Otherwise return False
"""
# exit conditions in the order of likelyhood
if (
self.config.exit_iteration_this_run > 0
and self.config.exit_iteration_this_run == self.global_step_this_run
):
self.early_stop_reason = f"reached exit_iteration_this_run of {self.global_step_this_run}"
return True
elif self.config.exit_iteration > 0 and self.config.exit_iteration == self.global_step:
self.early_stop_reason = f"reached exit_iteration of {self.global_step}"
return True
elif self.config.kill_switch_path.exists():
self.early_stop_reason = f"detected kill switch {self.config.kill_switch_path}"
return True
elif self.global_step >= self.training_horizon:
self.early_stop_reason = f"reached training_horizon of {self.global_step}"
return True
return False
[docs]
@callback_wrapper("step")
def step(self, batch: Dict[str, torch.Tensor]) -> None:
"""
Step function for the trainer. Each batch of training data is passed to
this method.
"""
self.model.train()
with deepspeed.runtime.engine.autocast_if_enabled(self.model):
loss = self.loss(batch)
self.backward(loss)
def maybe_item(v):
return v.item() if torch.is_tensor(v) else v
self.metrics.record("loss", maybe_item(loss))
# if neededing to debug AMoE-EP grads
# from deepspeed.utils import safe_get_full_grad
#
# if hasattr(self.model_unwrapped.model.layers[1].mlp, "router_gate"):
# pr0("------------------------->8------------- grads ------------->8----------",
# force=True)
# pr0(
# f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.router_gate))=}",
# force=True,
# )
# pr0(
# f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.expert_gate_up))=}",
# force=True,
# )
# pr0(
# f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].mlp.expert_down))=}",
# force=True,
# )
# pr0(
# f"grad: {torch.norm(safe_get_full_grad(self.model_unwrapped.model.layers[1].post_attention_layernorm.weight))=}",
# force=True,
# )
# pr0("------------------------->8------------- grads end --------->8----------",
# force=True)
# exit()
self.model.step()
self.checkpoint()
# DeepSpeed increments its global step after the step() call, so we use it as the golden truth
self.global_step = self.model.global_steps
self.global_step_this_run = self.global_step - self.global_step_at_start_this_run
[docs]
@callback_wrapper("epoch")
def epoch(self) -> None:
"""
Epoch training loop. This method will be called for each epoch of
training and iterates across batches of training data, calling the step
method on each batch.
"""
self.epoch_finished = False
self.metrics.start_timer("iter")
# enable memory allocation history, which will add tracebacks and event history to memory snapshots
if self.config.mem_profiler == "step":
torch.cuda.memory._record_memory_history(max_entries=self.config.mem_profiler_max_entries)
batch_iterator = iter(self.train_batches)
if self.is_resume:
logger.info(f"Resumed from checkpoint at global step: {self.global_step}.")
batches_to_skip = self.global_step % len(self.train_dataloader)
logger.info(f"Advancing {batches_to_skip} batches.")
for _ in range(batches_to_skip):
next(batch_iterator)
self.train_batch_idx += batches_to_skip
self.is_resume = False
for batch in batch_iterator:
self.train_batch_idx += 1
# Run the early exit checks before stepping to correctly deal with resume should the training not continue
if self.need_early_exit():
self.early_stop = True
break
self.gas_boundary = self.train_batch_idx % self.config.gradient_accumulation_steps == 0
if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation:
# deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately
sample_seqlens = batch.pop("packed_sample_seqlens")
else:
sample_seqlens = [
[len(batch["input_ids"][idx]) * self.config.sequence_parallel_size]
for idx in range(len(batch["input_ids"]))
]
self.metrics.seqlens = sample_seqlens
self.metrics.start_timer("step")
self.step(batch)
self.metrics.stop_timer("step")
self.metrics.restart_timer("iter")
if self.config.train_log_iter_interval != 0:
self.metrics.print_summary()
if self.gas_boundary:
if (
self.global_rank == 0
and self.config.train_log_iter_interval != 0
and self.global_step % self.config.train_log_iter_interval == 0
):
metrics = {k: v for k, v in self.metrics.summary_dict.items()}
if self.ds_wall_clock_available:
ds_timers = self.model.get_wall_clock_timers()
metrics.update(ds_timers)
append_json_file(self.config.train_log_metrics_path, metrics)
# do not log the first train iteration to wandb, since it's a massive outlier
# on all performance metrics, which messes up the scale of the report
if self.wandb_experiment is not None and self.global_step > 1:
metrics = {k: v for k, v in metrics.items() if k not in ["iter"]}
self.wandb_experiment.log(metrics, step=self.global_step)
if self.config.eval_interval != 0 and self.global_step % self.config.eval_interval == 0:
self.evaluate()
if self.is_eval_log_iter():
self.metrics.print_summary(prefix="eval")
if self.wandb_experiment is not None:
metrics = {k: self.metrics.summary_dict[k] for k in ["loss/eval"]}
self.wandb_experiment.log(metrics, step=self.global_step)
self.metrics.stop_timer("iter")
self.epoch_finished = True
[docs]
@callback_wrapper("train")
def train(self) -> None:
"""
Main training loop. Calls the epoch method for each epoch of training.
"""
self.print_model_parameters_header()
try:
# to be able to keep track of number of steps of this run inside step()
self.global_step_at_start_this_run = self.global_step
for epoch_idx in self.epochs:
self.epoch_idx = epoch_idx
self.epoch()
if self.early_stop:
break
self.checkpoint()
self.training_finished = True
if self.global_rank == 0:
if self.early_stop:
print(f"*** Exiting training early because training {self.early_stop_reason}")
else:
print("*** Training finished normally.")
self.checkpoint()
except Exception as e:
logger.error(f"Training failed with error: {e}")
# logger.info(f"{self._trainer_state}")
raise (e)
finally:
if self.config.mem_profiler is not None:
torch.cuda.memory._dump_snapshot(self.config.mem_profiler_dir / f"{self.global_rank}.pickle")
if self.wandb_experiment is not None:
self.wandb_experiment.finish()
[docs]
@callback_wrapper("evaluate")
def evaluate(self) -> None:
"""
Evaluation loop. Measures the model's performance on the evaluation dataset.
"""
self.model.eval()
with torch.no_grad():
losses = [self.loss(eval_batch).item() for eval_batch in self.eval_batches]
self.metrics.record("loss/eval", losses) # type: ignore
[docs]
@callback_wrapper("checkpoint")
def checkpoint(self) -> None:
if self.global_step_this_run == 0:
logger.info("No steps were run this run, not saving the checkpoint")
return
for engine in self.checkpoint_engines:
if engine.do_checkpoint:
if engine.name == "huggingface" and self.use_arctic_moe:
if self.training_finished:
# export to the original moe mlp format/layout - this is slow but it's the end of the training so it's fine.
from arctic_training.model.moe.utils import remap_arctic_moe_to_orig_moe_mlp_params
logger.info("Exporting to the original MoE format before saving the checkpoint")
remap_arctic_moe_to_orig_moe_mlp_params(self.model)
else:
raise ValueError(
"Currently supporting saving to HF checkpoint for AMoE models only when the training is"
" finished, because conversion will be very slow. For interim checkpoints use `deepspeed`"
" type of the checkpoint as it'd be much faster to save to and resume from. "
)
logger.info(f"Saving Checkpoint at global step: {self.global_step}.")
engine.save(self.model)
[docs]
def count_model_parameters(self):
"""Returns a dictionary containing "total" and "trainable" parameters."""
sizes = defaultdict(int)
def numel_fn(p):
return p.ds_numel if hasattr(p, "ds_numel") else p.numel()
for param in self.model.parameters():
numel = numel_fn(param)
sizes["total"] += numel
if param.requires_grad:
sizes["trainable"] += numel
# Converting defaultdict --> dict for nicer printing.
return dict(sizes)
[docs]
def count_model_params_in_original_model(self):
"""This counts total params in the model before it got sliced into MoE EP slices"""
# XXX: perhaps add a new class for various stats? or may be add to metrics class?
if torch.distributed.get_rank() == 0:
self.original_hf_model_params = self.count_model_parameters()