diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index a87dfdb7e..c0d84575c 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from ..extras import logging -from .data_utils import Role, merge_messages +from .data_utils import Role, continuous_messages if TYPE_CHECKING: @@ -152,13 +152,16 @@ def convert_sharegpt( odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) accept_tags = (odd_tags, even_tags) - messages = merge_messages(example[dataset_attr.messages], dataset_attr) + messages = continuous_messages(example[dataset_attr.messages], dataset_attr) if ( dataset_attr.system_tag and len(messages) != 0 and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag ): - system = messages[0][dataset_attr.content_tag] + if isinstance(messages[0][dataset_attr.content_tag], list): + system = "".join(messages[0][dataset_attr.content_tag]) + else: + system = messages[0][dataset_attr.content_tag] messages = messages[1:] else: system = example[dataset_attr.system] if dataset_attr.system else "" diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 229d2345f..61e2a4192 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -94,22 +94,30 @@ def split_dataset( return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) -def merge_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]: +def continuous_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]: # merge obversation messages new_messages = [] waiting_message = [] def append_waiting_message(): if len(waiting_message) == 1: - new_messages.append(waiting_message[0]) + new_messages.append( + { + dataset_attr.role_tag: waiting_message[0][dataset_attr.role_tag], + dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message], + } + ) else: - assert waiting_message[0]["role"] == dataset_attr.observation_tag + assert waiting_message[0][dataset_attr.role_tag] == dataset_attr.observation_tag new_messages.append( - {"role": dataset_attr.observation_tag, "content": [m["content"] for m in waiting_message]} + { + dataset_attr.role_tag: dataset_attr.observation_tag, + dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message], + } ) for message in messages: - if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]: + if len(waiting_message) > 0 and message[dataset_attr.role_tag] != waiting_message[-1][dataset_attr.role_tag]: append_waiting_message() waiting_message = [] waiting_message.append(message) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index bd7642dd4..d9aa5de71 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -117,9 +117,11 @@ def _encode( elements += self.format_separator.apply() if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_user.apply(content=content, idx=str(i // 2)) elif message["role"] == Role.ASSISTANT.value: - elements += self.format_assistant.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_assistant.apply(content=content) elif message["role"] == Role.OBSERVATION.value: if isinstance(message["content"], list): for m in message["content"]: @@ -127,7 +129,8 @@ def _encode( else: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: - elements += self.format_function.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_function.apply(content=content) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -156,6 +159,7 @@ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: " return token_ids + @dataclass class Llama2Template(Template): @override @@ -187,9 +191,11 @@ def _encode( elements += self.format_separator.apply() if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=system_text + message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_user.apply(content=system_text + content) elif message["role"] == Role.ASSISTANT.value: - elements += self.format_assistant.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_assistant.apply(content=content) elif message["role"] == Role.OBSERVATION.value: if isinstance(message["content"], list): for m in message["content"]: @@ -197,7 +203,8 @@ def _encode( else: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: - elements += self.format_function.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_function.apply(content=content) else: raise NotImplementedError("Unexpected role: {}".format(message["role"]))