Data Factory

Data Source

The Data Source is responsible for loading the raw data used in the training pipeline. A Data Source can be created by inheriting from the DataSource class and implementing the load() method.

class arctic_training.data.source.DataSource(data_factory, config)[source]

Bases: ABC, CallbackMixin

Base DataSource class for loading training and evaluation data.

Parameters:
name: str

Name of the DataSource.

config: DataSourceConfig

The type of the DataSourceConfig object that this DataSource uses. Any DataSource-specific options should be specified in this class.

property trainer: Trainer
property data_factory: DataFactory
property world_size: int
property global_rank: int
property cache_path_args: Tuple[Dict, ...]

Returns a dictionary of config fields that affect the cache path calculation.

property cache_path: Path

Returns the cache path for the data source split.

abstract load(config, split)[source]

Method to load the data. It should return a datasets.Dataset or datasets.IterableDataset.

Return type:

Union[Dataset, IterableDataset]

Parameters:
  • config (DataSourceConfig)

  • split (str)

Attributes

To define a custom data source, you must subclass the DataSource and define the name attribute and give a type hint for the config attribute.

Methods

To define a custom data source, you must implement the load_fn(). This method should return a HuggingFace Dataset object.

Data Factory

The Data Factory is responsible for creating the training and evaluation datasets used in the training pipeline.

class arctic_training.data.factory.DataFactory(trainer, config=None)[source]

Bases: ABC, CallbackMixin

Base DataFactory class for loading training and evaluation data.

Parameters:
name: str

Name of the DataFactory. This name should be unique to each registered DataFactory object. This name can be used in the training recipe YAMLs to specify the DataFactory to use.

default_source_cls: Optional[Type] = None
config: DataConfig

The type of the DataConfig object that this DataFactory uses. Any DataFactory-specific options should be specified in this class.

property trainer: Trainer

The Trainer object that is using this DataFactory.

property tokenizer: transformers.PreTrainedTokenizerBase

The tokenizer object used by the Trainer.

property micro_batch_size: int

The micro batch size used by the Trainer.

property global_rank: int

The global rank of the current process.

property local_rank: int

The local rank of the current process.

property world_size: int

The total number of processes in the world.

property is_main_process_by_path: bool
cache_path(sources)[source]

Returns the cache path for the processed + concatenated dataset.

Return type:

Path

Parameters:

sources (List[DataSource])

load(data_sources)[source]

Loads data from one or more data sources and concatenates into a single dataset.

Return type:

Union[Dataset, IterableDataset]

Parameters:

data_sources (List[DataSource])

process(dataset)[source]

Process the dataset (e.g., tokenization for text data).

Return type:

Union[Dataset, IterableDataset]

Parameters:

dataset (datasets.Dataset | datasets.IterableDataset)

split_data(training_data)[source]

Split the training data into training and evaluation datasets.

Return type:

Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset, None]]

Parameters:

training_data (datasets.Dataset | datasets.IterableDataset)

create_dataloader(dataset)[source]

Create a torch DataLoader from the dataset.

Return type:

DataLoader

Parameters:

dataset (datasets.Dataset | datasets.IterableDataset)

Attributes

To define a custom data factory, you must subclass the DataFactory, define the name attribute, and give a type hint for the config attribute.

Properties

The Data Factory class provides several properties that can be used to access information about the state of the Trainer, Tokenizer, and distributed environment at runtime. These include trainer, tokenizer, micro_batch_size, global_rank, and world_size.

Methods

To define a custom data factory, you must implement the process() method. Additionally, you can override the load(), split_data(), and create_dataloader() methods to change default behaviors.

SFTDataFactory

To help get started with creating custom trainers and data factories, ArcticTraining includes a Supervised Fine-Tuning (SFT) trainer (described in Trainer). We also include here an example of how to build a data factory from the base building blocks for use with the SFTTrainer. The SFTDataFactory can be used with the SFTTrainer or your own custom trainer. It can also be extended to fit other use cases.

To create the SFTDataFactory, we subclass the DataFactory and first define the process() method to tokenize the loaded datasets:

    def process(self, dataset: DatasetType) -> DatasetType:
        if "messages" not in dataset.column_names:
            raise ValueError("Dataset must have 'messages' column to tokenize for SFTDataFactory.")
        dataset = dataset.select_columns(["messages"])
        # sft based tokenization,
        # we assume the messages are in the format of:
        # {'role': '...', 'content': '...'}
        # datasets = datasets.select(range(100, 1100))
        dataset = dataset.select(range(len(dataset)))
        # datasets.disable_caching()
        # tmp = tokenize_messages(datasets[0]["messages"][:2], tokenizer, mask_inputs=mask_inputs)
        # import pdb; pdb.set_trace()
        return dataset.map(
            lambda ex: {
                **self.tokenize_messages(
                    ex["messages"],
                    self.tokenizer,
                    mask_inputs=self.config.mask_inputs,
                    ignore_empty_think=self.config.ignore_empty_think,
                )
            },
            remove_columns=dataset.column_names,
            num_proc=self.config.num_proc,
            desc="Tokenizing messages",
        )

Next we override the create_dataloader() method to add a custom Data Collator:

    def create_dataloader(self, dataset: DatasetType) -> DataLoader:
        dataloader = super().create_dataloader(dataset)
        dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config)
        return dataloader

Finally, we define two post-load callbacks that filter the any data source datasets based on a maximum desired length and then pack the data:

def filter_dataset_length(self, dataset: DatasetType) -> DatasetType:
    if not self.config.filter_samples:
        return dataset

    dataset = dataset.filter(
        lambda x: len(x["input_ids"]) <= self.config.max_length,
        num_proc=self.config.num_proc,
        desc="Filtering dataset by max length",
    )
    if len(dataset) < 1:
        raise ValueError(
            f"No data left after filtering by max length {self.config.max_length} in"
            f" {self.__class__.__name__}. Consider increasing the `max_length`."
        )
    return dataset
def pack_dataset(self, dataset: DatasetType) -> DatasetType:
    if not self.config.pack_samples:
        return dataset

    if self.config.repeat_to_pack_max_length:
        dataset = repeat_dataset(dataset=dataset, max_length=self.config.max_length, num_proc=self.config.num_proc)

    batch_size = len(dataset) // self.config.num_proc + 1
    # for huge datasets keep the bs to a sane size to avoid cpu-oom
    batch_size = int(min(batch_size, 1e3))
    dataset = dataset.shuffle(seed=self.config.seed)
    dataset = dataset.map(
        lambda x: pack_sft_batch(
            x,
            max_length=self.config.max_length,
            always_max_length=self.config.always_max_length,
            drop_last=self.config.drop_last,
            fuse_positions_prob=self.config.fuse_positions_prob,
            seed=self.config.seed,
        ),
        batched=True,
        batch_size=batch_size,
        num_proc=self.config.num_proc,
        desc="Packing dataset",
    )
    if len(dataset) < 1:
        raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}")
    return dataset

These callback functions are added to SFTDataFactory by adding to the callback attribute and they are run on the concatenated datasets returned from the load() method.

from arctic_training import logger
class SFTDataFactory(DataFactory):
    name = "sft"
    config: SFTDataConfig
    callbacks = [
        ("post-load", filter_dataset_length),
        ("post-load", pack_dataset)
    ]