Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support continuous obvervation and optional pre-cutoff #6441

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/llamafactory/data/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from ..extras import logging
from .data_utils import Role
from .data_utils import Role, continuous_messages


if TYPE_CHECKING:
Expand Down Expand Up @@ -152,13 +152,16 @@ def convert_sharegpt(
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[dataset_attr.messages]
messages = continuous_messages(example[dataset_attr.messages], dataset_attr)
if (
dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
if isinstance(messages[0][dataset_attr.content_tag], list):
system = "".join(messages[0][dataset_attr.content_tag])
else:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
Expand Down
39 changes: 39 additions & 0 deletions src/llamafactory/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from datasets import DatasetDict, concatenate_datasets, interleave_datasets

from .parser import DatasetAttr

from ..extras import logging


Expand Down Expand Up @@ -90,3 +92,40 @@ def split_dataset(
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})


def continuous_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]:
# merge obversation messages
new_messages = []
waiting_message = []

def append_waiting_message():
if len(waiting_message) == 1:
new_messages.append(
{
dataset_attr.role_tag: waiting_message[0][dataset_attr.role_tag],
dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message],
}
)
else:
assert waiting_message[0][dataset_attr.role_tag] == dataset_attr.observation_tag
new_messages.append(
{
dataset_attr.role_tag: dataset_attr.observation_tag,
dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message],
}
)

for message in messages:
if len(waiting_message) > 0 and message[dataset_attr.role_tag] != waiting_message[-1][dataset_attr.role_tag]:
append_waiting_message()
waiting_message = []
waiting_message.append(message)

if len(waiting_message) > 0:
append_waiting_message()

# must all messages no continuous roles
# if len(new_messages) == len(messages):
# return messages
return new_messages
5 changes: 5 additions & 0 deletions src/llamafactory/data/processors/processor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
numbers.sort() # sort numbers in ascending order for binary search
knapsacks = []

# filter out numbers that are larger than the capacity
numbers = [number for number in numbers if number <= capacity]
while numbers:
current_knapsack = []
remaining_capacity = capacity
Expand All @@ -43,6 +45,9 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers.pop(index)) # add the number to knapsack

if remaining_capacity == capacity:
break

knapsacks.append(current_knapsack)

return knapsacks
Expand Down
15 changes: 9 additions & 6 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ def _encode_supervised_example(
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
if total_length >= cutoff_len and cutoff_len > 0:
break

source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
if cutoff_len > 0:
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
total_length += source_len + target_len
else:
source_len, target_len = len(source_ids), len(target_ids)

if train_on_prompt:
source_label = source_ids
Expand Down Expand Up @@ -158,7 +161,7 @@ def preprocess_packed_supervised_dataset(
template=template,
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
cutoff_len=data_args.cutoff_len - 1 if data_args.allow_truncation else 0, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
Expand Down
30 changes: 22 additions & 8 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,20 @@ def _encode(
elements += self.format_separator.apply()

if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_user.apply(content=content, idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
if isinstance(message["content"], list):
for m in message["content"]:
elements += self.format_observation.apply(content=m)
else:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_function.apply(content=content)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))

Expand Down Expand Up @@ -184,13 +191,20 @@ def _encode(
elements += self.format_separator.apply()

if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_user.apply(content=system_text + content)
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
if isinstance(message["content"], list):
for m in message["content"]:
elements += self.format_observation.apply(content=m)
else:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_function.apply(content=content)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))

Expand Down
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ class DataArguments:
default=False,
metadata={"help": "Enable sequence packing without cross-attention."},
)
allow_truncation: bool = field(
default=False,
metadata={"help": "Allow truncation when processing supervised examples."}
)
tool_format: Optional[str] = field(
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
Expand Down
1 change: 1 addition & 0 deletions tests/data/processors/test_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_feedback_data(num_samples: int):
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
messages = original_data["messages"][index]
messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in messages]
ref_input_ids = ref_tokenizer.apply_chat_template(messages)
prompt_len = len(ref_tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True))
ref_labels = [IGNORE_INDEX] * prompt_len + ref_input_ids[prompt_len:]
Expand Down
2 changes: 2 additions & 0 deletions tests/data/processors/test_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def test_pairwise_data(num_samples: int):
rejected_messages = original_data["conversations"][index] + [original_data["rejected"][index]]
chosen_messages = _convert_sharegpt_to_openai(chosen_messages)
rejected_messages = _convert_sharegpt_to_openai(rejected_messages)
chosen_messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in chosen_messages]
rejected_messages = [{"role": msg['role'], "content": msg['content'][0]} for msg in rejected_messages]
ref_chosen_input_ids = ref_tokenizer.apply_chat_template(chosen_messages)
chosen_prompt_len = len(ref_tokenizer.apply_chat_template(chosen_messages[:-1], add_generation_prompt=True))
ref_chosen_labels = [IGNORE_INDEX] * chosen_prompt_len + ref_chosen_input_ids[chosen_prompt_len:]
Expand Down