From 3aa6664cb493e2ebe9c14cd9ae097ae61c2fcc0a Mon Sep 17 00:00:00 2001 From: Feng Yunlong <20281571+AlongWY@users.noreply.github.com> Date: Wed, 25 Dec 2024 06:47:44 +0000 Subject: [PATCH 1/8] support continuous obvervation and optional pre-cutoff --- .../data/processors/supervised.py | 15 ++++--- src/llamafactory/data/template.py | 39 ++++++++++++++++++- src/llamafactory/hparams/data_args.py | 4 ++ 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 83bd8ba2a7..2c1df29ce5 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -53,13 +53,16 @@ def _encode_supervised_example( encoded_pairs = encoded_pairs[::-1] # high priority for last turns for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): - if total_length >= cutoff_len: + if total_length >= cutoff_len and cutoff_len > 0: break - source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) - source_ids = source_ids[:source_len] - target_ids = target_ids[:target_len] - total_length += source_len + target_len + if cutoff_len > 0: + source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + else: + source_len, target_len = len(source_ids), len(target_ids) if train_on_prompt: source_label = source_ids @@ -158,7 +161,7 @@ def preprocess_packed_supervised_dataset( template=template, tokenizer=tokenizer, processor=processor, - cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token + cutoff_len=data_args.cutoff_len - 1 if data_args.allow_truncation else 0, # reserved for the padding token train_on_prompt=data_args.train_on_prompt, mask_history=data_args.mask_history, ) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 4c3b70bfea..2127f9577b 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -104,6 +104,7 @@ def _encode( """ system = system or self.default_system encoded_messages = [] + messages = self.merge_messages(messages) for i, message in enumerate(messages): elements = [] @@ -121,7 +122,11 @@ def _encode( elif message["role"] == Role.ASSISTANT.value: elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION.value: - elements += self.format_observation.apply(content=message["content"]) + if isinstance(message["content"], list): + for m in message["content"]: + elements += self.format_observation.apply(content=m) + else: + elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: elements += self.format_function.apply(content=message["content"]) else: @@ -152,6 +157,31 @@ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: " return token_ids + @staticmethod + def merge_messages(messages: Sequence[Dict[str, str]]) -> List[Dict[str, str]]: + # merge obversation messages + new_messages = [] + waiting_message = [] + + def append_waiting_message(): + if len(waiting_message) == 1: + new_messages.append(waiting_message[0]) + else: + assert waiting_message[0]["role"] == Role.OBSERVATION.value + new_messages.append( + {"role": Role.OBSERVATION.value, "content": [m["content"] for m in waiting_message]} + ) + for message in messages: + if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]: + append_waiting_message() + waiting_message = [] + waiting_message.append(message) + + if len(waiting_message) > 0: + append_waiting_message() + + return new_messages + @dataclass class Llama2Template(Template): @@ -170,6 +200,7 @@ def _encode( """ system = system or self.default_system encoded_messages = [] + messages = self.merge_messages(messages) for i, message in enumerate(messages): elements = [] @@ -188,7 +219,11 @@ def _encode( elif message["role"] == Role.ASSISTANT.value: elements += self.format_assistant.apply(content=message["content"]) elif message["role"] == Role.OBSERVATION.value: - elements += self.format_observation.apply(content=message["content"]) + if isinstance(message["content"], list): + for m in message["content"]: + elements += self.format_observation.apply(content=m) + else: + elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: elements += self.format_function.apply(content=message["content"]) else: diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index a33e626773..22479d76ce 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -109,6 +109,10 @@ class DataArguments: default=False, metadata={"help": "Enable sequence packing without cross-attention."}, ) + allow_truncation: bool = field( + default=False, + metadata={"help": "Allow truncation when processing supervised examples."} + ) tool_format: Optional[str] = field( default=None, metadata={"help": "Tool format to use for constructing function calling examples."}, From e1b957d4071f77538f0f8b835e73858939c49620 Mon Sep 17 00:00:00 2001 From: Feng Yunlong <20281571+AlongWY@users.noreply.github.com> Date: Wed, 25 Dec 2024 11:14:48 +0000 Subject: [PATCH 2/8] move merge_messages to aligner --- src/llamafactory/data/aligner.py | 4 ++-- src/llamafactory/data/data_utils.py | 28 ++++++++++++++++++++++++++++ src/llamafactory/data/template.py | 28 ---------------------------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 82bbfafb2d..a87dfdb7e7 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from ..extras import logging -from .data_utils import Role +from .data_utils import Role, merge_messages if TYPE_CHECKING: @@ -152,7 +152,7 @@ def convert_sharegpt( odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) accept_tags = (odd_tags, even_tags) - messages = example[dataset_attr.messages] + messages = merge_messages(example[dataset_attr.messages], dataset_attr) if ( dataset_attr.system_tag and len(messages) != 0 diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index cbce026c8d..229d2345fa 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -17,6 +17,8 @@ from datasets import DatasetDict, concatenate_datasets, interleave_datasets +from .parser import DatasetAttr + from ..extras import logging @@ -90,3 +92,29 @@ def split_dataset( val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size dataset = dataset.train_test_split(test_size=val_size, seed=seed) return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) + + +def merge_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]: + # merge obversation messages + new_messages = [] + waiting_message = [] + + def append_waiting_message(): + if len(waiting_message) == 1: + new_messages.append(waiting_message[0]) + else: + assert waiting_message[0]["role"] == dataset_attr.observation_tag + new_messages.append( + {"role": dataset_attr.observation_tag, "content": [m["content"] for m in waiting_message]} + ) + + for message in messages: + if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]: + append_waiting_message() + waiting_message = [] + waiting_message.append(message) + + if len(waiting_message) > 0: + append_waiting_message() + + return new_messages diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 2127f9577b..bd7642dd4c 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -104,7 +104,6 @@ def _encode( """ system = system or self.default_system encoded_messages = [] - messages = self.merge_messages(messages) for i, message in enumerate(messages): elements = [] @@ -157,32 +156,6 @@ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: " return token_ids - @staticmethod - def merge_messages(messages: Sequence[Dict[str, str]]) -> List[Dict[str, str]]: - # merge obversation messages - new_messages = [] - waiting_message = [] - - def append_waiting_message(): - if len(waiting_message) == 1: - new_messages.append(waiting_message[0]) - else: - assert waiting_message[0]["role"] == Role.OBSERVATION.value - new_messages.append( - {"role": Role.OBSERVATION.value, "content": [m["content"] for m in waiting_message]} - ) - for message in messages: - if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]: - append_waiting_message() - waiting_message = [] - waiting_message.append(message) - - if len(waiting_message) > 0: - append_waiting_message() - - return new_messages - - @dataclass class Llama2Template(Template): @override @@ -200,7 +173,6 @@ def _encode( """ system = system or self.default_system encoded_messages = [] - messages = self.merge_messages(messages) for i, message in enumerate(messages): elements = [] From 5050fad50ac9ae2269f4d46badc25aac43fac7ba Mon Sep 17 00:00:00 2001 From: Feng Yunlong <20281571+AlongWY@users.noreply.github.com> Date: Thu, 26 Dec 2024 06:37:38 +0000 Subject: [PATCH 3/8] fix processing error --- src/llamafactory/data/aligner.py | 9 ++++++--- src/llamafactory/data/data_utils.py | 18 +++++++++++++----- src/llamafactory/data/template.py | 19 +++++++++++++------ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index a87dfdb7e7..c0d84575c5 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from ..extras import logging -from .data_utils import Role, merge_messages +from .data_utils import Role, continuous_messages if TYPE_CHECKING: @@ -152,13 +152,16 @@ def convert_sharegpt( odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) accept_tags = (odd_tags, even_tags) - messages = merge_messages(example[dataset_attr.messages], dataset_attr) + messages = continuous_messages(example[dataset_attr.messages], dataset_attr) if ( dataset_attr.system_tag and len(messages) != 0 and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag ): - system = messages[0][dataset_attr.content_tag] + if isinstance(messages[0][dataset_attr.content_tag], list): + system = "".join(messages[0][dataset_attr.content_tag]) + else: + system = messages[0][dataset_attr.content_tag] messages = messages[1:] else: system = example[dataset_attr.system] if dataset_attr.system else "" diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 229d2345fa..61e2a41929 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -94,22 +94,30 @@ def split_dataset( return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) -def merge_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]: +def continuous_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]: # merge obversation messages new_messages = [] waiting_message = [] def append_waiting_message(): if len(waiting_message) == 1: - new_messages.append(waiting_message[0]) + new_messages.append( + { + dataset_attr.role_tag: waiting_message[0][dataset_attr.role_tag], + dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message], + } + ) else: - assert waiting_message[0]["role"] == dataset_attr.observation_tag + assert waiting_message[0][dataset_attr.role_tag] == dataset_attr.observation_tag new_messages.append( - {"role": dataset_attr.observation_tag, "content": [m["content"] for m in waiting_message]} + { + dataset_attr.role_tag: dataset_attr.observation_tag, + dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message], + } ) for message in messages: - if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]: + if len(waiting_message) > 0 and message[dataset_attr.role_tag] != waiting_message[-1][dataset_attr.role_tag]: append_waiting_message() waiting_message = [] waiting_message.append(message) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index bd7642dd4c..d9aa5de713 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -117,9 +117,11 @@ def _encode( elements += self.format_separator.apply() if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_user.apply(content=content, idx=str(i // 2)) elif message["role"] == Role.ASSISTANT.value: - elements += self.format_assistant.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_assistant.apply(content=content) elif message["role"] == Role.OBSERVATION.value: if isinstance(message["content"], list): for m in message["content"]: @@ -127,7 +129,8 @@ def _encode( else: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: - elements += self.format_function.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_function.apply(content=content) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -156,6 +159,7 @@ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: " return token_ids + @dataclass class Llama2Template(Template): @override @@ -187,9 +191,11 @@ def _encode( elements += self.format_separator.apply() if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=system_text + message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_user.apply(content=system_text + content) elif message["role"] == Role.ASSISTANT.value: - elements += self.format_assistant.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_assistant.apply(content=content) elif message["role"] == Role.OBSERVATION.value: if isinstance(message["content"], list): for m in message["content"]: @@ -197,7 +203,8 @@ def _encode( else: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: - elements += self.format_function.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_function.apply(content=content) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) From 95bb98634b9fa3c9b880e2ec8c607502823a5a9f Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 27 Dec 2024 00:43:59 +0800 Subject: [PATCH 4/8] fix tests broken --- src/llamafactory/data/data_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 61e2a41929..1f9c14df01 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -125,4 +125,6 @@ def append_waiting_message(): if len(waiting_message) > 0: append_waiting_message() + if len(new_messages) == len(messages): + return messages return new_messages From de906952698b5bf7666be470892e88e15ec94baf Mon Sep 17 00:00:00 2001 From: Feng Yunlong <20281571+AlongWY@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:53:27 +0800 Subject: [PATCH 5/8] fix endless --- src/llamafactory/data/processors/processor_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 8e13d100bc..a09dd2af64 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -31,6 +31,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: numbers.sort() # sort numbers in ascending order for binary search knapsacks = [] + # filter out numbers that are larger than the capacity + numbers = [number for number in numbers if number <= capacity] while numbers: current_knapsack = [] remaining_capacity = capacity From d050bbf87d3311259a466f3de0bc907a9ca1e5b9 Mon Sep 17 00:00:00 2001 From: Feng Yunlong <20281571+AlongWY@users.noreply.github.com> Date: Fri, 27 Dec 2024 15:06:30 +0800 Subject: [PATCH 6/8] Update data_utils.py --- src/llamafactory/data/data_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 1f9c14df01..ce401f5c47 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -125,6 +125,7 @@ def append_waiting_message(): if len(waiting_message) > 0: append_waiting_message() - if len(new_messages) == len(messages): - return messages + # must all messages no continuous roles + # if len(new_messages) == len(messages): + # return messages return new_messages From 713d08879c3f64808118be243f31f95525cdbd73 Mon Sep 17 00:00:00 2001 From: Feng Yunlong <20281571+AlongWY@users.noreply.github.com> Date: Fri, 27 Dec 2024 15:16:15 +0800 Subject: [PATCH 7/8] avoid endless --- src/llamafactory/data/processors/processor_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index a09dd2af64..f4107a76fa 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -45,6 +45,9 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: remaining_capacity -= numbers[index] # update the remaining capacity current_knapsack.append(numbers.pop(index)) # add the number to knapsack + if remaining_capacity == capacity: + break + knapsacks.append(current_knapsack) return knapsacks From 73d96545b8cec7bdae678b951b9a17e4c928a5a3 Mon Sep 17 00:00:00 2001 From: Feng Yunlong <20281571+AlongWY@users.noreply.github.com> Date: Fri, 27 Dec 2024 07:24:14 +0000 Subject: [PATCH 8/8] update tests --- tests/data/processors/test_feedback.py | 1 + tests/data/processors/test_pairwise.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/data/processors/test_feedback.py b/tests/data/processors/test_feedback.py index c04e823b75..b1cecfb988 100644 --- a/tests/data/processors/test_feedback.py +++ b/tests/data/processors/test_feedback.py @@ -51,6 +51,7 @@ def test_feedback_data(num_samples: int): indexes = random.choices(range(len(original_data)), k=num_samples) for index in indexes: messages = original_data["messages"][index] + messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in messages] ref_input_ids = ref_tokenizer.apply_chat_template(messages) prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)) ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:] diff --git a/tests/data/processors/test_pairwise.py b/tests/data/processors/test_pairwise.py index da50ca2426..079edcdac8 100644 --- a/tests/data/processors/test_pairwise.py +++ b/tests/data/processors/test_pairwise.py @@ -64,6 +64,8 @@ def test_pairwise_data(num_samples: int): rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]] chosen_messages = _convert_sharegpt_to_openai(chosen_messages) rejected_messages = _convert_sharegpt_to_openai(rejected_messages) + chosen_messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in chosen_messages] + rejected_messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in rejected_messages] ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages) chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)) ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]