Skip to content

Commit

Permalink
Modify train_val_test_split to invert from modality/train to train/mo…
Browse files Browse the repository at this point in the history
…dality structure
  • Loading branch information
kdu4108 committed Aug 1, 2024
1 parent 8508f83 commit 7cdfe15
Showing 1 changed file with 41 additions and 4 deletions.
45 changes: 41 additions & 4 deletions pseudolabeling/train_val_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main(args):
for d in os.listdir(args.source_dir)
if os.path.isdir(os.path.join(args.source_dir, d))
]
print(f"Subdirectories of source dir {args.source_dir}: {dset_dirs}")

for dset_dir in dset_dirs:
dset_path = os.path.join(args.source_dir, dset_dir)
Expand Down Expand Up @@ -44,16 +45,46 @@ def main(args):
for dataset, files in zip(
["train", "val", "test"], [train_files, val_files, test_files]
):
split_path = os.path.join(args.output_dir, dset_dir, dataset)
split_path = os.path.join(args.output_dir, dataset, dset_dir)
print(f"Move {dset_path} -----------> {split_path}")
os.makedirs(split_path, exist_ok=True) # Create class directory in split
for file in files:
shutil.move(
os.path.join(dset_path, file), os.path.join(split_path, file)
)
if args.copy:
shutil.copy(os.path.join(dset_path, file), os.path.join(split_path, file))
else:
shutil.move(
os.path.join(dset_path, file), os.path.join(split_path, file)
)


if __name__ == "__main__":
"""
Given a source directory containing the data for multiple modalities, e.g.,
```
|--source_dir/
| |--modality_a/
| |--modality_b/
| |--modality_c/
```
move the files into a specified output_dir/ with the structure:
```
|--source_dir/
| |--train/
| | |--modality_a/
| | |--modality_b/
| | |--modality_c/
| |--val/
| | |--modality_a/
| | |--modality_b/
| | |--modality_c/
| |--test/
| | |--modality_a/
| | |--modality_b/
| | |--modality_c/
```
"""
parser = argparse.ArgumentParser(
description="Partition datasets into train, val, and test splits."
)
Expand Down Expand Up @@ -87,6 +118,12 @@ def main(args):
default=False,
help="Whether to shuffle shards befores splitting. Otherwise, train is 0, 1, 2, etc.",
)
parser.add_argument(
"--copy",
type=bool,
default=False,
help="Whether to copy the files instead of move. Defaults to False.",
)
args = parser.parse_args()

main(args)

0 comments on commit 7cdfe15

Please sign in to comment.