From 0b042dffa0f0b115b177bf3d5b0c0c35c99d4cd1 Mon Sep 17 00:00:00 2001 From: SrWYG Date: Tue, 24 Sep 2024 10:06:03 +0800 Subject: [PATCH] [Update] loader.py , evaluate will run separate evaluations on each dataset. `If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run separate evaluations on each dataset. This can be useful to monitor how training affects other datasets or simply to get a more fine-grained evaluation` seq2seqtrainner support eval_dataset as Dict. --- src/llamafactory/data/loader.py | 45 +++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 362d57e9f8..03ec2928aa 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -143,21 +143,21 @@ def _get_merged_dataset( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], + merge: bool = True ) -> Optional[Union["Dataset", "IterableDataset"]]: - r""" - Gets the merged datasets in the standard format. - """ if dataset_names is None: return None - datasets = [] + datasets = {} for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): raise ValueError("The dataset is not applicable in the current training stage.") - datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args)) - - return merge_dataset(datasets, data_args, seed=training_args.seed) + datasets[f'{dataset_attr.dataset_name}_{dataset_attr.subset}'] = _load_single_dataset(dataset_attr, model_args, data_args, training_args) + if merge: + return merge_dataset([data for _, data in datasets.items()], data_args, seed=training_args.seed) + else: + return datasets def _get_preprocessed_dataset( @@ -246,15 +246,21 @@ def get_dataset( # Load and preprocess dataset with training_args.main_process_first(desc="load dataset"): dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) - eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage, merge=data_args.streaming) with training_args.main_process_first(desc="pre-process dataset"): dataset = _get_preprocessed_dataset( dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False ) - eval_dataset = _get_preprocessed_dataset( - eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True - ) + if isinstance(eval_dataset, dict): + for eval_name, eval_data in eval_dataset.items(): + eval_dataset[eval_name] = _get_preprocessed_dataset( + eval_data, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) + else: + eval_dataset = _get_preprocessed_dataset( + eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True + ) if data_args.val_size > 1e-6: dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed) @@ -269,7 +275,13 @@ def get_dataset( if eval_dataset is not None: if data_args.streaming: eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) - + dataset_dict["validation"] = eval_dataset + else: + if isinstance(eval_dataset, dict): + for eval_name, eval_data in eval_dataset.items(): + dataset_dict[f"validation_{eval_name}"] = eval_data + else: + dataset_dict["validation"] = eval_dataset dataset_dict["validation"] = eval_dataset dataset_dict = DatasetDict(dataset_dict) @@ -285,8 +297,15 @@ def get_dataset( dataset_module = {} if "train" in dataset_dict: dataset_module["train_dataset"] = dataset_dict["train"] - + if "validation" in dataset_dict: dataset_module["eval_dataset"] = dataset_dict["validation"] + + eval_datasets_map = {} + for key, value in dataset_dict.items(): + if 'validation_' in key: + eval_datasets_map[key] = dataset_dict[key] + if len(eval_datasets_map): + dataset_module["eval_dataset"] = eval_datasets_map return dataset_module