Skip to content

Commit

Permalink
fix processing error
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Dec 26, 2024
1 parent e1b957d commit 5050fad
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
9 changes: 6 additions & 3 deletions src/llamafactory/data/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from ..extras import logging
from .data_utils import Role, merge_messages
from .data_utils import Role, continuous_messages


if TYPE_CHECKING:
Expand Down Expand Up @@ -152,13 +152,16 @@ def convert_sharegpt(
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
messages = merge_messages(example[dataset_attr.messages], dataset_attr)
messages = continuous_messages(example[dataset_attr.messages], dataset_attr)
if (
dataset_attr.system_tag
and len(messages) != 0
and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag
):
system = messages[0][dataset_attr.content_tag]
if isinstance(messages[0][dataset_attr.content_tag], list):
system = "".join(messages[0][dataset_attr.content_tag])
else:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = example[dataset_attr.system] if dataset_attr.system else ""
Expand Down
18 changes: 13 additions & 5 deletions src/llamafactory/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,30 @@ def split_dataset(
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})


def merge_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]:
def continuous_messages(messages: Sequence[Dict[str, str]], dataset_attr: "DatasetAttr") -> List[Dict[str, str]]:
# merge obversation messages
new_messages = []
waiting_message = []

def append_waiting_message():
if len(waiting_message) == 1:
new_messages.append(waiting_message[0])
new_messages.append(
{
dataset_attr.role_tag: waiting_message[0][dataset_attr.role_tag],
dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message],
}
)
else:
assert waiting_message[0]["role"] == dataset_attr.observation_tag
assert waiting_message[0][dataset_attr.role_tag] == dataset_attr.observation_tag
new_messages.append(
{"role": dataset_attr.observation_tag, "content": [m["content"] for m in waiting_message]}
{
dataset_attr.role_tag: dataset_attr.observation_tag,
dataset_attr.content_tag: [m[dataset_attr.content_tag] for m in waiting_message],
}
)

for message in messages:
if len(waiting_message) > 0 and message["role"] != waiting_message[-1]["role"]:
if len(waiting_message) > 0 and message[dataset_attr.role_tag] != waiting_message[-1][dataset_attr.role_tag]:
append_waiting_message()
waiting_message = []
waiting_message.append(message)
Expand Down
19 changes: 13 additions & 6 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,20 @@ def _encode(
elements += self.format_separator.apply()

if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_user.apply(content=content, idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION.value:
if isinstance(message["content"], list):
for m in message["content"]:
elements += self.format_observation.apply(content=m)
else:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_function.apply(content=content)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))

Expand Down Expand Up @@ -156,6 +159,7 @@ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "

return token_ids


@dataclass
class Llama2Template(Template):
@override
Expand Down Expand Up @@ -187,17 +191,20 @@ def _encode(
elements += self.format_separator.apply()

if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=system_text + message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_user.apply(content=system_text + content)
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_assistant.apply(content=content)
elif message["role"] == Role.OBSERVATION.value:
if isinstance(message["content"], list):
for m in message["content"]:
elements += self.format_observation.apply(content=m)
else:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
content = "".join(message["content"]) if isinstance(message["content"], list) else message["content"]
elements += self.format_function.apply(content=content)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))

Expand Down

0 comments on commit 5050fad

Please sign in to comment.