diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 82bbfafb2..a87dfdb7e 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from ..extras import logging -from .data_utils import Role +from .data_utils import Role, merge_messages if TYPE_CHECKING: @@ -152,7 +152,7 @@ def convert_sharegpt( odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) accept_tags = (odd_tags, even_tags) - messages = example[dataset_attr.messages] + messages = merge_messages(example[dataset_attr.messages], dataset_attr) if ( dataset_attr.system_tag and len(messages) != 0 diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index cbce026c8..229d2345f 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -17,6 +17,8 @@ from datasets import DatasetDict, concatenate_datasets, interleave_datasets +from .parser import DatasetAttr + from ..extras import logging @@ -90,3 +92,29 @@ def split_dataset( val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size dataset = dataset.train_test_split(test_size=val_size, seed=seed) return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) + + +def merge_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]: + # merge obversation messages + new_messages = [] + waiting_message = [] + + def append_waiting_message(): + if len(waiting_message) == 1: + new_messages.append(waiting_message[0]) + else: + assert waiting_message[0]["role"] == dataset_attr.observation_tag + new_messages.append( + {"role": dataset_attr.observation_tag, "content": [m["content"] for m in waiting_message]} + ) + + for message in messages: + if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]: + append_waiting_message() + waiting_message = [] + waiting_message.append(message) + + if len(waiting_message) > 0: + append_waiting_message() + + return new_messages diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 2127f9577..bd7642dd4 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -104,7 +104,6 @@ def _encode( """ system = system or self.default_system encoded_messages = [] - messages = self.merge_messages(messages) for i, message in enumerate(messages): elements = [] @@ -157,32 +156,6 @@ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: " return token_ids - @staticmethod - def merge_messages(messages: Sequence[Dict[str, str]]) -> List[Dict[str, str]]: - # merge obversation messages - new_messages = [] - waiting_message = [] - - def append_waiting_message(): - if len(waiting_message) == 1: - new_messages.append(waiting_message[0]) - else: - assert waiting_message[0]["role"] == Role.OBSERVATION.value - new_messages.append( - {"role": Role.OBSERVATION.value, "content": [m["content"] for m in waiting_message]} - ) - for message in messages: - if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]: - append_waiting_message() - waiting_message = [] - waiting_message.append(message) - - if len(waiting_message) > 0: - append_waiting_message() - - return new_messages - - @dataclass class Llama2Template(Template): @override @@ -200,7 +173,6 @@ def _encode( """ system = system or self.default_system encoded_messages = [] - messages = self.merge_messages(messages) for i, message in enumerate(messages): elements = []