diff --git a/cfgs/default/4m/alphas_mixture/main/mix_mod7_all2all_rgb2dense_a0.5.yaml b/cfgs/default/4m/alphas_mixture/main/mix_mod7_all2all_rgb2dense_a0.5.yaml index 4594a23..1236711 100644 --- a/cfgs/default/4m/alphas_mixture/main/mix_mod7_all2all_rgb2dense_a0.5.yaml +++ b/cfgs/default/4m/alphas_mixture/main/mix_mod7_all2all_rgb2dense_a0.5.yaml @@ -1,7 +1,7 @@ # Half of all samples from this mixture are rgb2dense, half are all2all -rgb@224: - input_alphas: [1000.0, 0.5] - target_alphas: [0., 0.] # RGB is not a target +rgb@224: # num cols is the number of sampling strategies to try. + input_alphas: [1000.0, 0.5] # NOTE: input alphas: how much weight to put on the input modalities for encoding. + target_alphas: [0., 0.] # RGB is not a target # NOTE: target alphas - how much weight to put on the modalities to be unmasked/sampled during decoding. tok_rgb@224: input_alphas: [0., 0.5] target_alphas: [0.5, 0.5] diff --git a/cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_a0.5.yaml b/cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_a0.5.yaml new file mode 100644 index 0000000..90cde68 --- /dev/null +++ b/cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_a0.5.yaml @@ -0,0 +1,7 @@ +# What's up with the naming configs here? +video_rgb@224: + input_alphas: [1000.0, 0.5] + target_alphas: [0., 0.] # RGB is not a target +video_tok_rgb@224: + input_alphas: [0., 0.5] + target_alphas: [0.5, 0.5] \ No newline at end of file diff --git a/cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_det_a0.5.yaml b/cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_det_a0.5.yaml new file mode 100644 index 0000000..73745ab --- /dev/null +++ b/cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_det_a0.5.yaml @@ -0,0 +1,11 @@ +# What's up with the naming configs here? +video_rgb@224: + input_alphas: [1000.0, 0.5] + target_alphas: [0., 0.] # RGB is not a target +video_tok_rgb@224: + input_alphas: [0., 0.5] + target_alphas: [0.5, 0.5] +det: + input_alphas: [0., 0.5] + target_alphas: [0.5, 0.5] + keep: ['random', 'random'] diff --git a/cfgs/default/4m/data/video/mix_mod3_rgb_tok_det_to_all_a0.5.yaml b/cfgs/default/4m/data/video/mix_mod3_rgb_tok_det_to_all_a0.5.yaml new file mode 100644 index 0000000..c0da88f --- /dev/null +++ b/cfgs/default/4m/data/video/mix_mod3_rgb_tok_det_to_all_a0.5.yaml @@ -0,0 +1,44 @@ +train: + datasets: + my_video_dataset: + type: multimodal + + # Input and output domain names, separated by hyphen + in_domains: video_rgb@224-video_tok_rgb@224-video_det + out_domains: video_rgb@224-video_tok_rgb@224-video_det + + # Dirichlet alphas concentration parameter for input and output. + # Can be either one value, or one value per input modality separated by hyphen. + input_alphas: null + target_alphas: null + # Path to specific alphas configuration to enable mixture of Dirichlets. + # If provided, overrides input_alphas and target_alphas + alphas_config: "cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_det_a0.5.yaml" + + # Optionally, min_input_tokens, min_target_tokens, num_input_tokens, num_target_tokens can be specified here + # If so, they will override the values provided in the main config + + # Data can either be local or on cloud storage (e.g. S3), see data docs for more info + # Use braceexpand notation to indicate shard range (e.g. shard-{0000..9999}.tar) + # Use brackets to indicate multiple modalities (e.g. [modality1,modality2,modality3]) + data_path: '/store/swissai/a08/data/4m/train/[video_rgb,video_rgb_tok,video_det]/shard-{00000..00100}.tar' # TODO: need to reformat the data correctly here. + use_wds: True # Use webdataset + wds_n_repeats: 4 # Number of repeats for webdataset loader to improve efficiency + wds_shuffle_buffer_tar: 1_000 # Webdatasets shuffle buffer after loading tar files + wds_shuffle_buffer_repeat: 1_000 # Webdatasets shuffle buffer after repeating samples + + main_augment_domain: video_rgb@224 # Select from which modality to get the original full image size (mostly important for resizing bounding boxes) + aligned_captions: True # Align captions to crop_settings # TODO: tbd? + tok_train_aug: True # Apply data augmentation to tokens (if multiple crop settings are available) # TODO: tbd? + + # modality_name_map: # Use modality_name_map to define a mapping from a folder name to a modality name + # tok_rgb_folder_name: tok_rgb@224 + # tok_depth_folder_nme: tok_depth@224 + # ... + + weights: [1.0] # Sampling weights for the training datasets + +val: + datasets: + my_video_dataset: + data_path: '/store/swissai/a08/data/4m/val/[video_rgb,video_rgb_tok,video_det]/shard-{00000..00100}.tar' \ No newline at end of file diff --git a/cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml b/cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml new file mode 100644 index 0000000..55800e6 --- /dev/null +++ b/cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml @@ -0,0 +1,44 @@ +train: + datasets: + my_video_dataset: + type: multimodal + + # Input and output domain names, separated by hyphen + in_domains: video_rgb@224-video_tok_rgb@224 + out_domains: video_rgb@224-video_tok_rgb@224 + + # Dirichlet alphas concentration parameter for input and output. + # Can be either one value, or one value per input modality separated by hyphen. + input_alphas: null + target_alphas: null + # Path to specific alphas configuration to enable mixture of Dirichlets. + # If provided, overrides input_alphas and target_alphas + alphas_config: "cfgs/default/4m/alphas_mixture/video/mix_mod3_rgb_tok_a0.5.yaml" + + # Optionally, min_input_tokens, min_target_tokens, num_input_tokens, num_target_tokens can be specified here + # If so, they will override the values provided in the main config + + # Data can either be local or on cloud storage (e.g. S3), see data docs for more info + # Use braceexpand notation to indicate shard range (e.g. shard-{0000..9999}.tar) + # Use brackets to indicate multiple modalities (e.g. [modality1,modality2,modality3]) + data_path: '/store/swissai/a08/data/4m/cleaned/train/[video_rgb,video_tok_rgb]/0000000000.tar' # TODO: need to reformat the data correctly here. + use_wds: True # Use webdataset + wds_n_repeats: 4 # Number of repeats for webdataset loader to improve efficiency + wds_shuffle_buffer_tar: 1_000 # Webdatasets shuffle buffer after loading tar files + wds_shuffle_buffer_repeat: 1_000 # Webdatasets shuffle buffer after repeating samples + + main_augment_domain: video_rgb@224 # Select from which modality to get the original full image size (mostly important for resizing bounding boxes) + aligned_captions: True # Align captions to crop_settings # TODO: tbd? + tok_train_aug: True # Apply data augmentation to tokens (if multiple crop settings are available) # TODO: tbd? + + # modality_name_map: # Use modality_name_map to define a mapping from a folder name to a modality name + # tok_rgb_folder_name: tok_rgb@224 + # tok_depth_folder_nme: tok_depth@224 + # ... + + weights: [1.0] # Sampling weights for the training datasets + +# val: +# datasets: +# my_video_dataset: +# data_path: '/store/swissai/a08/data/4m/val/[video_rgb,video_tok_rgb]/00000{00175..00199}.tar' \ No newline at end of file diff --git a/cfgs/default/4m/models/video/4m-b_mod3.yaml b/cfgs/default/4m/models/video/4m-b_mod3.yaml new file mode 100644 index 0000000..c921053 --- /dev/null +++ b/cfgs/default/4m/models/video/4m-b_mod3.yaml @@ -0,0 +1,46 @@ +# Config for DDP + +# Arch: SwiGLU No Bias +# Modalities: Mix of rgb2all and all2all, with alphas=0.5 +# To be run on 64 GPUs for batch size = 8192 +run_name: auto + +# Input & output +num_input_tokens: 128 +num_target_tokens: 128 +loss_type: mod + +# Architecture +model: fm_base_12e_12d_swiglu_nobias +patch_size: 16 +input_size: 224 +dtype: bfloat16 +tokenizer_path: "fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json" + +# Train +epochs: -1 +total_tokens: 500 # in billions +opt: adamw +blr: 0.0001 # this is base_lr = 1e-4, lr = base_lr * batch_size / 256 +min_blr: 0. +warmup_epochs: -1 +warmup_tokens: 10 # in billions +batch_size: 128 # 128 x 64 = 8192 + +# Data + +data_config: "cfgs/default/4m/data/video/mix_mod3_rgb_tok_to_all_a0.5.yaml" +s3_data_endpoint: null # Change me +eval_freq: 1 +fixed_eval: True +epoch_size: 10_000_000 # Number of samples per "epoch" + +# Saving +save_ckpt_freq: 1 +output_dir: 'output/auto' + +# Wandb +log_wandb: False # Set to True to log to Weights & Biases +wandb_project: '4m-train' +wandb_entity: null # Change if needed +wandb_run_name: auto diff --git a/fourm/data/masking.py b/fourm/data/masking.py index 860861a..87697da 100644 --- a/fourm/data/masking.py +++ b/fourm/data/masking.py @@ -128,7 +128,7 @@ def chunk_span_masking(sequence_chunks: List[List[int]], sentinel_to_id: Dict[in -class UnifiedMasking(object): +class UnifiedMasking(object): # this defines the masking logic def __init__(self, modality_info: Dict, text_tokenizer: Optional[Tokenizer], diff --git a/fourm/data/modality_info.py b/fourm/data/modality_info.py index dd85882..c11bb24 100644 --- a/fourm/data/modality_info.py +++ b/fourm/data/modality_info.py @@ -29,6 +29,11 @@ ColorPaletteTransform, SAMInstanceTokTransform, SAMInstanceTransform, + VideoDescriptionTransform, + VideoDetectionTransform, + VideoRGBTransform, + VideoTokTransform, + VideoTranscriptTransform, ) from fourm.models.decoder_embeddings import ImageTokenDecoderEmbedding, SequenceDecoderEmbedding from fourm.models.encoder_embeddings import ( @@ -39,6 +44,7 @@ ) from fourm.utils import generate_uint15_hash +# Specifications about different modalities MODALITY_INFO = { # 4M-7 modalities "rgb@224": { @@ -406,6 +412,38 @@ }, } +VIDEO_MODALITY_INFO = { + ### Video modalities + # TODO: do we need to keep the image versions? These probably should generalize over those, right? + "video_rgb@224": { + **MODALITY_INFO["rgb@224"], + "id": generate_uint15_hash("video_rgb@224"), + "path": "video_rgb", # TODO: video_rgb or keep aas rgb? (probably only keep rgb if this generalizes over single images too) + }, + "video_description": { + **MODALITY_INFO["caption"], # TODO: do we want to increase the default 'max_tokens/max_length' from 256? + "id": generate_uint15_hash("video_description"), + }, + "video_transcript": { + **MODALITY_INFO["caption"], # TODO: do we want to increase the default 'max_tokens/max_length' from 256? + "id": generate_uint15_hash("video_transcript"), + }, + "video_det": { + **MODALITY_INFO["det"], + "id": generate_uint15_hash("video_det"), + }, + "video_tok_rgb@224": { + **MODALITY_INFO["tok_rgb@224"], + "id": generate_uint15_hash("video_tok_rgb@224"), + }, + "video_tok_clip@224": { + **MODALITY_INFO["tok_clip@224"], + "id": generate_uint15_hash("video_tok_clip@224"), + }, +} + +MODALITY_INFO = {**MODALITY_INFO, **VIDEO_MODALITY_INFO} + # Note: @res suffix is ignored for modality transforms MODALITY_TRANSFORMS = { # 4M-7 modalities @@ -414,7 +452,7 @@ "det": DetectionTransform( det_threshold=0.6, det_max_instances=None, bbox_order="dist_to_orig", coord_bins=1000, min_visibility=0.0 ), - "tok_rgb": TokTransform(), + "tok_rgb": TokTransform(), # tok_ indicates its a token representation "tok_depth": TokTransform(), "tok_normal": TokTransform(), "tok_semseg": TokTransform(), @@ -435,6 +473,15 @@ "tok_imagebind_global": TokTransform(), # Other "mask_valid": MaskTransform(mask_pool_size=1), + # Video + "video_rgb": VideoRGBTransform(imagenet_default_mean_and_std=True), # TODO: check parameters + "video_tok_rgb": VideoTokTransform(), # tok_ indicates its a token representation + "video_tok_clip": VideoTokTransform(), # TODO: check parameters + "video_description": VideoDescriptionTransform(aligned_captions=True), # TODO: check parameters + "video_transcript": VideoTranscriptTransform(aligned_captions=True), # TODO: check parameters + "video_det": VideoDetectionTransform( + det_threshold=0.6, det_max_instances=None, bbox_order="dist_to_orig", coord_bins=1000, min_visibility=0.0 + ), # TODO: check parameters } MODALITY_TRANSFORMS_DIVAE = { diff --git a/fourm/data/modality_transforms.py b/fourm/data/modality_transforms.py index bdc7dbf..e672a29 100644 --- a/fourm/data/modality_transforms.py +++ b/fourm/data/modality_transforms.py @@ -13,6 +13,7 @@ # limitations under the License. import gzip import json +import jsonlines import random from pathlib import Path from typing import Optional, Tuple, List, Dict @@ -40,9 +41,11 @@ PAD_MASK_VALUE, ) +# start with data, then models # The @-symbol is used to specify the resolution of a modality. Syntax: modality@resolution def get_transform_key(mod_name): + # TODO: do we need to modify this for the video keys? return mod_name.split("@")[0] @@ -149,7 +152,6 @@ def __repr__(self): class AbstractTransform(ABC): - @abstractmethod def load(self, sample): pass @@ -177,7 +179,6 @@ def postprocess(self, v): class ImageTransform(AbstractTransform): - @staticmethod def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) @@ -216,8 +217,10 @@ def image_crop_and_resize(img: Image, crop_coords: Tuple, target_size: Tuple, re return img -class RGBTransform(ImageTransform): - +class RGBTransform( + ImageTransform +): # For RGB in raw format for 4M because it's passed in directly (and also training tokenizer) + def __init__(self, imagenet_default_mean_and_std=True, color_jitter=False, color_jitter_strength=0.5): self.rgb_mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN self.rgb_std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD @@ -283,6 +286,47 @@ def postprocess(self, sample): return sample +class VideoRGBTransform(RGBTransform): + """ + A video transform applied to a sequence of RGB images. + For now, I'm assuming the input to the load function points to a webdataset containing mp4 which contains the frames in the specified modality (e.g, RGB). + This format almost certainly is subject to change. TODO: figure out what the right format for this input should be (aka, how are we storing the raw videos?) + + Output: a tensor of shape (num_frames, C, H, W) where C is the number of channels (3 for RGB). OR should this already be unrolled into something like (num_frames * C, H, W)? + """ + + # raise NotImplementedError("I'm not ") + def load(self, path): + raise NotImplementedError( + "TODO: implement the loader for video frames, probably from a webdataset. It might be helpful to see how to load data from video2dataset format into a dataloader here: https://github.com/swiss-ai/ml-4m/blob/kdu/pseudolabeler/notebooks/pseudolabeler.py. However we don't want to load it into a dataset so maybe that's overkill? Instead might be helpful to just load directly using webd utilities." + ) + + def preprocess(self, sample): + raise NotImplementedError( + "TODO: what preprocessing do we want? do we also want to convert to RGB and do color jitter like with normal RGBTransform?" + ) + + def image_augment( + self, + v, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): + raise NotImplementedError( + "TODO: what augmentations do we want? do we also want to same as with normal image RGBTransform?" + ) + + def postprocess(self, v): + # TODO: deicde + raise NotImplementedError( + "TODO: postprocess should convert the frames into a tensor of shape (num_frames, C, H, W) where C is the number of channels (3 for RGB)." + ) + + class DepthTransform(ImageTransform): def __init__(self, standardize_depth=True): @@ -336,7 +380,6 @@ def postprocess(self, sample): class NormalTransform(ImageTransform): - def __init__(self, standardize_surface_normals=False): self.normal_mean = (0.5, 0.5, 0.5) if not standardize_surface_normals else IMAGENET_SURFACE_NORMAL_MEAN self.normal_std = (0.5, 0.5, 0.5) if not standardize_surface_normals else IMAGENET_SURFACE_NORMAL_STD @@ -382,6 +425,7 @@ def postprocess(self, sample): class SemsegTransform(ImageTransform): + # apply for learning the tokens? def __init__( self, scale_factor=1.0, shift_idx_by_one=False, id_mapping: Optional[Dict] = None, select_channel=None @@ -393,7 +437,7 @@ def __init__( def map_semseg_values(self, sample): sample = np.asarray(sample) - mapping_fn = lambda x: self.id_mapping.get(x, x) + mapping_fn = lambda x: self.id_mapping.get(x, x) # noqa: E731 sample = np.vectorize(mapping_fn)(sample) sample = Image.fromarray(sample, mode="P") return sample @@ -503,7 +547,7 @@ def fn(x, xn): lmbda = np.linalg.solve(A, b) if 0 <= lmbda[0] <= 1 and 0 <= lmbda[1] <= 1: output.append(lmbda[1] * xn + (1 - lmbda[1]) * x) - except: + except: # noqa: E722 continue return output @@ -677,6 +721,7 @@ def postprocess(self, sample): class TokTransform(AbstractTransform): + # Transformation on the tokens def __init__(self): pass @@ -703,13 +748,51 @@ def image_augment( "Crop settings / augmentation index are missing but a pre-tokenized modality is being used" ) v = torch.tensor(v[rand_aug_idx]) - return v + return v # Since we augment before saving (create 5 augmented versions of the same image), in this transform you randomly sample one of the augmentations to work with. def postprocess(self, sample): return sample -class DetectionTransform(AbstractTransform): +class VideoTokTransform(AbstractTransform): + """ + Assume input tokens is an ndarray of shape (num_frames, num_tokens_per_frame). + Transform the tokens to a torch tensor of shape (num_frames * num_tokens,). + """ + # Transformation on the tokens + + def __init__(self): + pass + + def load(self, path): + sample = np.load(path).astype(int) + return sample # shape: (num_frames, num_tokens_per_frame) + + def preprocess(self, sample): + return sample + + def image_augment( + self, + v, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ): + # TODO: do we want an image_augment for videos frames? What would those even look like? + # only implement this if we implement augmentations during the tokenization/saving process? + print( + "WARNING: no image augmentations implemented for video tokens at the moment. Decide on what we should do and then remove this warning/implement as needed." + ) + return v + + def postprocess(self, sample): + return torch.ravel(sample) # shape: (num_frames * num_tokens_per_frame,) + + +class DetectionTransform(AbstractTransform): # bounding boxes def __init__( self, @@ -751,13 +834,13 @@ def order_bboxes_by_score(bboxes): def shuffle_bboxes(bboxes): return sorted(bboxes, key=lambda x: random.random()) - def convert_detection_instance(self, instances): - """Convert instances dict to list of lists where each list takes the form: + def convert_detection_instance(self, instances: dict) -> List[Tuple]: + """Convert instances dict to list of tuples where each list takes the form: [xmin, ymin, xmax, ymax, class_name, score] """ instances = [ - inst["boxes"] + [inst["class_name"], inst["score"]] + tuple(inst["boxes"] + [inst["class_name"], inst["score"]]) for inst in instances if inst["score"] >= self.det_threshold ] @@ -804,7 +887,7 @@ def order_and_filter_bboxes(self, bboxes): return self.bbox_order(bboxes) - def convert_bboxes_to_string(self, bboxes: List[Tuple]): + def convert_bboxes_to_string(self, bboxes: List[Tuple]) -> str: """Convert bounding boxes to a string. xmin, ymin, xmax, ymax are mapped to v0, v1, v2, v3 special tokens. @@ -864,8 +947,129 @@ def postprocess(self, bboxes): return bboxes -class CaptionTransform(AbstractTransform): +class VideoDetectionTransform(DetectionTransform): + """ + Bounding boxes for videos. Read bounding boxes for a given video (sequence of frames) as a JSONL file containing the list of bounding boxes. + + Example video bounding boxes input: + [ + # FRAME 0 Bounding boxes + { + "num_instances": 5, + "image_height": 512, + "image_width": 906, + "instances": [ + { + "boxes": [ + 0.4229210317134857, + 0.00020096010121051222, + 0.5715101361274719, + 0.13699540495872498 + ], + "score": 0.9029952883720398, + "class_id": 74, + "class_name": "clock", + "segmentation": [ + [ + 0.5055187637969095, + 0.1337890625, + ... + ] + ] + }, + { + "boxes": [ + ... + ], + ... + }, + ... + ] + }, + # FRAME 1 Bounding boxes + { + "num_instances": 5, + "image_height": 512, + "image_width": 906, + "instances": [ + ..., + ], + ... + } + ] + + Input: path to a List[dicts] representing bounding boxes representations for each frame. + Output: a str representing the bounding boxes for each frame in the video, separated by special frame tokens to indicate which frame each bounding box representation corresponds to. + """ + + def load(self, path): + """ + Load jsonl file containing bounding box representations per frame. + """ + with jsonlines.open(path, "r") as jsonl_f: + bbs_per_frame = [obj for obj in jsonl_f] + return bbs_per_frame + + def preprocess(self, sample: List[dict]) -> List[List[tuple]]: + """ + Given a list of dicts (each element in the list is a frame, each dict is a representation of bounding boxes for that frame), + convert the instances within each frame's bounding boxes representation. + # TODO: any issue with returning a list of lists here? We sorta might break the type consistency of the class? I'm not sure how robust it is. + # ^ Actually I think this is fine because it just gets passed into the image_augment function, so as long as that's implemented correspondingly we're good. + + Returns: + List of lists of tuples. + The outermost list represents each frame. + The middle list represents each bounding box instances in that frame. + The innermost tuple is the actual representation of a bounding box instance (of the form [xmin, ymin, xmax, ymax, class_name, score]). + """ + instances = [frame["instances"] for frame in sample] + return self.convert_detection_instance(instances) + + def image_augment( + self, + bboxes_per_frame: List[List[Tuple]], + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx=None, + resample_mode: str = None, + ) -> List[List[Tuple]]: + """ + Apply bbox augmentations to the bounding boxes of each frame (TODO: do we actually wanna do this? Would some of these augmentations break the video nature/spatio-temporal relations/etc.?) + """ + bboxes_per_frame = [] + for bboxes in bboxes_per_frame: + bboxes = self.bboxes_crop_and_resize(bboxes, crop_coords, orig_size) + bboxes = self.bboxes_hflip(bboxes, target_size, flip) + bboxes = self.order_and_filter_bboxes(bboxes) + + bboxes_per_frame.append(bboxes) + + return bboxes_per_frame + + def postprocess(self, bboxes_per_frame): + """ + Given bounding box representations per frame, should return a string like: + " bbox0 string representation blah blah blah ... bbox0 string representation blah blah blah ... " + """ + if self.return_raw: + raise NotImplementedError( + "I'm not sure what the correct behavior of returning bboxes_per_frame should be when returned raw, so throwing error for now. Set return_raw=False to avoid this." + ) + return bboxes_per_frame + output_str = "" # TODO: implement the special tokens better than just hardcoding things here. + for i, bboxes in enumerate(bboxes_per_frame): + bboxes_str = self.convert_bboxes_to_string(bboxes_per_frame) + output_str += bboxes_str + output_str += f"" + output_str += "" # TODO: Do we need to explicitly add EOS token here? If yes, we need to do this so it adds the actual eos_token str of the model not just this hardcoded string. + return output_str + + +class CaptionTransform(AbstractTransform): def __init__(self, aligned_captions=True, no_aug=False): self.aligned_captions = aligned_captions self.no_aug = no_aug @@ -913,8 +1117,174 @@ def postprocess(self, sample): return sample -class CaptionEmbTransform(AbstractTransform): +class VideoTranscriptTransform(AbstractTransform): + def __init__(self, aligned_captions=True, no_aug=False): + # TODO: what are these args? Do we still need them? + self.aligned_captions = aligned_captions + self.no_aug = no_aug + + def load(self, path) -> List[dict]: + # TODO: decide on the representation of the input description (how do we know which frames each caption maps to?) + # Caption can either be stored as .txt or .json.gz (in which case it's a list of dicts) + """ + For now, assume we have something like a jsonlines of transcripts for each clip/sequence of frames in a video.: + [ + { + "transcript": "here's a transcript", + "start_frame_index": 0, + "end_frame_index": 5, + }, + { + "transcript": "here's another transcript", + "start_frame_index": 10, + "end_frame_index": 13, + } # Note that the transcript need not be consecutive in all frames, e.g., you can see a skip from frames 5-10 with no transcripts. + ] + """ + with jsonlines.open(path, "r") as jsonl_f: + transcripts_per_clip = [obj for obj in jsonl_f] + + return transcripts_per_clip + + def preprocess(self, sample): + return sample + + def image_augment( + self, + val, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ) -> List[dict]: + # TODO: decide what augmentaitons might be appropriate for descriptions here? + print( + "WARNING: no augmentations implemented for transcripts yet. Decide whether to augment/what these should be and then remove this warning." + ) + return val + def postprocess(self, sample: List[dict]) -> str: + """ + Given a list of {text, start_frame, end_frame} dicts, we want to return a string in the format: + <0th_start_frame_token>0th transcript text blah blah<0th_end_frame_token><1st_start_frame_token>1st transcript text blah blah<1st_end_frame_token>.... + Example: given + [ + { + "transcript": "here's a transcript", + "start_frame_index": 0, + "end_frame_index": 5, + }, + { + "transcript": "here's another transcript", + "start_frame_index": 10, + "end_frame_index": 13, + } # Note that the transcript need not be consecutive in all frames, e.g., you can see a skip from frames 5-10 with no transcripts. + ] + + We should have + here's a transcripthere's another transcript + """ + output_str = "" + for transcript_dict in sample: + start_frame_token = f"" + end_frame_token = f"" + output_str += start_frame_token + transcript_dict["transcript"] + end_frame_token + output_str += "" # TODO: don't hardcode here, use the actual eos_token str of the model. + return output_str + + +class VideoDescriptionTransform(AbstractTransform): + # TODO: maybe do some inheritance situation with VideoDescriptionTransform and VideoTranscriptTransform. But also maybe not necessary since there's just these two classes. + # Implementation-wise they are very similar (Except that descriptions have the consecutive frame constraint), but differ conceptually because transcripts are like dialogue/etc. and descriptions are "what's happening"/every seq of frames can have a description. + def __init__(self, aligned_captions=True, no_aug=False): + # TODO: what are these args? Do we still need them? + self.aligned_captions = aligned_captions + self.no_aug = no_aug + + def load(self, path) -> List[dict]: + # TODO: decide on the representation of the input description (how do we know which frames each caption maps to?) + # Caption can either be stored as .txt or .json.gz (in which case it's a list of dicts) + """ + For now, assume we have something like a jsonlines of descriptions for each clip/sequence of frames in a video.: + [ + { + "description": "here's a description", + "start_frame_index": 0, + "end_frame_index": 5, + }, + { + "description": "here's another description", + "start_frame_index": 5, + "end_frame_index": 12, + } # Note that the description NEEDS to be consecutive in all frames, e.g., the end frame of one description is the start frame of the next. + ] + """ + with jsonlines.open(path, "r") as jsonl_f: + descs_per_clip = [obj for obj in jsonl_f] + + return descs_per_clip + + def preprocess(self, sample: List[dict]): + # Check that each clip's description frame indices are consecutive. + for i in range(len(sample) - 1): + desc_dict = sample[i] + next_desc_dict = sample[i + 1] + if next_desc_dict["start_frame_index"] != desc_dict["end_frame_index"]: + raise ValueError( + "Frame indices of descriptions are not consecutive, which is a definitional requirement of descriptions. Double check how the descriptions are created upstream." + ) + return sample + + def image_augment( + self, + val, + crop_coords: Tuple, + flip: bool, + orig_size: Tuple, + target_size: Tuple, + rand_aug_idx: Optional[int], + resample_mode: str = None, + ) -> List[dict]: + # TODO: decide what augmentaitons might be appropriate for descriptions here? + print( + "WARNING: no augmentations implemented for descriptions yet. Decide whether to augment/what these should be and then remove this warning." + ) + return val + + def postprocess(self, sample: List[dict]) -> str: + """ + Given a list of {text, start_frame, end_frame} dicts, we want to return a string in the format: + <0th_start_frame_token>0th transcript text blah blah<0th_end_frame_token><1st_start_frame_token>1st transcript text blah blah<1st_end_frame_token>.... + Example: given + [ + { + "description": "here's a description", + "start_frame_index": 0, + "end_frame_index": 5, + }, + { + "description": "here's another description", + "start_frame_index": 5, + "end_frame_index": 12, + } # Note that the description NEEDS to be consecutive in all frames, e.g., the end frame of one description is the start frame of the next. + ] + + We should have + here's a descriptionhere's another description + """ + output_str = "" + for transcript_dict in sample: + start_frame_token = f"" + end_frame_token = f"" + output_str += start_frame_token + transcript_dict["transcript"] + end_frame_token + output_str += "" # TODO: don't hardcode here, use the actual eos_token str of the model. + return output_str + + +class CaptionEmbTransform(AbstractTransform): + # CaptionEmbTransform for the caption embeddings def __init__(self, aligned_captions=True, no_aug=False): self.aligned_captions = aligned_captions self.no_aug = no_aug @@ -969,7 +1339,6 @@ def postprocess(self, sample): class MetadataTransform(AbstractTransform): - def __init__( self, special_vmin: int = 0, @@ -1150,7 +1519,6 @@ def postprocess(self, metadata): class HumanPoseTransform(AbstractTransform): - def __init__(self, coord_bins=1000, only_pose=False, return_raw=False): self.coord_bins = coord_bins self.return_raw = return_raw @@ -1351,7 +1719,6 @@ def postprocess(self, humanposes): class ColorPaletteTransform(AbstractTransform): - def __init__(self, coord_bins=1000, return_raw=False): self.coord_bins = coord_bins self.return_raw = return_raw @@ -1414,7 +1781,6 @@ def postprocess(self, palettes): class SAMInstanceTokTransform(AbstractTransform): - def __init__(self, image_size=224, points_per_side=7, point_order="random"): self.H, self.W = to_2tuple(image_size) self.points_per_h, self.points_per_w = to_2tuple(points_per_side) @@ -1477,7 +1843,7 @@ def convert_target_tokens_to_string(self, target_tokens): result_text.append("none") else: for tok, bbox in target_tokens[point]: - result_text.append(f"polygon") + result_text.append("polygon") # Add bounding box coordinates to the string ymin, xmin, ymax, xmax = bbox.astype(np.int32) @@ -1533,7 +1899,6 @@ def postprocess(self, sample): class CropSettingsTransform(AbstractTransform): - def load(self, path): sample = np.load(path) return sample @@ -1558,7 +1923,6 @@ def postprocess(self, sample): class IdentityTransform(AbstractTransform): - def load(self, path): raise NotImplementedError("IdentityTransform does not support loading") @@ -1582,7 +1946,6 @@ def postprocess(self, sample): class JSONTransform(AbstractTransform): - def load(self, path): if path.endswith(".json"): with open(path, "r") as f: @@ -1609,3 +1972,4 @@ def image_augment( def postprocess(self, sample): return sample +# CaptionEmbTransform for the caption embeddings \ No newline at end of file diff --git a/fourm/data/multimodal_dataset_folder.py b/fourm/data/multimodal_dataset_folder.py index 07ccf96..d7880ce 100644 --- a/fourm/data/multimodal_dataset_folder.py +++ b/fourm/data/multimodal_dataset_folder.py @@ -23,9 +23,9 @@ from fourm.data.modality_transforms import AbstractTransform, get_transform_key -IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp', '.jpx', '.npy', '.npz') +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp", ".jpx", ".npy", ".npz") -UNIFIED_EXTENSIONS = IMG_EXTENSIONS + ('.json', '.txt', '.json.gz') +UNIFIED_EXTENSIONS = IMG_EXTENSIONS + (".json", ".txt", ".json.gz") def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: @@ -54,15 +54,15 @@ def is_image_file(filename: str) -> bool: def make_dataset( - directory: str, - class_to_idx: Dict[str, int], - extensions: Optional[Tuple[str, ...]] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, - cache_path: Optional[str] = None, + directory: str, + class_to_idx: Dict[str, int], + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + cache_path: Optional[str] = None, ) -> List[Tuple[str, int]]: if cache_path is not None and os.path.exists(cache_path): # Load cached file paths from disk if it exists - with open(cache_path, 'rb') as f: + with open(cache_path, "rb") as f: return pickle.load(f) instances = [] directory = os.path.expanduser(directory) @@ -71,8 +71,10 @@ def make_dataset( if both_none or both_something: raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") if extensions is not None: + def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + is_valid_file = cast(Callable[[str], bool], is_valid_file) for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] @@ -88,7 +90,7 @@ def is_valid_file(x: str) -> bool: if cache_path is not None: # Cache all file paths s.t. setting up the dataloader is instant in the future os.makedirs(os.path.dirname(cache_path), exist_ok=True) - with open(cache_path, 'wb') as f: + with open(cache_path, "wb") as f: pickle.dump(instances, f) return instances @@ -126,16 +128,15 @@ class DatasetFolder(VisionDataset): """ def __init__( - self, - root: str, - loader: Callable[[str], Any], - extensions: Optional[Tuple[str, ...]] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, + self, + root: str, + loader: Callable[[str], Any], + extensions: Optional[Tuple[str, ...]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: - super(DatasetFolder, self).__init__(root, transform=transform, - target_transform=target_transform) + super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) classes, class_to_idx = self._find_classes(self.root) samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) if len(samples) == 0: @@ -205,7 +206,7 @@ class MultiModalDatasetFolder(VisionDataset): root/modality_a/class_y/xxy.ext root/modality_a/class_z/xxz.ext - root/modality_b/class_x/xxx.ext + root/modality_b/class_x/xxx.ext # TODO: what is meant by a class here? root/modality_b/class_y/xxy.ext root/modality_b/class_z/xxz.ext @@ -236,18 +237,18 @@ class MultiModalDatasetFolder(VisionDataset): """ def __init__( - self, - root: str, - modalities: List[str], - modality_paths: Dict[str, str], - modality_transforms: Dict[str, AbstractTransform], - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, - max_samples: Optional[int] = None, - pre_shuffle: bool = False, - cache: bool = False, - return_path: bool = False, + self, + root: str, + modalities: List[str], + modality_paths: Dict[str, str], + modality_transforms: Dict[str, AbstractTransform], + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + max_samples: Optional[int] = None, + pre_shuffle: bool = False, + cache: bool = False, + return_path: bool = False, ) -> None: super(MultiModalDatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) self.modalities = modalities @@ -264,17 +265,22 @@ def __init__( samples = { mod: make_dataset( - os.path.join(self.root, f'{self.modality_paths[mod]}'), - class_to_idx, - extensions, + os.path.join(self.root, f"{self.modality_paths[mod]}"), + class_to_idx, + extensions, is_valid_file, - cache_path=os.path.join(self.root, 'dataloader_cache', f'{self.modality_paths[mod]}.pkl') if cache else None) + cache_path=os.path.join(self.root, "dataloader_cache", f"{self.modality_paths[mod]}.pkl") + if cache + else None, + ) for mod in self.modalities } - + for mod, mod_samples in samples.items(): if len(mod_samples) == 0: - msg = "Found 0 logs in subfolders of: {}\n".format(os.path.join(self.root, f'{self.modality_paths[mod]}')) + msg = "Found 0 logs in subfolders of: {}\n".format( + os.path.join(self.root, f"{self.modality_paths[mod]}") + ) if extensions is not None: msg += "Supported extensions are: {}".format(",".join(extensions)) raise RuntimeError(msg) @@ -292,7 +298,7 @@ def __init__( permutation = np.random.permutation(total_samples) for task in samples: self.samples[task] = [self.samples[task][i] for i in permutation][:max_samples] - + if pre_shuffle: total_samples = len(list(self.samples.values())[0]) np.random.seed(100) @@ -320,11 +326,11 @@ def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx - + def get_class_and_file(self, path: str) -> Tuple[str, str]: - """ Extracts the class and file name from a path. """ - class_id, file_name = path.split('/')[-2:] - file_name = file_name.split('.')[0] + """Extracts the class and file name from a path.""" + class_id, file_name = path.split("/")[-2:] + file_name = file_name.split(".")[0] return class_id, file_name def __getitem__(self, index: int) -> Tuple[Any, Any]: @@ -350,12 +356,12 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: if self.target_transform is not None: target = self.target_transform(target) - sample_dict['class_idx'] = target + sample_dict["class_idx"] = target - if self.return_path and not index in self.cache: + if self.return_path and not index in self.cache: # noqa: E713 class_id, file_name = self.get_class_and_file(path) - sample_dict['class_id'] = class_id - sample_dict['file_name'] = file_name + sample_dict["class_id"] = class_id + sample_dict["file_name"] = file_name return sample_dict diff --git a/fourm/data/notes.txt b/fourm/data/notes.txt new file mode 100644 index 0000000..4617771 --- /dev/null +++ b/fourm/data/notes.txt @@ -0,0 +1,6 @@ +1. Do we need modality transforms for video? (assuming we tokenize the video to (seq_len, time)) + - First define modality info and transforms, if needed (e.g. ) + - Make a clever masking temporal schema + - then pass to encoder embedding to convert to (hs, seq_len, time) + - TODO: rn encoder embeddings only deal with (seq_len,) objects. Need to modify/add new class so that it can deal with (seq_len, time) objects +2. \ No newline at end of file diff --git a/fourm/data/unified_datasets.py b/fourm/data/unified_datasets.py index 7369731..28ce53b 100644 --- a/fourm/data/unified_datasets.py +++ b/fourm/data/unified_datasets.py @@ -136,6 +136,7 @@ def build_fm_transfer_dataset( def _keyless_map(data, f, handler=reraise_exception): """Map samples without adding __key__.""" for sample in data: + import pdb; pdb.set_trace() try: result = f(sample) except Exception as exn: @@ -390,7 +391,7 @@ def build_wds_fm_pretraining_dataloader( if batch_size is not None: # Perform multi-threaded dataloading - return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None) + return wds.WebLoader(datapipe, num_workers=0, batch_size=None) else: return datapipe @@ -459,7 +460,7 @@ def build_huggingface_pretraining_dataloader( UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter), UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer, input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range) - ]) + ]) # This executes the masking datapipe = wds.DataPipeline( dataset, @@ -552,6 +553,6 @@ def build_mixture_dataloader(data_iters, weights, modality_info, batch_size, num wds.batched(batch_size, collation_fn=default_collate, partial=False), ).with_epoch(epoch_size // (num_gpus * num_workers * batch_size)) # Pre-define iterator length - mixture_loader = wds.WebLoader(mixture_pipe, num_workers=num_workers, batch_size=None) + mixture_loader = wds.WebLoader(mixture_pipe, num_workers=0, batch_size=None) return mixture_loader diff --git a/fourm/models/decoder_embeddings.py b/fourm/models/decoder_embeddings.py index 3045793..471d858 100644 --- a/fourm/models/decoder_embeddings.py +++ b/fourm/models/decoder_embeddings.py @@ -20,6 +20,7 @@ from .fm_utils import build_1d_sincos_posemb, build_2d_sincos_posemb, pair +# Pass position + modality tokens to decoder so it knows what modality to decode to class SequenceDecoderEmbedding(nn.Module): """Embedding module for sequence inputs, like captions or a sequence of objects. diff --git a/fourm/models/encoder_embeddings.py b/fourm/models/encoder_embeddings.py index d1d42a1..ec20b00 100644 --- a/fourm/models/encoder_embeddings.py +++ b/fourm/models/encoder_embeddings.py @@ -20,6 +20,7 @@ from .fm_utils import build_1d_sincos_posemb, build_2d_sincos_posemb, pair class SequenceEncoderEmbedding(nn.Module): + # NOTE(Kev): embedding anything that's a sequence, such as text, or bounding boxes (it's a sequence of coordinates) output shape: (HS, seq_len) """Embedding module for encoding sequence inputs, like captions or a sequence of objects. Args: @@ -121,6 +122,7 @@ def forward(self, d : Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return d class ImageTokenEncoderEmbedding(nn.Module): + # Used for embedding anything that's gridlike (n, n), such as tokenized RGB, clip, etc. """Embedding module for tokenized spatial inputs. Args: @@ -202,7 +204,7 @@ def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # Map to embedding x = self.token_emb(ids) - # Create positional embedding + modality embedding + # Create positional embedding + modality embedding # here is where the modality gets encoded x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B) d['x'] = x @@ -212,6 +214,7 @@ def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: class ImageEncoderEmbedding(nn.Module): + # Encode Raw RGB input - split the image into patches and encode each patch into the embedding size """Embedding module for spatial inputs, like images or feature maps. Creates tokens from patches over the image. @@ -310,6 +313,7 @@ def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: class SequenceEmbEncoderEmbedding(nn.Module): + # input shape: (t5 emb size, seq length) -> (4m HS, seq length) """Adapter for sequence emb inputs, like T5-XXL, CLIP text embeddings. Args: diff --git a/run_training_4m.py b/run_training_4m.py index 5088172..ef065fd 100755 --- a/run_training_4m.py +++ b/run_training_4m.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse + +# FOR DEBUGGING with pdb: +# 1. make sure the wds.WebLoader(num_workers=0) in a) build_wds_fm_pretraining_dataloader and b) build_mixture_dataloader +# python -m torch.distributed.launch --nproc_per_node 1 --use-env run_training_4m.py --config /store/swissai/a08/kdu/ml-4m/cfgs/default/4m/models/video/4m-b_mod3.yamlimport argparse import datetime import json import math diff --git a/save_vq_tokens.py b/save_vq_tokens.py index 9374399..1312ca6 100755 --- a/save_vq_tokens.py +++ b/save_vq_tokens.py @@ -297,13 +297,14 @@ def main(args): print('Tokenization time {}'.format(total_time_str)) +# TODO: modify for video case if __name__ == '__main__': parser = argparse.ArgumentParser(prog="VQ token saver") parser.add_argument( "--tokenizer_id", type=str, default='cc12m/rgb_ViTB-UNetP4_16k_224-448', help="ID of tokenizer to load." - ) + ) # TODO: says which tokenizer to use, need to download the weights from HF and save to this directory. parser.add_argument( "--tokenizers_root", type=str, default='./tokenizer_ckpts', help="Path where tokenizer checkpoints are saved." @@ -311,7 +312,7 @@ def main(args): parser.add_argument( "--data_root", type=str, default='/path/to/dataset', help="Path to dataset root" - ) + ) # TODO: path to dataset (dataset needs to be stored in the format of the actual images, not webd here) parser.add_argument( "--split", type=str, default='train', help="train or val"