From 62e462cd77b08675757f17619a5161797cdc9519 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 15 Aug 2023 21:00:10 +0530 Subject: [PATCH 1/3] added orca-best --- .../model_training/custom_datasets/prompt_dialogue.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/model/model_training/custom_datasets/prompt_dialogue.py b/model/model_training/custom_datasets/prompt_dialogue.py index 1d30458cb3..48cec721ab 100644 --- a/model/model_training/custom_datasets/prompt_dialogue.py +++ b/model/model_training/custom_datasets/prompt_dialogue.py @@ -172,15 +172,18 @@ def __getitem__(self, index: int) -> list[str] | tuple[str]: class OrcaChat(Dataset): name = "orca-chat" - def __init__(self, data_files: Union[List[str], str] = "orca-chat-gpt4.json", cache_dir: str = None) -> None: - self.dataset = load_dataset("shahules786/orca-chat", split="train", data_files=data_files, cache_dir=cache_dir) - + def __init__(self, rows_per_conv: int = 1, use_auth_token: Optional[Union[bool, str]] = None, cache_dir: str = None) -> None: + self.dataset = load_dataset("shahules786/orca-best", split="train", + use_auth_token=use_auth_token, + cache_dir=cache_dir) + self.rows_per_conv = rows_per_conv + def __len__(self): return len(self.dataset) def __getitem__(self, idx): conversation, instruction = [self.dataset[idx][key] for key in ("conversation", "instruction")] - conversation = [(item["input"], item["output"]) for item in conversation] + conversation = [(item["input"], item["output"]) for item in conversation["samples"][:self.rows_per_conv]] conversation = list(sum(conversation, ())) conv_utt: list[Utterance] = [ ( From 01b1111b215b4ee57f715ecf182b9baa412bcea5 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Tue, 15 Aug 2023 22:49:25 +0530 Subject: [PATCH 2/3] patch for unkwn error --- model/model_training/custom_datasets/prompt_dialogue.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model/model_training/custom_datasets/prompt_dialogue.py b/model/model_training/custom_datasets/prompt_dialogue.py index 48cec721ab..3d6dc806e3 100644 --- a/model/model_training/custom_datasets/prompt_dialogue.py +++ b/model/model_training/custom_datasets/prompt_dialogue.py @@ -11,7 +11,8 @@ from model_training.custom_datasets.utils import _filter_by_words from torch import Generator, randperm from torch.utils.data import Dataset, random_split - +import datasets +datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory='.': True def load_oig_file( source_url: str, From edd63bdc798255e5607923cb3580d213ef044fbf Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Sat, 19 Aug 2023 17:14:23 +0530 Subject: [PATCH 3/3] aded best of megacode --- .../custom_datasets/__init__.py | 4 ++- .../custom_datasets/prompt_dialogue.py | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/model/model_training/custom_datasets/__init__.py b/model/model_training/custom_datasets/__init__.py index 4c66d06008..f1b8c92082 100644 --- a/model/model_training/custom_datasets/__init__.py +++ b/model/model_training/custom_datasets/__init__.py @@ -13,7 +13,7 @@ ) from model_training.custom_datasets.oasst_dataset import load_oasst_export from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama -from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file +from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file, BestOfMegacode from model_training.custom_datasets.qa_datasets import ( SODA, AlpacaGpt4, @@ -188,6 +188,8 @@ def get_one_dataset( dataset = DolphinMix(cache_dir=data_path, **kwargs) elif dataset_name in RAG_DATASETS.keys(): dataset = RAGDataset(dataset_name, cache_dir=data_path, **kwargs) + elif dataset_name == "bestofmegacode": + dataset = BestOfMegacode(cache_dir=data_path, **kwargs) else: raise ValueError(f"Unknown dataset {dataset_name}") diff --git a/model/model_training/custom_datasets/prompt_dialogue.py b/model/model_training/custom_datasets/prompt_dialogue.py index 3d6dc806e3..d573798141 100644 --- a/model/model_training/custom_datasets/prompt_dialogue.py +++ b/model/model_training/custom_datasets/prompt_dialogue.py @@ -199,6 +199,34 @@ def __getitem__(self, idx): return DatasetEntrySft(conversation=conv_utt, system_message=instruction) +class BestOfMegacode(Dataset): + name = "bestofmegacode" + + def __init__(self, rows_per_conv: int = 1, use_auth_token: Optional[Union[bool, str]] = None, cache_dir: str = None) -> None: + self.dataset = load_dataset("shahules786/megacode-best", split="train", + use_auth_token=use_auth_token, + cache_dir=cache_dir) + self.rows_per_conv = rows_per_conv + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + conversation = [self.dataset[idx][key] for key in ("conversation")] + conversation = [(item["USER"], item["ASSISTANT"]) for item in conversation["samples"][:self.rows_per_conv]] + conversation = list(sum(conversation, ())) + conv_utt: list[Utterance] = [ + ( + Utterance( + text=conv, + role=Role.prompter if i % 2 == 0 else Role.assistant, + ) + ) + for i, conv in enumerate(conversation) + ] + + return DatasetEntrySft(conversation=conv_utt) + class DolphinMix(Dataset): name = "dophin-mix"