diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 82bbfafb2d..c0d84575c5 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from ..extras import logging -from .data_utils import Role +from .data_utils import Role, continuous_messages if TYPE_CHECKING: @@ -152,13 +152,16 @@ def convert_sharegpt( odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag) even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag) accept_tags = (odd_tags, even_tags) - messages = example[dataset_attr.messages] + messages = continuous_messages(example[dataset_attr.messages], dataset_attr) if ( dataset_attr.system_tag and len(messages) != 0 and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag ): - system = messages[0][dataset_attr.content_tag] + if isinstance(messages[0][dataset_attr.content_tag], list): + system = "".join(messages[0][dataset_attr.content_tag]) + else: + system = messages[0][dataset_attr.content_tag] messages = messages[1:] else: system = example[dataset_attr.system] if dataset_attr.system else "" diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index cbce026c8d..ce401f5c47 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -17,6 +17,8 @@ from datasets import DatasetDict, concatenate_datasets, interleave_datasets +from .parser import DatasetAttr + from ..extras import logging @@ -90,3 +92,40 @@ def split_dataset( val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size dataset = dataset.train_test_split(test_size=val_size, seed=seed) return DatasetDict({"train": dataset["train"], "validation": dataset["test"]}) + + +def continuous_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]: + # merge obversation messages + new_messages = [] + waiting_message = [] + + def append_waiting_message(): + if len(waiting_message) == 1: + new_messages.append( + { + dataset_attr.role_tag: waiting_message[0][dataset_attr.role_tag], + dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message], + } + ) + else: + assert waiting_message[0][dataset_attr.role_tag] == dataset_attr.observation_tag + new_messages.append( + { + dataset_attr.role_tag: dataset_attr.observation_tag, + dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message], + } + ) + + for message in messages: + if len(waiting_message) > 0 and message[dataset_attr.role_tag] != waiting_message[-1][dataset_attr.role_tag]: + append_waiting_message() + waiting_message = [] + waiting_message.append(message) + + if len(waiting_message) > 0: + append_waiting_message() + + # must all messages no continuous roles + # if len(new_messages) == len(messages): + # return messages + return new_messages diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 8e13d100bc..f4107a76fa 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -31,6 +31,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: numbers.sort() # sort numbers in ascending order for binary search knapsacks = [] + # filter out numbers that are larger than the capacity + numbers = [number for number in numbers if number <= capacity] while numbers: current_knapsack = [] remaining_capacity = capacity @@ -43,6 +45,9 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: remaining_capacity -= numbers[index] # update the remaining capacity current_knapsack.append(numbers.pop(index)) # add the number to knapsack + if remaining_capacity == capacity: + break + knapsacks.append(current_knapsack) return knapsacks diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 83bd8ba2a7..2c1df29ce5 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -53,13 +53,16 @@ def _encode_supervised_example( encoded_pairs = encoded_pairs[::-1] # high priority for last turns for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): - if total_length >= cutoff_len: + if total_length >= cutoff_len and cutoff_len > 0: break - source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) - source_ids = source_ids[:source_len] - target_ids = target_ids[:target_len] - total_length += source_len + target_len + if cutoff_len > 0: + source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) + source_ids = source_ids[:source_len] + target_ids = target_ids[:target_len] + total_length += source_len + target_len + else: + source_len, target_len = len(source_ids), len(target_ids) if train_on_prompt: source_label = source_ids @@ -158,7 +161,7 @@ def preprocess_packed_supervised_dataset( template=template, tokenizer=tokenizer, processor=processor, - cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token + cutoff_len=data_args.cutoff_len - 1 if data_args.allow_truncation else 0, # reserved for the padding token train_on_prompt=data_args.train_on_prompt, mask_history=data_args.mask_history, ) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 4c3b70bfea..d9aa5de713 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -117,13 +117,20 @@ def _encode( elements += self.format_separator.apply() if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_user.apply(content=content, idx=str(i // 2)) elif message["role"] == Role.ASSISTANT.value: - elements += self.format_assistant.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_assistant.apply(content=content) elif message["role"] == Role.OBSERVATION.value: - elements += self.format_observation.apply(content=message["content"]) + if isinstance(message["content"], list): + for m in message["content"]: + elements += self.format_observation.apply(content=m) + else: + elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: - elements += self.format_function.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_function.apply(content=content) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -184,13 +191,20 @@ def _encode( elements += self.format_separator.apply() if message["role"] == Role.USER.value: - elements += self.format_user.apply(content=system_text + message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_user.apply(content=system_text + content) elif message["role"] == Role.ASSISTANT.value: - elements += self.format_assistant.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_assistant.apply(content=content) elif message["role"] == Role.OBSERVATION.value: - elements += self.format_observation.apply(content=message["content"]) + if isinstance(message["content"], list): + for m in message["content"]: + elements += self.format_observation.apply(content=m) + else: + elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION.value: - elements += self.format_function.apply(content=message["content"]) + content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"] + elements += self.format_function.apply(content=content) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index a33e626773..22479d76ce 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -109,6 +109,10 @@ class DataArguments: default=False, metadata={"help": "Enable sequence packing without cross-attention."}, ) + allow_truncation: bool = field( + default=False, + metadata={"help": "Allow truncation when processing supervised examples."} + ) tool_format: Optional[str] = field( default=None, metadata={"help": "Tool format to use for constructing function calling examples."}, diff --git a/tests/data/processors/test_feedback.py b/tests/data/processors/test_feedback.py index c04e823b75..b1cecfb988 100644 --- a/tests/data/processors/test_feedback.py +++ b/tests/data/processors/test_feedback.py @@ -51,6 +51,7 @@ def test_feedback_data(num_samples: int): indexes = random.choices(range(len(original_data)), k=num_samples) for index in indexes: messages = original_data["messages"][index] + messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in messages] ref_input_ids = ref_tokenizer.apply_chat_template(messages) prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True)) ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:] diff --git a/tests/data/processors/test_pairwise.py b/tests/data/processors/test_pairwise.py index da50ca2426..079edcdac8 100644 --- a/tests/data/processors/test_pairwise.py +++ b/tests/data/processors/test_pairwise.py @@ -64,6 +64,8 @@ def test_pairwise_data(num_samples: int): rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]] chosen_messages = _convert_sharegpt_to_openai(chosen_messages) rejected_messages = _convert_sharegpt_to_openai(rejected_messages) + chosen_messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in chosen_messages] + rejected_messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in rejected_messages] ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages) chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True)) ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]