From e686e662b15da65d6fcc050e22545f87b6b2dc2b Mon Sep 17 00:00:00 2001 From: Kevin Du Date: Fri, 2 Aug 2024 16:50:48 +0200 Subject: [PATCH] Update configs to work with a filtered raw + add debug script with pdb --- cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml | 4 ++-- fourm/data/unified_datasets.py | 5 +++-- run_training_4m.py | 5 +++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml b/cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml index 5ac205c..55800e6 100644 --- a/cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml +++ b/cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml @@ -21,7 +21,7 @@ train: # Data can either be local or on cloud storage (e.g. S3), see data docs for more info # Use braceexpand notation to indicate shard range (e.g. shard-{0000..9999}.tar) # Use brackets to indicate multiple modalities (e.g. [modality1,modality2,modality3]) - data_path: '/store/swissai/a08/data/4m/splits2/train/[video_rgb,video_rgb_tok]/00000{00000..00100}.tar' # TODO: need to reformat the data correctly here. + data_path: '/store/swissai/a08/data/4m/cleaned/train/[video_rgb,video_tok_rgb]/0000000000.tar' # TODO: need to reformat the data correctly here. use_wds: True # Use webdataset wds_n_repeats: 4 # Number of repeats for webdataset loader to improve efficiency wds_shuffle_buffer_tar: 1_000 # Webdatasets shuffle buffer after loading tar files @@ -41,4 +41,4 @@ train: # val: # datasets: # my_video_dataset: -# data_path: '/store/swissai/a08/data/4m/val/[video_rgb,video_rgb_tok]/00000{00175..00199}.tar' \ No newline at end of file +# data_path: '/store/swissai/a08/data/4m/val/[video_rgb,video_tok_rgb]/00000{00175..00199}.tar' \ No newline at end of file diff --git a/fourm/data/unified_datasets.py b/fourm/data/unified_datasets.py index cf35d3c..28ce53b 100644 --- a/fourm/data/unified_datasets.py +++ b/fourm/data/unified_datasets.py @@ -136,6 +136,7 @@ def build_fm_transfer_dataset( def _keyless_map(data, f, handler=reraise_exception): """Map samples without adding __key__.""" for sample in data: + import pdb; pdb.set_trace() try: result = f(sample) except Exception as exn: @@ -390,7 +391,7 @@ def build_wds_fm_pretraining_dataloader( if batch_size is not None: # Perform multi-threaded dataloading - return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None) + return wds.WebLoader(datapipe, num_workers=0, batch_size=None) else: return datapipe @@ -552,6 +553,6 @@ def build_mixture_dataloader(data_iters, weights, modality_info, batch_size, num wds.batched(batch_size, collation_fn=default_collate, partial=False), ).with_epoch(epoch_size // (num_gpus * num_workers * batch_size)) # Pre-define iterator length - mixture_loader = wds.WebLoader(mixture_pipe, num_workers=num_workers, batch_size=None) + mixture_loader = wds.WebLoader(mixture_pipe, num_workers=0, batch_size=None) return mixture_loader diff --git a/run_training_4m.py b/run_training_4m.py index db85bc0..ef065fd 100755 --- a/run_training_4m.py +++ b/run_training_4m.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# python run_training_4m.py --data_config cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml -import argparse +# FOR DEBUGGING with pdb: +# 1. make sure the wds.WebLoader(num_workers=0) in a) build_wds_fm_pretraining_dataloader and b) build_mixture_dataloader +# python -m torch.distributed.launch --nproc_per_node 1 --use-env run_training_4m.py --config /store/swissai/a08/kdu/ml-4m/cfgs/default/4m/models/video/4m-b_mod3.yamlimport argparse import datetime import json import math