diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 362d57e9f8..03ec2928aa 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -143,21 +143,21 @@ def _get_merged_dataset( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], + merge: bool = True ) -> Optional[Union["Dataset", "IterableDataset"]]: - r""" - Gets the merged datasets in the standard format. - """ if dataset_names is None: return None - datasets = [] + datasets = {} for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") - datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) - - return merge_dataset(datasets, data_args, seed=training_args.seed) + datasets[f'{dataset_attr.dataset_name}_{dataset_attr.subset}'] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) + if merge: + return merge_dataset([data for _, data in datasets.items()], data_args, seed=training_args.seed) + else: + return datasets def _get_preprocessed_dataset( @@ -246,15 +246,21 @@ def get_dataset( # Load and preprocess dataset with training_args.main_process_first(desc="load dataset"): dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) - eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage, merge=data_args.streaming) with training_args.main_process_first(desc="pre-process dataset"): dataset = _get_preprocessed_dataset( dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False ) - eval_dataset = _get_preprocessed_dataset( - eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True - ) + if isinstance(eval_dataset, dict): + for eval_name, eval_data in eval_dataset.items(): + eval_dataset[eval_name] = _get_preprocessed_dataset( + eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + else: + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) if data_args.val_size > 1e-6: dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) @@ -269,7 +275,13 @@ def get_dataset( if eval_dataset is not None: if data_args.streaming: eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) - + dataset_dict["validation"] = eval_dataset + else: + if isinstance(eval_dataset, dict): + for eval_name, eval_data in eval_dataset.items(): + dataset_dict[f"validation_{eval_name}"] = eval_data + else: + dataset_dict["validation"] = eval_dataset dataset_dict["validation"] = eval_dataset dataset_dict = DatasetDict(dataset_dict) @@ -285,8 +297,15 @@ def get_dataset( dataset_module = {} if "train" in dataset_dict: dataset_module["train_dataset"] = dataset_dict["train"] - + if "validation" in dataset_dict: dataset_module["eval_dataset"] = dataset_dict["validation"] + + eval_datasets_map = {} + for key, value in dataset_dict.items(): + if 'validation_' in key: + eval_datasets_map[key] = dataset_dict[key] + if len(eval_datasets_map): + dataset_module["eval_dataset"] = eval_datasets_map return dataset_module