Data Factory
Data Source
The Data Source is responsible for loading the raw data used in the training
pipeline. A Data Source can be created by inheriting from the
DataSource class and implementing the
load() method.
- class arctic_training.data.source.DataSource(data_factory, config)[source]
Bases:
ABC,CallbackMixinBase DataSource class for loading training and evaluation data.
- Parameters:
data_factory (DataFactory)
config (DataSourceConfig)
-
name:
str Name of the DataSource.
-
config:
DataSourceConfig The type of the DataSourceConfig object that this DataSource uses. Any DataSource-specific options should be specified in this class.
- property data_factory: DataFactory
- property world_size: int
- property global_rank: int
- property cache_path_args: Tuple[Dict, ...]
Returns a dictionary of config fields that affect the cache path calculation.
- property cache_path: Path
Returns the cache path for the data source split.
Attributes
To define a custom data source, you must subclass the DataSource and define the
name attribute and give a type hint for the
config attribute.
Methods
To define a custom data source, you must implement the
load_fn(). This method should return a HuggingFace Dataset
object.
Data Factory
The Data Factory is responsible for creating the training and evaluation datasets used in the training pipeline.
- class arctic_training.data.factory.DataFactory(trainer, config=None)[source]
Bases:
ABC,CallbackMixinBase DataFactory class for loading training and evaluation data.
- Parameters:
trainer (Trainer)
config (DataConfig)
-
name:
str Name of the DataFactory. This name should be unique to each registered DataFactory object. This name can be used in the training recipe YAMLs to specify the DataFactory to use.
-
default_source_cls:
Optional[Type] = None
-
config:
DataConfig The type of the DataConfig object that this DataFactory uses. Any DataFactory-specific options should be specified in this class.
- property tokenizer: transformers.PreTrainedTokenizerBase
The tokenizer object used by the Trainer.
- property micro_batch_size: int
The micro batch size used by the Trainer.
- property global_rank: int
The global rank of the current process.
- property local_rank: int
The local rank of the current process.
- property world_size: int
The total number of processes in the world.
- property is_main_process_by_path: bool
- cache_path(sources)[source]
Returns the cache path for the processed + concatenated dataset.
- Return type:
Path- Parameters:
sources (List[DataSource])
- load(data_sources)[source]
Loads data from one or more data sources and concatenates into a single dataset.
- Return type:
Union[Dataset,IterableDataset]- Parameters:
data_sources (List[DataSource])
- process(dataset)[source]
Process the dataset (e.g., tokenization for text data).
- Return type:
Union[Dataset,IterableDataset]- Parameters:
dataset (datasets.Dataset | datasets.IterableDataset)
Attributes
To define a custom data factory, you must subclass the DataFactory, define the
name attribute, and give a type hint for the
config attribute.
Properties
The Data Factory class provides several properties that can be used to access
information about the state of the Trainer, Tokenizer, and distributed
environment at runtime. These include trainer,
tokenizer, micro_batch_size,
global_rank, and world_size.
Methods
To define a custom data factory, you must implement the
process() method. Additionally, you can override the
load(), split_data(), and
create_dataloader() methods to change default behaviors.
SFTDataFactory
To help get started with creating custom trainers and data factories, ArcticTraining includes a Supervised Fine-Tuning (SFT) trainer (described in Trainer). We also include here an example of how to build a data factory from the base building blocks for use with the SFTTrainer. The SFTDataFactory can be used with the SFTTrainer or your own custom trainer. It can also be extended to fit other use cases.
To create the SFTDataFactory, we subclass the DataFactory and first define the
process() method to tokenize the loaded datasets:
def process(self, dataset: DatasetType) -> DatasetType:
if "messages" not in dataset.column_names:
raise ValueError("Dataset must have 'messages' column to tokenize for SFTDataFactory.")
dataset = dataset.select_columns(["messages"])
# sft based tokenization,
# we assume the messages are in the format of:
# {'role': '...', 'content': '...'}
# datasets = datasets.select(range(100, 1100))
dataset = dataset.select(range(len(dataset)))
# datasets.disable_caching()
# tmp = tokenize_messages(datasets[0]["messages"][:2], tokenizer, mask_inputs=mask_inputs)
# import pdb; pdb.set_trace()
return dataset.map(
lambda ex: {
**self.tokenize_messages(
ex["messages"],
self.tokenizer,
mask_inputs=self.config.mask_inputs,
ignore_empty_think=self.config.ignore_empty_think,
)
},
remove_columns=dataset.column_names,
num_proc=self.config.num_proc,
desc="Tokenizing messages",
)
Next we override the create_dataloader() method to add a custom Data Collator:
def create_dataloader(self, dataset: DatasetType) -> DataLoader:
dataloader = super().create_dataloader(dataset)
dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config)
return dataloader
Finally, we define two post-load callbacks that filter the any data source datasets based on a maximum desired length and then pack the data:
def filter_dataset_length(self, dataset: DatasetType) -> DatasetType:
if not self.config.filter_samples:
return dataset
dataset = dataset.filter(
lambda x: len(x["input_ids"]) <= self.config.max_length,
num_proc=self.config.num_proc,
desc="Filtering dataset by max length",
)
if len(dataset) < 1:
raise ValueError(
f"No data left after filtering by max length {self.config.max_length} in"
f" {self.__class__.__name__}. Consider increasing the `max_length`."
)
return dataset
def pack_dataset(self, dataset: DatasetType) -> DatasetType:
if not self.config.pack_samples:
return dataset
if self.config.repeat_to_pack_max_length:
dataset = repeat_dataset(dataset=dataset, max_length=self.config.max_length, num_proc=self.config.num_proc)
batch_size = len(dataset) // self.config.num_proc + 1
# for huge datasets keep the bs to a sane size to avoid cpu-oom
batch_size = int(min(batch_size, 1e3))
dataset = dataset.shuffle(seed=self.config.seed)
dataset = dataset.map(
lambda x: pack_sft_batch(
x,
max_length=self.config.max_length,
always_max_length=self.config.always_max_length,
drop_last=self.config.drop_last,
fuse_positions_prob=self.config.fuse_positions_prob,
seed=self.config.seed,
),
batched=True,
batch_size=batch_size,
num_proc=self.config.num_proc,
desc="Packing dataset",
)
if len(dataset) < 1:
raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}")
return dataset
These callback functions are added to SFTDataFactory by adding to the callback
attribute and they are run on the concatenated datasets returned from the
load() method.
from arctic_training import logger
class SFTDataFactory(DataFactory):
name = "sft"
config: SFTDataConfig
callbacks = [
("post-load", filter_dataset_length),
("post-load", pack_dataset)
]