diff --git a/.gitignore b/.gitignore index e2f7103..61abb8f 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,8 @@ wandb/ tokenizer_ckpts/ *pkl *.egg-info +<<<<<<< HEAD *.log +======= +build/** +>>>>>>> da82c44 (adapt paths + langs) diff --git a/save_vq_tokens.py b/pseudolabeling/save_vq_tokens.py similarity index 54% rename from save_vq_tokens.py rename to pseudolabeling/save_vq_tokens.py index 9374399..d7dcb40 100755 --- a/save_vq_tokens.py +++ b/pseudolabeling/save_vq_tokens.py @@ -35,9 +35,22 @@ import fourm.utils.clip as clip -FEATURE_TASKS = ['CLIP-B16'] -IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp", ".jpx", ".gif") - +FEATURE_TASKS = ["CLIP-B16"] +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", + ".jpx", + ".gif", +) + + def find_image_extension(root_dir): for root, dirs, files in os.walk(root_dir): for file in files: @@ -45,23 +58,26 @@ def find_image_extension(root_dir): return os.path.splitext(file)[1] return None + class SaveVQDataset(Dataset): - def __init__(self, - root: str, - tokens_dir: str, - crop_settings_dir: str, - task: str, - n_crops: int = 10, - min_crop_scale: float = 0.2, - input_size: int = 224, - mask_value: Optional[float] = None, - task_transforms: dict = MODALITY_TRANSFORMS_DIVAE, - resample_mode: str = 'bilinear', - corrupt_samples_log: Optional[str] = None, - dryrun: bool = False, - force_load_crop: bool = False): + def __init__( + self, + root: str, + tokens_dir: str, + crop_settings_dir: str, + task: str, + n_crops: int = 10, + min_crop_scale: float = 0.2, + input_size: int = 224, + mask_value: Optional[float] = None, + task_transforms: dict = MODALITY_TRANSFORMS_DIVAE, + resample_mode: str = "bilinear", + corrupt_samples_log: Optional[str] = None, + dryrun: bool = False, + force_load_crop: bool = False, + ): super().__init__() - + self.data_root = root self.tokens_root = os.path.join(root, tokens_dir) self.crop_settings_root = os.path.join(root, crop_settings_dir) @@ -76,65 +92,77 @@ def __init__(self, self.dryrun = dryrun self.force_load_crop = force_load_crop - + self.loader = lambda path: Image.open(path) - + self.classes, self.class_to_idx = find_classes(os.path.join(root, task)) if corrupt_samples_log is not None: task_ext = find_image_extension(os.path.join(root, task)) self.samples = self.get_corrupt_samples(corrupt_samples_log, task_ext) else: - self.samples = make_dataset(os.path.join(root, task), self.class_to_idx, IMG_EXTENSIONS, None) - + self.samples = make_dataset( + os.path.join(root, task), self.class_to_idx, IMG_EXTENSIONS, None + ) + self.center_crop_augmenter = CenterCropImageAugmenter( target_size=self.input_size, hflip=0.0, main_domain=task ) self.random_crop_augmenter = RandomCropImageAugmenter( - target_size=self.input_size, hflip=0.5, + target_size=self.input_size, + hflip=0.5, crop_scale=(min_crop_scale, 1.0), crop_ratio=(0.75, 1.3333), - main_domain=task + main_domain=task, ) def get_corrupt_samples(self, corrupt_samples_log, task_ext): # Load the log file from find_corrupted_pseudolabels.py - with open(corrupt_samples_log, 'r') as f: + with open(corrupt_samples_log, "r") as f: corrupt_samples = f.readlines() - + # Remove the error message that was thrown and empty characters - corrupt_samples = [sample.split(':')[-1].strip() for sample in corrupt_samples] - + corrupt_samples = [sample.split(":")[-1].strip() for sample in corrupt_samples] + # Extract the folder and file names - corrupt_samples = [sample.split('/')[-2:] for sample in corrupt_samples] - + corrupt_samples = [sample.split("/")[-2:] for sample in corrupt_samples] + # Construct path corrupt_samples = [ - (os.path.join(self.data_root, self.task, s[0], s[1].replace('.npy', task_ext)), self.class_to_idx[s[0]]) + ( + os.path.join( + self.data_root, self.task, s[0], s[1].replace(".npy", task_ext) + ), + self.class_to_idx[s[0]], + ) for s in corrupt_samples ] - + return corrupt_samples - + def __len__(self): return len(self.samples) - def __getitem__(self, index): + def __getitem__(self, index): path, _ = self.samples[index] img = self.loader(path) - img = img.convert("RGB") if self.task in ['rgb', 'normal'] else img - - class_id, file_id = path.split('/')[-2:] - file_id = file_id.split('.')[0] + img = img.convert("RGB") if self.task in ["rgb", "normal"] else img + + class_id, file_id = path.split("/")[-2:] + file_id = file_id.split(".")[0] if self.mask_value is not None: - mask_path = os.path.join(self.data_root, 'mask_valid', class_id, f'{file_id}.png') + mask_path = os.path.join( + self.data_root, "mask_valid", class_id, f"{file_id}.png" + ) mask = Image.open(mask_path) - tokens_path = os.path.join(self.tokens_root, class_id, f'{file_id}.npy') + tokens_path = os.path.join(self.tokens_root, class_id, f"{file_id}.npy") if not self.dryrun: os.makedirs(os.path.dirname(tokens_path), exist_ok=True) - crop_settings_path = os.path.join(self.crop_settings_root, class_id, f'{file_id}.npy') + crop_settings_path = os.path.join( + self.crop_settings_root, class_id, f"{file_id}.npy" + ) # Create or load crop settings if os.path.exists(crop_settings_path) or self.force_load_crop: @@ -151,7 +179,9 @@ def __getitem__(self, index): # Subsequent crops are random for _ in range(1, self.n_crops): - crop_coords, h_flip, _, _, _ = self.random_crop_augmenter({self.task: img}, None) + crop_coords, h_flip, _, _, _ = self.random_crop_augmenter( + {self.task: img}, None + ) settings.append((*crop_coords, 1 if h_flip else 0)) settings = np.array(settings) @@ -162,38 +192,55 @@ def __getitem__(self, index): # Perform augmentations and optionally mask images imgs = [] for i, j, h, w, h_flip in settings: - img_mod = self.task_transforms[self.task].preprocess(img.copy()) img_mod = self.task_transforms[self.task].image_augment( - img_mod, (i,j,h,w), h_flip, None, - (self.input_size, self.input_size), None, self.resample_mode + img_mod, + (i, j, h, w), + h_flip, + None, + (self.input_size, self.input_size), + None, + self.resample_mode, ) img_mod = self.task_transforms[self.task].postprocess(img_mod) if self.mask_value is not None: - mask_valid = self.task_transforms['mask_valid'].preprocess(mask.copy()) - mask_valid = self.task_transforms['mask_valid'].image_augment( - mask_valid, (i,j,h,w), h_flip, None, - (self.input_size, self.input_size), None, None + mask_valid = self.task_transforms["mask_valid"].preprocess(mask.copy()) + mask_valid = self.task_transforms["mask_valid"].image_augment( + mask_valid, + (i, j, h, w), + h_flip, + None, + (self.input_size, self.input_size), + None, + None, ) - mask_valid = self.task_transforms['mask_valid'].postprocess(mask_valid) - img_mod[~repeat(mask_valid, '1 h w -> c h w', c=img_mod.shape[0])] = self.mask_value - mask_valid = mask_valid.float() * 2 - 1 # Valid regions -> 1, Masked-out regions -> -1 - img_mod = torch.cat([img_mod, mask_valid], dim=0) # Concat image with mask - + mask_valid = self.task_transforms["mask_valid"].postprocess(mask_valid) + img_mod[~repeat(mask_valid, "1 h w -> c h w", c=img_mod.shape[0])] = ( + self.mask_value + ) + mask_valid = ( + mask_valid.float() * 2 - 1 + ) # Valid regions -> 1, Masked-out regions -> -1 + img_mod = torch.cat( + [img_mod, mask_valid], dim=0 + ) # Concat image with mask + imgs.append(img_mod) imgs = torch.stack(imgs) return imgs, tokens_path + def get_feature_extractor(args): - if args.task == 'CLIP-B16': - teacher_model, _ = clip.load("ViT-B/16", device='cpu', jit=False) + if args.task == "CLIP-B16": + teacher_model, _ = clip.load("ViT-B/16", device="cpu", jit=False) teacher_model = teacher_model.visual return teacher_model.eval() else: return None + def main(args): utils.init_distributed_mode(args) device = torch.device(args.device) @@ -203,7 +250,9 @@ def main(args): np.random.seed(seed) random.seed(seed) - model = get_image_tokenizer(args.tokenizer_id, tokenizers_root=args.tokenizers_root, encoder_only=True) + model = get_image_tokenizer( + args.tokenizer_id, tokenizers_root=args.tokenizers_root, encoder_only=True + ) feature_extractor = get_feature_extractor(args) num_tasks = utils.get_world_size() @@ -211,16 +260,31 @@ def main(args): global_rank = utils.get_rank() sampler_rank = global_rank - loader_task = 'rgb' if args.task in FEATURE_TASKS else args.task - dataset = SaveVQDataset(root=os.path.join(args.data_root, args.split), crop_settings_dir='crop_settings', - tokens_dir=f'{args.task}_{args.folder_suffix}', task=loader_task, - min_crop_scale=args.min_crop_scale, n_crops=args.n_crops, - input_size=args.input_size, mask_value=args.mask_value, - resample_mode=args.resample_mode, corrupt_samples_log=args.corrupt_samples_log, force_load_crop=args.force_load_crop) - - sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=sampler_rank, shuffle=False) - data_loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=args.batch_size_dataloader, - num_workers=args.num_workers, drop_last=False) + loader_task = "rgb" if args.task in FEATURE_TASKS else args.task + dataset = SaveVQDataset( + root=os.path.join(args.data_root, args.split), + crop_settings_dir="crop_settings", + tokens_dir=f"{args.task}_{args.folder_suffix}", + task=loader_task, + min_crop_scale=args.min_crop_scale, + n_crops=args.n_crops, + input_size=args.input_size, + mask_value=args.mask_value, + resample_mode=args.resample_mode, + corrupt_samples_log=args.corrupt_samples_log, + force_load_crop=args.force_load_crop, + ) + + sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=num_tasks, rank=sampler_rank, shuffle=False + ) + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=args.batch_size_dataloader, + num_workers=args.num_workers, + drop_last=False, + ) model.to(device) if feature_extractor is not None: @@ -235,7 +299,6 @@ def main(args): pbar = None for imgs_batch, tokens_paths in data_loader: - # Filter out already saved images imgs_batch_filtered, tokens_paths_filtered = [], [] for imgs, tokens_path in zip(imgs_batch, tokens_paths): @@ -248,139 +311,169 @@ def main(args): continue imgs_batch = torch.stack(imgs_batch_filtered) tokens_paths = tokens_paths_filtered - - + # Merge batch and number of augmentation dimensions - if 'semseg' in args.task: - imgs_batch = rearrange(imgs_batch, 'b n h w -> (b n) h w') + if "semseg" in args.task: + imgs_batch = rearrange(imgs_batch, "b n h w -> (b n) h w") else: - imgs_batch = rearrange(imgs_batch, 'b n c h w -> (b n) c h w') - + imgs_batch = rearrange(imgs_batch, "b n c h w -> (b n) c h w") + # For efficiency, process images with batch size that might be different from loader batch size or num augmentations sub_batches = imgs_batch.split(args.batch_size, dim=0) - + all_tokens = [] - + for sub_batch in sub_batches: sub_batch = sub_batch.to(device) - + with torch.no_grad(): - if 'CLIP' in args.task: + if "CLIP" in args.task: B, C, H, W = sub_batch.shape P_H, P_W = feature_extractor.conv1.kernel_size N_H, N_W = H // P_H, W // P_W - sub_batch = feature_extractor(sub_batch, return_final_tokens_no_cls=True) - sub_batch = rearrange(sub_batch, 'b (nh nw) d -> b d nh nw', nh=N_H, nw=N_W) + sub_batch = feature_extractor( + sub_batch, return_final_tokens_no_cls=True + ) + sub_batch = rearrange( + sub_batch, "b (nh nw) d -> b d nh nw", nh=N_H, nw=N_W + ) tokens = model.tokenize(sub_batch) tokens = rearrange(tokens, "b h w -> b (h w)") tokens = tokens.detach().cpu().numpy().astype(np.int16) all_tokens.append(tokens) - + all_tokens = np.concatenate(all_tokens) - all_tokens = rearrange(all_tokens, '(b n) d -> b n d', n=args.n_crops) - + all_tokens = rearrange(all_tokens, "(b n) d -> b n d", n=args.n_crops) + for tokens, tokens_path in zip(all_tokens, tokens_paths): if args.dryrun: - print(f'Dryrun: rank {global_rank} -> {tokens_path}') + print(f"Dryrun: rank {global_rank} -> {tokens_path}") else: np.save(tokens_path, tokens) if pbar is not None: pbar.update(1) - #torch.distributed.barrier() + # torch.distributed.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Tokenization time {}'.format(total_time_str)) + print("Tokenization time {}".format(total_time_str)) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser(prog="VQ token saver") parser.add_argument( - "--tokenizer_id", type=str, default='cc12m/rgb_ViTB-UNetP4_16k_224-448', - help="ID of tokenizer to load." + "--tokenizer_id", + type=str, + default="cc12m/rgb_ViTB-UNetP4_16k_224-448", + help="ID of tokenizer to load.", ) parser.add_argument( - "--tokenizers_root", type=str, default='./tokenizer_ckpts', - help="Path where tokenizer checkpoints are saved." + "--tokenizers_root", + type=str, + default="./tokenizer_ckpts", + help="Path where tokenizer checkpoints are saved.", ) parser.add_argument( - "--data_root", type=str, default='/path/to/dataset', - help="Path to dataset root" + "--data_root", type=str, default="/path/to/dataset", help="Path to dataset root" ) + parser.add_argument("--split", type=str, default="train", help="train or val") parser.add_argument( - "--split", type=str, default='train', - help="train or val" - ) - parser.add_argument( - "--n_crops", type=int, default='1', + "--n_crops", + type=int, + default="1", help="Number of crops to save. If 1, only a center crop will be saved. \ - If > 1, first image will be center cropped, the subsequent ones will be randomly cropped." + If > 1, first image will be center cropped, the subsequent ones will be randomly cropped.", ) parser.add_argument( - "--min_crop_scale", type=float, default=0.8, - help="Minimum crop scale (Only for n_crops > 1)" + "--min_crop_scale", + type=float, + default=0.8, + help="Minimum crop scale (Only for n_crops > 1)", ) + parser.add_argument("--input_size", type=int, default=224, help="Image size") + parser.add_argument("--task", type=str, default="rgb", help="Task name") parser.add_argument( - "--input_size", type=int, default=224, - help="Image size" + "--mask_value", + type=float, + default=None, + help="Optionally set masked-out regions to this value after data augs (default: %(default)s)", ) parser.add_argument( - "--task", type=str, default='rgb', - help="Task name" + "--resample_mode", + type=str, + default=None, + help="PIL resample mode for resizing loaded images. One out of ['bilinear', 'bicubic', 'nearest', None]. (default: %(default)s)", ) parser.add_argument( - "--mask_value", type=float, default=None, - help="Optionally set masked-out regions to this value after data augs (default: %(default)s)" + "--corrupt_samples_log", + type=str, + default=None, + help="Path to log file with corrupted samples from find_corrupted_pseudolabels.py. \ + If provided, only corrupted samples will be re-tokenized.", ) parser.add_argument( - "--resample_mode", type=str, default=None, - help="PIL resample mode for resizing loaded images. One out of ['bilinear', 'bicubic', 'nearest', None]. (default: %(default)s)" + "--verbose", + action="store_true", + default=False, + help="Set to enable progress bar", ) parser.add_argument( - "--corrupt_samples_log", type=str, default=None, - help="Path to log file with corrupted samples from find_corrupted_pseudolabels.py. \ - If provided, only corrupted samples will be re-tokenized." + "--dryrun", + action="store_true", + default=False, + help="Set to do a dry run that creates the tokens and prints the paths without saving them to disk.", ) parser.add_argument( - "--verbose", action='store_true', default=False, - help="Set to enable progress bar" + "--device", default="cuda", help="Device to use for tokenization" ) + parser.add_argument("--seed", default=0, type=int, help="Random seed") parser.add_argument( - "--dryrun", action='store_true', default=False, - help="Set to do a dry run that creates the tokens and prints the paths without saving them to disk." + "--folder_suffix", + type=str, + default="dvae_BUa_224", + help="Suffix to add to the folder under which the tokens are saved.", ) - parser.add_argument('--device', default='cuda', help='Device to use for tokenization') - parser.add_argument('--seed', default=0, type=int, help='Random seed') + parser.add_argument("--num_workers", default=16, type=int) parser.add_argument( - "--folder_suffix", type=str, - default='dvae_BUa_224', - help="Suffix to add to the folder under which the tokens are saved." + "--pin_mem", + action="store_true", + help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", ) - parser.add_argument('--num_workers', default=16, type=int) - parser.add_argument('--pin_mem', action='store_true', - help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') - parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', - help='') + parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem", help="") parser.set_defaults(pin_mem=True) - parser.add_argument('--batch_size_dataloader', default=64, type=int, - help='Dataloader batch size (default: %(default)s)') - parser.add_argument('--batch_size', default=64, type=int, - help='Batch size per GPU (default: %(default)s)') + parser.add_argument( + "--batch_size_dataloader", + default=64, + type=int, + help="Dataloader batch size (default: %(default)s)", + ) + parser.add_argument( + "--batch_size", + default=64, + type=int, + help="Batch size per GPU (default: %(default)s)", + ) # Distributed parameters - parser.add_argument('--world_size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--local_rank', default=-1, type=int) - parser.add_argument('--dist_on_itp', action='store_true') - parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') - - parser.add_argument('--force_load_crop', action='store_true', - help='Make sure to load crops locally, otherwise break the code.') + parser.add_argument( + "--world_size", default=1, type=int, help="number of distributed processes" + ) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument("--dist_on_itp", action="store_true") + parser.add_argument( + "--dist_url", default="env://", help="url used to set up distributed training" + ) + + parser.add_argument( + "--force_load_crop", + action="store_true", + help="Make sure to load crops locally, otherwise break the code.", + ) args = parser.parse_args() print("Force loading existing crop settings: {}".format(args.force_load_crop)) diff --git a/pseudolabeling/train_val_test_split.py b/pseudolabeling/train_val_test_split.py new file mode 100644 index 0000000..e22953a --- /dev/null +++ b/pseudolabeling/train_val_test_split.py @@ -0,0 +1,129 @@ +import argparse +import os +import shutil +from sklearn.model_selection import train_test_split + + +def main(args): + """Main function to partition datasets into train, val, and test splits.""" + + # Get all class directories from the source directory + dset_dirs = [ + d + for d in os.listdir(args.source_dir) + if os.path.isdir(os.path.join(args.source_dir, d)) + ] + print(f"Subdirectories of source dir {args.source_dir}: {dset_dirs}") + + for dset_dir in dset_dirs: + dset_path = os.path.join(args.source_dir, dset_dir) + print(dset_path) + all_files = os.listdir(dset_path) + if len(all_files) == 0: + print(f"Skipping dataset {dset_dir} as it has no files.") + continue + + # filter out files not ending with .tar + all_files = sorted([f for f in all_files if f.endswith(".tar")]) + # Split shards into train/temp + train_files, temp_files = train_test_split( + all_files, + train_size=args.train_ratio, + random_state=42, + shuffle=args.shuffle, + ) + + # Split temp into val/test + val_files, test_files = train_test_split( + temp_files, + test_size=args.test_ratio / (1 - args.train_ratio), + random_state=42, + shuffle=args.shuffle, + ) + + # move files to respective splits + for dataset, files in zip( + ["train", "val", "test"], [train_files, val_files, test_files] + ): + split_path = os.path.join(args.output_dir, dataset, dset_dir) + print(f"Move {dset_path} -----------> {split_path}") + os.makedirs(split_path, exist_ok=True) # Create class directory in split + for file in files: + if args.copy: + shutil.copy(os.path.join(dset_path, file), os.path.join(split_path, file)) + else: + shutil.move( + os.path.join(dset_path, file), os.path.join(split_path, file) + ) + + +if __name__ == "__main__": + """ + Given a source directory containing the data for multiple modalities, e.g., + + ``` + |--source_dir/ + | |--modality_a/ + | |--modality_b/ + | |--modality_c/ + ``` + + move the files into a specified output_dir/ with the structure: + ``` + |--source_dir/ + | |--train/ + | | |--modality_a/ + | | |--modality_b/ + | | |--modality_c/ + | |--val/ + | | |--modality_a/ + | | |--modality_b/ + | | |--modality_c/ + | |--test/ + | | |--modality_a/ + | | |--modality_b/ + | | |--modality_c/ + ``` + """ + parser = argparse.ArgumentParser( + description="Partition datasets into train, val, and test splits." + ) + parser.add_argument( + "--source_dir", + type=str, + required=True, + help="Path to the source directory containing dataset folders.", + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to the output directory to store the splits. (--output_dir/dataset/split)", + ) + parser.add_argument( + "--train_ratio", + type=float, + default=0.7, + help="Ratio of data for the training set.", + ) + parser.add_argument( + "--test_ratio", + type=float, + default=0.2, + help="Ratio of data for the test set (remaining will be validation).", + ) + parser.add_argument( + "--shuffle", + type=bool, + default=False, + help="Whether to shuffle shards befores splitting. Otherwise, train is 0, 1, 2, etc.", + ) + parser.add_argument( + "--copy", + type=bool, + default=False, + help="Whether to copy the files instead of move. Defaults to False.", + ) + args = parser.parse_args() + + main(args) diff --git a/pseudolabeling/v2d_to_metadata.py b/pseudolabeling/v2d_to_metadata.py new file mode 100644 index 0000000..16764fe --- /dev/null +++ b/pseudolabeling/v2d_to_metadata.py @@ -0,0 +1,149 @@ +import argparse +import json +import os +import shutil +import tarfile +import tempfile +from tqdm import tqdm +from datetime import timedelta + +# FIXME: may need adaptation +METADATA_MAPPING = { + "webpage_url": "url", + "title": "title", + "duration": "duration", + "channel": "channel", + "fps": "fps", + "tags": "tags", + "resolution": "resolution", + "aspect_ratio": "aspect_ratio", +} + + +def process_tar_files(source_directory, target_directory, dataset, skip_existing=True): + """Extract, process, and re-package JSON files in TAR archives.""" + + os.makedirs(target_directory, exist_ok=True) + + for filename in tqdm(os.listdir(source_directory)): + if filename.endswith(".tar"): + target_tar_path = os.path.join(target_directory, filename) + print(target_tar_path) + + if skip_existing and os.path.exists(target_tar_path): + print(f"Skipping already processed file: {target_tar_path}") + continue + + source_tar_path = os.path.join(source_directory, filename) + with tarfile.open(source_tar_path, "r") as tar: + temp_dir = tempfile.mkdtemp() + try: + tar.extractall(path=temp_dir) + + # process json files + for root, dirs, files in os.walk(temp_dir): + for file in files: + if file.endswith(".json"): + process_json_file( + os.path.join(root, file), temp_dir, dataset + ) + + with tarfile.open(target_tar_path, "w") as out_tar: + for root, dirs, files in os.walk(temp_dir): + for file in files: + if file.endswith(".json"): + out_tar.add(os.path.join(root, file), arcname=file) + finally: + shutil.rmtree(temp_dir) + + +def process_json_file(json_file_path, output_dir, dataset): + """Reads and processes a single JSON file to convert it to the required format.""" + with open(json_file_path, "r", encoding="utf-8") as file: + data = json.load(file) + # remove filepath of json + os.remove(json_file_path) + video_key = os.path.splitext(os.path.basename(json_file_path))[0] + + json_content = {} + + if data["status"] != "success": + # errored while downloading + print(data["status"]) + return + elif "subtitles" not in data["yt_meta_dict"]: + print("NO SUBTITLES: ", data) + # indeed, there are some videos without subtitles (np speech) + return + if ( + data["yt_meta_dict"]["subtitles"].keys() != {"en"} + and len(data["yt_meta_dict"]["subtitles"].keys()) > 0 + ): + # XXX: for now, we decided to only exclude non-English videos. + raise ValueError( + f"Non-English subtitles found: {data['yt_meta_dict']['subtitles'].keys()}" + ) + for key, value in METADATA_MAPPING.items(): + if value in data["yt_meta_dict"]["info"]: + json_content[key] = data["yt_meta_dict"]["info"][value] + + json_content["dataset"] = dataset + json_filename = f"{video_key}.json" + with open(os.path.join(output_dir, json_filename), "w") as outfile: + json.dump(json_content, outfile, indent=4) + + +def main(args): + if "filtered_raw" not in args.input_dir: + raise ValueError(f"Expected input dir to be a subdir of `filtered_raw/`, instead received {args.input_dir}.") + + output_dir = ( + args.output_dir + if args.output_dir is not None + else os.path.join(args.input_dir.replace("filtered_raw", "4m"), "video_metadata") + ) + process_tar_files( + source_directory=args.input_dir, + target_directory=output_dir, + dataset=args.dataset, + skip_existing=args.skip_existing, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Process tarfiles from `filtered_raw` format containing JSONs and extract relevant metadata into the `video_metadata` modality." + ) + + parser.add_argument( + "-I", + "--input_dir", + type=str, + default="/store/swissai/a08/data/filtered_raw/howto100m/v2d_5000/", + # default="/cluster/work/cotterell/mm_swissai/raw/v2d_500/howto100m", + help="A `filtered_raw` dir containing the JSON files to process.", + ) + parser.add_argument( + "-O", + "--output_dir", + type=str, + default=None, + help="Output dir to save the pseudolabeled metadata.", + ) + parser.add_argument( + "-S", + "--skip_existing", + default=False, # FIXME + help="Skip tarfiles already processed (exist in the target directory).", + ) + # TODO: is this also in filestructure or do we have to provide it like this? + parser.add_argument( + "-D", + "--dataset", + type=str, + required=True, + help="Which dataset tar is coming from (HDVILA/HowTo100M)", + ) + + args = parser.parse_args() + main(args) diff --git a/pseudolabeling/v2d_to_transcript.py b/pseudolabeling/v2d_to_transcript.py new file mode 100644 index 0000000..f24d53b --- /dev/null +++ b/pseudolabeling/v2d_to_transcript.py @@ -0,0 +1,157 @@ +import argparse +import json +import os +import shutil +import tarfile +import tempfile +from tqdm import tqdm +from datetime import timedelta + + +def timestamp_to_frames(timestamp, fps): + """Converts a timestamp in the format 'min.ms' into a frame count.""" + total_seconds = float(timestamp) + print(total_seconds) + # TODO: right-exlusive, left-inclusive. + return round(total_seconds * fps) + + +def process_tar_files(source_directory, target_directory, skip_existing=True): + """Extract, process, and re-package JSON files in TAR archives.""" + os.makedirs(target_directory, exist_ok=True) + + for filename in tqdm(os.listdir(source_directory)): + if filename.endswith(".tar"): + target_tar_path = os.path.join(target_directory, filename) + print(target_tar_path) + + if skip_existing and os.path.exists(target_tar_path): + print(f"Skipping already processed file: {target_tar_path}") + continue + + source_tar_path = os.path.join(source_directory, filename) + with tarfile.open(source_tar_path, "r") as tar: + temp_dir = tempfile.mkdtemp() + try: + tar.extractall(path=temp_dir) + + # process json files + for root, dirs, files in os.walk(temp_dir): + for file in files: + if file.endswith(".json"): + process_json_file(os.path.join(root, file), temp_dir) + + with tarfile.open(target_tar_path, "w") as out_tar: + for root, dirs, files in os.walk(temp_dir): + for file in files: + if file.endswith(".jsonl"): + out_tar.add(os.path.join(root, file), arcname=file) + finally: + shutil.rmtree(temp_dir) + + +def process_json_file(json_file_path, output_dir): + """Reads and processes a single JSON file to convert it to the required format.""" + with open(json_file_path, "r", encoding="utf-8") as file: + data = json.load(file) + video_key = os.path.splitext(os.path.basename(json_file_path))[0] + + if data["status"] != "success": + # errored while downloading + return + elif "subtitles" not in data["yt_meta_dict"]: + print(data) + # TODO: what to do with videos that have no subtitles? When can this occur? + return + if data["yt_meta_dict"]["subtitles"].keys() != {"en"}: + # XXX: for now, we decided to only exclude non-English videos + return + subtitles = data["whisper_alignment"]["segments"] + fps = data["yt_meta_dict"]["info"]["fps"] + + json_content = [] + for subtitle in subtitles: + start_frame = timestamp_to_frames(subtitle["start"], fps) + end_frame = timestamp_to_frames(subtitle["end"], fps) + sentence = subtitle["text"] + word_timestamps = [] + for word in subtitle["words"]: + word_timestamps.append( + { + "word": word["word"], + "start": timestamp_to_frames(word["start"], fps) + if "start" in word.keys() + else None, + "end": timestamp_to_frames(word["end"], fps) + if "end" in word.keys() + else None, + } + ) + + json_content.append( + { + "sentence": sentence, + "start": start_frame, + "end": end_frame, + "words": word_timestamps, + } + ) + + jsonl_filename = f"{video_key}.jsonl" + with open(os.path.join(output_dir, jsonl_filename), "w") as outfile: + json.dump(json_content, outfile, indent=4) + + +def main(args): + if "filtered_raw" not in args.input_dir: + raise ValueError(f"Expected input dir to be a subdir of `filtered_raw/`, instead received {args.input_dir}.") + + current_folder = os.path.join(args.input_dir, args.whisper_dir) + output_dir = ( + args.output_dir + if args.output_dir is not None + else os.path.join(args.input_dir.replace("filtered_raw", "4m"), "video_transcript") + ) + print(f"Processing {current_folder}.") + process_tar_files( + source_directory=current_folder, + target_directory=output_dir, + skip_existing=args.skip_existing, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Process tarfiles containing JSONs and convert to structured JSONL format." + ) + parser.add_argument( + "-I", + "--input_dir", + type=str, + default="/store/swissai/a08/data/filtered_raw/howto100m/v2d_5000/", + # default="/cluster/work/cotterell/mm_swissai/raw/v2d_500/howto100m", + help="A `filtered_raw` dir containing the JSON files to process.", + ) + parser.add_argument( + "-O", + "--output_dir", + type=str, + default=None, + help="Output dir to save the pseudolabeled transcripts.", + ) + parser.add_argument( + "-W", + "--whisper_dir", + type=str, + default="whisperx", + help="Dir containing the WhisperX transcripts.", + ) + parser.add_argument( + "-S", + "--skip_existing", + default=False, # FIXME + help="Skip tarfiles already processed (exist in the target directory).", + ) + + args = parser.parse_args() + main(args)