Skip to content

Commit

Permalink
move merge_messages to aligner
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Dec 25, 2024
1 parent 3aa6664 commit e1b957d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/llamafactory/data/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from ..extras import logging
from .data_utils import Role
from .data_utils import Role, merge_messages


if TYPE_CHECKING:
Expand Down Expand Up @@ -152,7 +152,7 @@ def convert_sharegpt(
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = example[dataset_attr.messages]
messages = merge_messages(example[dataset_attr.messages], dataset_attr)
if (
dataset_attr.system_tag
and len(messages) != 0
Expand Down
28 changes: 28 additions & 0 deletions src/llamafactory/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from datasets import DatasetDict, concatenate_datasets, interleave_datasets

from .parser import DatasetAttr

from ..extras import logging


Expand Down Expand Up @@ -90,3 +92,29 @@ def split_dataset(
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})


def merge_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]:
# merge obversation messages
new_messages = []
waiting_message = []

def append_waiting_message():
if len(waiting_message) == 1:
new_messages.append(waiting_message[0])
else:
assert waiting_message[0]["role"] == dataset_attr.observation_tag
new_messages.append(
{"role": dataset_attr.observation_tag, "content": [m["content"] for m in waiting_message]}
)

for message in messages:
if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]:
append_waiting_message()
waiting_message = []
waiting_message.append(message)

if len(waiting_message) > 0:
append_waiting_message()

return new_messages
28 changes: 0 additions & 28 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def _encode(
"""
system = system or self.default_system
encoded_messages = []
messages = self.merge_messages(messages)
for i, message in enumerate(messages):
elements = []

Expand Down Expand Up @@ -157,32 +156,6 @@ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "

return token_ids

@staticmethod
def merge_messages(messages: Sequence[Dict[str, str]]) -> List[Dict[str, str]]:
# merge obversation messages
new_messages = []
waiting_message = []

def append_waiting_message():
if len(waiting_message) == 1:
new_messages.append(waiting_message[0])
else:
assert waiting_message[0]["role"] == Role.OBSERVATION.value
new_messages.append(
{"role": Role.OBSERVATION.value, "content": [m["content"] for m in waiting_message]}
)
for message in messages:
if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]:
append_waiting_message()
waiting_message = []
waiting_message.append(message)

if len(waiting_message) > 0:
append_waiting_message()

return new_messages


@dataclass
class Llama2Template(Template):
@override
Expand All @@ -200,7 +173,6 @@ def _encode(
"""
system = system or self.default_system
encoded_messages = []
messages = self.merge_messages(messages)
for i, message in enumerate(messages):
elements = []

Expand Down

0 comments on commit e1b957d

Please sign in to comment.