diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7b087b93..4ef56e41 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,7 +31,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8] + python-version: [3.8] steps: - uses: actions/checkout@v2 diff --git a/.gitignore b/.gitignore index 5728acb7..146bae56 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ __pycache__ .env* wandb *.pex -.pexing \ No newline at end of file +.pexing +**/dataset/* \ No newline at end of file diff --git a/README.md b/README.md index 56130f69..4072dc6c 100644 --- a/README.md +++ b/README.md @@ -11,20 +11,149 @@ Checkout the [design doc](https://docs.google.com/document/d/1_TD2KQLkEegszq4Eip pip install video2dataset +## Examples + + +## Usage + +First get some video url list. For example: +``` +echo 'https://www.youtube.com/watch?v=0WfKzVqdQqo' >> myvidlist.txt +``` + +Then, run the tool: + +``` +video2dataset --url_list=myvidlist.txt --output_folder=output_folder +``` + +The tool will then automatically download the urls and store them with that format: +* output_folder + * 00000 + * 000000000.mp4 + * 000000001.mp4 + * 000000002.mp4 + +or as this format if choosing webdataset: +* output_folder + * 00000.tar containing: + * 000000000.mp4 + * 000000001.mp4 + * 000000002.mp4 + +with each number being the position in the list. The subfolders avoids having too many files in a single folder. + +If **captions** are provided, they will be saved as 0.txt, 1.txt, ... + +This can then easily be fed into machine learning training or any other use case. + +Also .json files named 0.json, 1.json,... are saved with these keys: +* url +* caption +* key of the form 000010005 : the first 5 digits are the shard id, the last 4 are the index in the shard +* status : whether the download succeeded +* error_message + +Also a .parquet file will be saved with the same name as the subfolder/tar files containing these same metadata. +It can be used to analyze the results efficiently. + +.json files will also be saved with the same name suffixed by _stats, they contain stats collected during downloading (download time, number of success, ...) + ## Python examples -Checkout these examples to call this as a lib: -* [example.py](examples/example.py) ## API -This module exposes a single function `hello_world` which takes the same arguments as the command line tool: +This module exposes a single function `download` which takes the same arguments as the command line tool: + +* **url_list** A file with the list of url of images to download. It can be a folder of such files. (*required*) +* **output_folder** The path to the output folder. (default *"images"*) +* **processes_count** The number of processes used for downloading the pictures. This is important to be high for performance. (default *1*) +* **encode_format** encode format (default *mp4*) +* **output_format** decides how to save pictures (default *files*) + * **files** saves as a set of subfolder containing pictures + * **webdataset** saves as tars containing pictures + * **parquet** saves as parquet containing pictures as bytes + * **tfrecord** saves as tfrecord containing pictures as bytes + * **dummy** does not save. Useful for benchmarks +* **input_format** decides how to load the urls (default *txt*) + * **txt** loads the urls as a text file of url, one per line + * **csv** loads the urls and optional caption as a csv + * **tsv** loads the urls and optional caption as a tsv + * **tsv.gz** loads the urls and optional caption as a compressed (gzip) tsv.gz + * **json** loads the urls and optional caption as a json + * **parquet** loads the urls and optional caption as a parquet +* **url_col** the name of the url column for parquet and csv (default *url*) +* **caption_col** the name of the caption column for parquet and csv (default *None*) +* **number_sample_per_shard** the number of sample that will be downloaded in one shard (default *10000*) +* **save_additional_columns** list of additional columns to take from the csv/parquet files and save in metadata files (default *None*) +* **timeout** maximum time (in seconds) to wait when trying to download an image (default *10*) +* **enable_wandb** whether to enable wandb logging (default *False*) +* **wandb_project** name of W&B project used (default *video2dataset*) +* **oom_shard_count** the order of magnitude of the number of shards, used only to decide what zero padding to use to name the shard files (default *5*) +* **distributor** choose how to distribute the downloading (default *multiprocessing*) + * **multiprocessing** use a multiprocessing pool to spawn processes + * **pyspark** use a pyspark session to create workers on a spark cluster (see details below) +* **subjob_size** the number of shards to download in each subjob supporting it, a subjob can be a pyspark job for example (default *1000*) +* **incremental_mode** Can be "incremental" or "overwrite". For "incremental", video2dataset will download all the shards that were not downloaded, for "overwrite" video2dataset will delete recursively the output folder then start from zero (default *incremental*) + +## Incremental mode + +If a first download got interrupted for any reason, you can run again with --incremental "incremental" (this is the default) and using the same output folder , the same number_sample_per_shard and the same input urls, and video2dataset will complete the download. + +## Output format choice + +video2dataset support several formats. There are trade off for which to choose: +* files: this is the simplest one, images are simply saved as files. It's good for up to 1M samples on a local file system. Beyond that performance issues appear very fast. Handling more than a million files in standard filesystem does not work well. +* webdataset: webdataset format saves samples in tar files, thanks to [webdataset](https://webdataset.github.io/webdataset/) library, this makes it possible to load the resulting dataset fast in both pytorch, tensorflow and jax. Choose this for most use cases. It works well for any filesystem +* parquet: parquet is a columnar format that allows fast filtering. It's particularly easy to read it using pyarrow and pyspark. Choose this if the rest of your data ecosystem is based on pyspark. [petastorm](https://github.com/uber/petastorm) can be used to read the data but it's not as easy to use as webdataset +* tfrecord: tfrecord is a protobuf based format. It's particularly easy to use from tensorflow and using [tf data](https://www.tensorflow.org/guide/data). Use this if you plan to use the dataset only in the tensorflow ecosystem. The tensorflow writer does not use fsspec and as a consequence supports only a limited amount of filesystem, including local, hdfs, s3 and gcs. It is also less efficient than the webdataset writer when writing to other filesystems than local, losing some 30% performance. + +## File system support + +Thanks to [fsspec](https://filesystem-spec.readthedocs.io/en/latest/), video2dataset supports reading and writing files in [many file systems](https://github.com/fsspec/filesystem_spec/blob/6233f315548b512ec379323f762b70764efeb92c/fsspec/registry.py#L87). +To use it, simply use the prefix of your filesystem before the path. For example `hdfs://`, `s3://`, `http://`, or `gcs://`. +Some of these file systems require installing an additional package (for example s3fs for s3, gcsfs for gcs). +See fsspec doc for all the details. + +If you need specific configuration for your filesystem, you may handle this problem by using the [fsspec configuration system](https://filesystem-spec.readthedocs.io/en/latest/features.html#configuration) that makes it possible to create a file such as `.config/fsspec/s3.json` and have information in it such as: +``` +{ + "s3": { + "client_kwargs": { + "endpoint_url": "https://some_endpoint", + "aws_access_key_id": "your_user", + "aws_secret_access_key": "your_password" + } + } +} +``` +Which may be necessary if using s3 compatible file systems such as [minio](https://min.io/). That kind of configuration also work for all other fsspec-supported file systems. + +## Distribution modes + +video2dataset supports several distributors. +* multiprocessing which spawns a process pool and use these local processes for downloading +* pyspark which spawns workers in a spark pool to do the downloading + +multiprocessing is a good option for downloading on one machine, and as such it is the default. +Pyspark lets video2dataset use many nodes, which makes it as fast as the number of machines. +It can be particularly useful if downloading datasets with more than a billion image. + +### pyspark configuration + +In order to use video2dataset with pyspark, you will need to do this: +1. `pip install pyspark` +2. use the `--distributor pyspark` option +3. tweak the `--subjob_size 1000` option: this is the number of images to download in each subjob. Increasing it will mean a longer time of preparation to put the feather files in the temporary dir, a shorter time will mean sending less shards at a time to the pyspark job. + +By default a local spark session will be created. +You may want to create a custom spark session depending on your specific spark cluster. -* **message** the message to print. (*required*) ## For development -Either locally, or in [gitpod](https://gitpod.io/#https://github.com/rom1504/video2dataset) (do `export PIP_USER=false` there) +Either locally, or in [gitpod](https://gitpod.io/#https://github.com/iejMac/video2dataset) (do `export PIP_USER=false` there) Setup a virtualenv: @@ -47,3 +176,6 @@ make test You can use `make black` to reformat the code `python -m pytest -x -s -v tests -k "dummy"` to run a specific test + +## Benchmarks + diff --git a/benchmark/run_benchmark.sh b/benchmark/run_benchmark.sh index 6902c1f8..6191872e 100755 --- a/benchmark/run_benchmark.sh +++ b/benchmark/run_benchmark.sh @@ -1,4 +1,12 @@ #!/bin/bash -mkdir -p dataset -video2dataset benchmark_vids.parquet --dest="dataset" --output-format="files" --metadata-columns="videoID,title,description,start,end" +rm -rf dataset + +video2dataset --url_list="benchmark_vids.parquet" \ +--input_format="parquet" \ +--output_folder="dataset" \ +--output-format="files" \ +--url_col="videoLoc" \ +--caption_col="title" \ +--save_additional_columns='[videoID,description,start,end]' \ +--enable_wandb=True diff --git a/requirements-test.txt b/requirements-test.txt index fac95eb8..2b424ffb 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -4,3 +4,4 @@ pylint==2.13.4 pytest-cov==3.0.0 pytest-xdist==2.5.0 pytest==7.0.1 +types-requests \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index fa629401..39574d21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,7 @@ ffmpeg-python yt_dlp pyarrow fsspec -webdataset \ No newline at end of file +webdataset +wandb +pandas +tqdm \ No newline at end of file diff --git a/setup.py b/setup.py index a4e4d6a4..591242ae 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ def _read_reqs(relpath): description="Easily create large video dataset from video urls", long_description=long_description, long_description_content_type="text/markdown", - entry_points={"console_scripts": ["video2dataset=video2dataset.cli:main"]}, + entry_points={"console_scripts": ["video2dataset=video2dataset.main:main"]}, author="Maciej Kilian", author_email="kilianmaciej6@gmail.com", url="https://github.com/iejMac/video2dataset", diff --git a/tests/test_main.py b/tests/test_main.py index 5871ed8e..7bda61da 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1 +1,2 @@ -import pytest +def test_hello_world(): + print("hi") diff --git a/video2dataset/__init__.py b/video2dataset/__init__.py index 0b39a057..cf7f9b9b 100644 --- a/video2dataset/__init__.py +++ b/video2dataset/__init__.py @@ -1,3 +1,3 @@ """video2dataset""" -from .video2dataset import video2dataset +from .main import video2dataset diff --git a/video2dataset/cli.py b/video2dataset/cli.py deleted file mode 100644 index c1ee0be8..00000000 --- a/video2dataset/cli.py +++ /dev/null @@ -1,14 +0,0 @@ -"""cli entry point""" - -import fire - -from video2dataset import video2dataset - - -def main(): - """Main entry point""" - fire.Fire(video2dataset) - - -if __name__ == "__main__": - main() diff --git a/video2dataset/downloader.py b/video2dataset/data_reader.py similarity index 77% rename from video2dataset/downloader.py rename to video2dataset/data_reader.py index 45be3f31..fd66624d 100644 --- a/video2dataset/downloader.py +++ b/video2dataset/data_reader.py @@ -54,10 +54,21 @@ def handle_url(url): else: print("Warning: Incorrect URL type") return None, None, "" - + return file.name, file, name -class Downloader: - def __init__(self): - pass +class VideoDataReader: + """Video data reader provide data for a video""" + + def __init__(self) -> None: + pass + + def __call__(self, row, timeout, retries): + key, url = row + file_name, file, _ = handle_url(url) + with open(file_name, "rb") as vid_file: + vid_bytes = vid_file.read() + if file is not None: # for python files that need to be closed + file.close() + return key, vid_bytes, None diff --git a/video2dataset/data_writer.py b/video2dataset/data_writer.py new file mode 100644 index 00000000..6f8c6436 --- /dev/null +++ b/video2dataset/data_writer.py @@ -0,0 +1,307 @@ +""""writer module handle writing the images to disk""" + +import json +import os + +import fsspec +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import webdataset as wds + + +class BufferedParquetWriter: + """Write samples to parquet files incrementally with a buffer""" + + def __init__(self, output_file, schema, buffer_size=100): + self.buffer_size = buffer_size + self.schema = schema + self._initiatlize_buffer() + fs, output_path = fsspec.core.url_to_fs(output_file) + + self.output_fd = fs.open(output_path, "wb") + self.parquet_writer = pq.ParquetWriter(self.output_fd, schema) + + def _initiatlize_buffer(self): + self.current_buffer_size = 0 + self.buffer = {k: [] for k in self.schema.names} + + def _add_sample_to_buffer(self, sample): + for k in self.schema.names: + self.buffer[k].append(sample[k]) + self.current_buffer_size += 1 + + def write(self, sample): + if self.current_buffer_size >= self.buffer_size: + self.flush() + self._add_sample_to_buffer(sample) + + def flush(self): + """Write the buffer to disk""" + if self.current_buffer_size == 0: + return + + df = pa.Table.from_pydict(self.buffer, self.schema) + self.parquet_writer.write_table(df) + self._initiatlize_buffer() + + def close(self): + self.flush() + if self.parquet_writer is not None: + self.parquet_writer.close() + self.parquet_writer = None + self.output_fd.close() + + +class ParquetSampleWriter: + """ParquetSampleWriter is a image+caption writer to parquet""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_format, + ): + self.oom_shard_count = oom_shard_count + self.encode_format = encode_format + schema = schema.append(pa.field(encode_format, pa.binary())) + shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + output_file = f"{output_folder}/{shard_name}.parquet" + self.buffered_parquet_writer = BufferedParquetWriter(output_file, schema, 100) + self.save_caption = save_caption + + def write(self, img_str, key, caption, meta): + """Keep sample in memory then write to disk when close() is called""" + if img_str is not None: + sample = {"key": key, self.encode_format: img_str} + if self.save_caption: + sample["txt"] = str(caption) if caption is not None else "" + else: + sample = {"key": key, self.encode_format: None} + if self.save_caption: + sample["txt"] = None + sample.update(meta) + self.buffered_parquet_writer.write(sample) + + def close(self): + self.buffered_parquet_writer.close() + + +class WebDatasetSampleWriter: + """WebDatasetSampleWriter is a image+caption writer to webdataset""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_format, + ): + self.oom_shard_count = oom_shard_count + shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + self.shard_id = shard_id + fs, output_path = fsspec.core.url_to_fs(output_folder) + self.tar_fd = fs.open(f"{output_path}/{shard_name}.tar", "wb") + self.tarwriter = wds.TarWriter(self.tar_fd) + self.save_caption = save_caption + self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) + self.encode_format = encode_format + + def write(self, img_str, key, caption, meta): + """write sample to tars""" + if img_str is not None: + sample = {"__key__": key, self.encode_format: img_str} + if self.save_caption: + sample["txt"] = str(caption) if caption is not None else "" + # some meta data may not be JSON serializable + for k, v in meta.items(): + if isinstance(v, np.ndarray): + meta[k] = v.tolist() + sample["json"] = json.dumps(meta, indent=4) + self.tarwriter.write(sample) + self.buffered_parquet_writer.write(meta) + + def close(self): + self.buffered_parquet_writer.close() + self.tarwriter.close() + self.tar_fd.close() + + +class TFRecordSampleWriter: + """TFRecordSampleWriter is a image+caption writer to TFRecord""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_format, + ): + try: + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + import tensorflow_io as _ # pylint: disable=import-outside-toplevel + from tensorflow.python.lib.io.tf_record import TFRecordWriter # pylint: disable=import-outside-toplevel + from tensorflow.python.training.training import ( # pylint: disable=import-outside-toplevel + BytesList, + Example, + Feature, + Features, + FloatList, + Int64List, + ) + + self._BytesList = BytesList # pylint: disable=invalid-name + self._Int64List = Int64List # pylint: disable=invalid-name + self._FloatList = FloatList # pylint: disable=invalid-name + self._Example = Example # pylint: disable=invalid-name + self._Features = Features # pylint: disable=invalid-name + self._Feature = Feature # pylint: disable=invalid-name + except ImportError as e: + raise ModuleNotFoundError( + "tfrecords require tensorflow and tensorflow_io to be installed." + "Run `pip install tensorflow tensorflow_io`." + ) from e + + self.oom_shard_count = oom_shard_count + shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + self.shard_id = shard_id + self.tf_writer = TFRecordWriter(f"{output_folder}/{shard_name}.tfrecord") + self.save_caption = save_caption + self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) + self.encode_format = encode_format + + def write(self, img_str, key, caption, meta): + """Write a sample using tfrecord writer""" + if img_str is not None: + sample = { + "key": self._bytes_feature(key.encode()), + self.encode_format: self._bytes_feature(img_str), + } + if self.save_caption: + sample["txt"] = self._bytes_feature(str(caption) if caption is not None else "") + for k, v in meta.items(): + sample[k] = self._feature(v) + tf_example = self._Example(features=self._Features(feature=sample)) + self.tf_writer.write(tf_example.SerializeToString()) + self.buffered_parquet_writer.write(meta) + + def close(self): + self.buffered_parquet_writer.close() + self.tf_writer.close() + + def _feature(self, value): + """Convert to proper feature type""" + if isinstance(value, list): + return self._list_feature(value) + elif isinstance(value, int): + return self._int64_feature(value) + elif isinstance(value, float): + return self._float_feature(value) + else: + return self._bytes_feature(value) + + def _bytes_feature(self, value): + """Returns a bytes_list from a string / byte.""" + if value is None: + value = "" + if isinstance(value, str): + value = value.encode() + return self._Feature(bytes_list=self._BytesList(value=[value])) + + def _float_feature(self, value): + """Returns a float_list from a float / double.""" + return self._Feature(float_list=self._FloatList(value=[value])) + + def _int64_feature(self, value): + """Returns an int64_list from a bool / enum / int / uint.""" + return self._Feature(int64_list=self._Int64List(value=[value])) + + def _list_feature(self, value): + """Returns an list of int64_list, float_list, bytes_list.""" + if isinstance(value[0], int): + return self._Feature(int64_list=self._Int64List(value=value)) + elif isinstance(value[0], float): + return self._Feature(float_list=self._FloatList(value=value)) + else: + for i, bytes_feature in enumerate(value): + if bytes_feature is None: + value[i] = "" + if isinstance(bytes_feature, str): + value[i] = bytes_feature.encode() + return self._Feature(bytes_list=self._BytesList(value=value)) + + +class FilesSampleWriter: + """FilesSampleWriter is a caption+image writer to files""" + + def __init__( + self, + shard_id, + output_folder, + save_caption, + oom_shard_count, + schema, + encode_format, + ): + self.oom_shard_count = oom_shard_count + shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + self.shard_id = shard_id + self.fs, self.subfolder = fsspec.core.url_to_fs(f"{output_folder}/{shard_name}") + if not self.fs.exists(self.subfolder): + self.fs.mkdir(self.subfolder) + self.save_caption = save_caption + self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) + self.encode_format = encode_format + + def write(self, img_str, key, caption, meta): + """Write sample to disk""" + if img_str is not None: + filename = f"{self.subfolder}/{key}.{self.encode_format}" + with self.fs.open(filename, "wb") as f: + f.write(img_str) + if self.save_caption: + caption = str(caption) if caption is not None else "" + caption_filename = f"{self.subfolder}/{key}.txt" + with self.fs.open(caption_filename, "w") as f: + f.write(str(caption)) + + # some meta data may not be JSON serializable + for k, v in meta.items(): + if isinstance(v, np.ndarray): + meta[k] = v.tolist() + j = json.dumps(meta, indent=4) + meta_filename = f"{self.subfolder}/{key}.json" + with self.fs.open(meta_filename, "w") as f: + f.write(j) + self.buffered_parquet_writer.write(meta) + + def close(self): + self.buffered_parquet_writer.close() + + +class DummySampleWriter: + """Does not write""" + + def __init__(self, shard_id, output_folder, save_caption, oom_shard_count, schema, encode_format): + pass + + def write(self, img_str, key, caption, meta): + pass + + def close(self): + pass diff --git a/video2dataset/distributor.py b/video2dataset/distributor.py new file mode 100644 index 00000000..0153577a --- /dev/null +++ b/video2dataset/distributor.py @@ -0,0 +1,98 @@ +"""distributor defines the distribution strategies for img2dataset""" + +from contextlib import contextmanager +from multiprocessing import get_context +from itertools import islice, chain + +from tqdm import tqdm + + +def retrier(runf, failed_shards, max_shard_retry): + # retry failed shards max_shard_retry times + for i in range(max_shard_retry): + if len(failed_shards) == 0: + break + print(f"Retrying {len(failed_shards)} shards, try {i+1}") + failed_shards = runf(failed_shards) + if len(failed_shards) != 0: + print( + f"Retried {max_shard_retry} times, but {len(failed_shards)} shards " + "still failed. You may restart the same command to retry again." + ) + + +def multiprocessing_distributor(processes_count, worker, input_sharder, _, max_shard_retry): + """Distribute the work to the processes using multiprocessing""" + ctx = get_context("spawn") + with ctx.Pool(processes_count, maxtasksperchild=5) as process_pool: + + def run(gen): + failed_shards = [] + for (status, row) in tqdm(process_pool.imap_unordered(worker, gen)): + if status is False: + failed_shards.append(row) + return failed_shards + + failed_shards = run(input_sharder) + + retrier(run, failed_shards, max_shard_retry) + + process_pool.terminate() + process_pool.join() + del process_pool + + +def pyspark_distributor(processes_count, worker, input_sharder, subjob_size, max_shard_retry): + """Distribute the work to the processes using pyspark""" + + with _spark_session(processes_count) as spark: + + def batcher(iterable, batch_size): + iterator = iter(iterable) + for first in iterator: + yield list(chain([first], islice(iterator, batch_size - 1))) + + def run(gen): + failed_shards = [] + for batch in batcher(gen, subjob_size): + rdd = spark.sparkContext.parallelize(batch, len(batch)) + for (status, row) in rdd.map(worker).collect(): + if status is False: + failed_shards.append(row) + return failed_shards + + failed_shards = run(input_sharder) + + retrier(run, failed_shards, max_shard_retry) + + +@contextmanager +def _spark_session(processes_count: int): + """Create and close a spark session if none exist""" + + from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel + import pyspark # pylint: disable=import-outside-toplevel + + spark_major_version = int(pyspark.version.__version__[0]) + if spark_major_version >= 3: + spark = SparkSession.getActiveSession() + else: + spark = pyspark.sql.SparkSession._instantiatedSession # pylint: disable=protected-access + + if spark is None: + print("No pyspark session found, creating a new one!") + owned = True + spark = ( + SparkSession.builder.config("spark.driver.memory", "16G") + .master("local[" + str(processes_count) + "]") + .appName("spark-stats") + .getOrCreate() + ) + else: + owned = False + + try: + yield spark + finally: + if owned: + spark.stop() diff --git a/video2dataset/input_sharder.py b/video2dataset/input_sharder.py new file mode 100644 index 00000000..7d5a28f6 --- /dev/null +++ b/video2dataset/input_sharder.py @@ -0,0 +1,179 @@ +"""Reader is module to read the url list and return shards""" + +from multiprocessing.pool import ThreadPool +import math +import fsspec +import time +import pyarrow.parquet as pq +import pyarrow.csv as csv_pq +import pyarrow as pa +import pandas as pd + + +class InputSharder: + """ + The reader class reads an url list and returns shards + It provides an iter method + It provides attributes: + - column_list: the list of columns to read + - input_format: the format of the input file + - url_col: the column name of the url + - caption_col: the column name of the caption + - save_additional_columns: the list of additional columns to save + - number_sample_per_shard: the number of samples per shard + - done_shards: a set of already done shards + """ + + def __init__( + self, + url_list, + input_format, + url_col, + caption_col, + save_additional_columns, + number_sample_per_shard, + done_shards, + tmp_path, + ) -> None: + self.input_format = input_format + self.url_col = url_col + self.caption_col = caption_col + self.save_additional_columns = save_additional_columns + self.number_sample_per_shard = number_sample_per_shard + self.done_shards = done_shards + + fs, url_path = fsspec.core.url_to_fs(url_list) + self.fs = fs + self.tmp_path = tmp_path + + if fs.isdir(url_path): + self.input_files = sorted(fs.glob(url_path + "/*." + input_format)) + if len(self.input_files) == 0: + raise Exception(f"No file found at path {url_path} with extension {input_format}") + else: + self.input_files = [url_path] + + if self.input_format == "txt": + self.column_list = ["url"] + elif self.input_format in ["json", "csv", "tsv", "tsv.gz", "parquet"]: + self.column_list = self.save_additional_columns if self.save_additional_columns is not None else [] + if self.caption_col is not None: + self.column_list = self.column_list + ["caption", "url"] + else: + self.column_list = self.column_list + ["url"] + else: + raise ValueError(f"Invalid input format {self.input_format}") + + def _save_to_arrow(self, input_file, start_shard_id): + """Read the input file and save to arrow files in a temporary directory""" + if self.input_format in ["txt", "json", "csv", "tsv"]: + with self.fs.open(input_file, mode="rb") as file: + if self.input_format == "txt": + df = csv_pq.read_csv(file, read_options=csv_pq.ReadOptions(column_names=["url"])) + elif self.input_format == "json": + df = pa.Table.from_pandas(pd.read_json(file)) + elif self.input_format == "csv": + df = csv_pq.read_csv(file) + elif self.input_format == "tsv": + df = csv_pq.read_csv(file, parse_options=csv_pq.ParseOptions(delimiter="\t")) + else: + raise ValueError(f"Unknown input format {self.input_format}") + elif self.input_format == "tsv.gz": + with self.fs.open(input_file, encoding="utf-8", mode="rb", compression="gzip") as file: + df = csv_pq.read_csv(file, parse_options=csv_pq.ParseOptions(delimiter="\t")) + elif self.input_format == "parquet": + with self.fs.open(input_file, mode="rb") as file: + columns_to_read = [self.url_col] + if self.caption_col is not None: + columns_to_read += [self.caption_col] + if self.save_additional_columns is not None: + columns_to_read += self.save_additional_columns + df = pq.read_table(file, columns=columns_to_read) + else: + raise ValueError(f"Unknown input format {self.input_format}") + + column_names = df.column_names + if self.caption_col is not None: + column_names = [c if c != self.caption_col else "caption" for c in column_names] + column_names = [c if c != self.url_col else "url" for c in column_names] + + df = df.rename_columns(column_names) + + number_samples = df.num_rows + + number_shards = math.ceil(df.num_rows / self.number_sample_per_shard) + shards_to_write = [ + (start_shard_id + shard_id, shard_id) + for shard_id in range(number_shards) + if start_shard_id + shard_id not in self.done_shards + ] + if len(shards_to_write) == 0: + return [], number_shards + + def write_shard(t): + full_shard_id, shard_id = t + begin_shard = shard_id * self.number_sample_per_shard + end_shard = min(number_samples, (1 + shard_id) * self.number_sample_per_shard) + df_shard = df.slice(begin_shard, end_shard - begin_shard).select(self.column_list) + tmp_file = self.tmp_path + f"/{full_shard_id}.feather" + for i in range(10): + try: + fs, tmp_path = fsspec.core.url_to_fs(tmp_file) + with fs.open(tmp_path, "wb") as file: + with pa.ipc.new_file(file, df_shard.schema) as writer: + writer.write_table(df_shard) + return (full_shard_id, tmp_file) + except Exception as e: # pylint: disable=broad-except + if i != 9: + print("retrying to write to file due to error:", e) + time.sleep(1) + else: + raise e + # can't reach here + raise Exception("Failed to write to file.") + + for i in range(10): + shards = [] + # thread pool to make it faster to write files to low latency file systems (ie s3, hdfs) + try: + with ThreadPool(32) as thread_pool: + for shard in thread_pool.imap_unordered(write_shard, shards_to_write): + shards.append(shard) + break + except Exception as e: # pylint: disable=broad-except + if i != 9: + print("retrying whole sharding to write to files due to error:", e) + time.sleep(2 * i) + else: + raise e + + shards.sort(key=lambda k: k[0]) + + del df + + return shards, number_shards + + def __iter__(self): + """ + Iterate over shards, yield shards of size number_sample_per_shard or less for the last one + Each shard is a tuple (shard_id, shard) + shard is a tuple (sample id, sample) + sample is a tuple of the columns + """ + start_shard_id = 0 + for i, input_file in enumerate(self.input_files): + print("Sharding file number " + str(i + 1) + " of " + str(len(self.input_files)) + " called " + input_file) + + shards, number_shards = self._save_to_arrow(input_file, start_shard_id) + print("File sharded in " + str(len(shards)) + " shards") + print( + "Downloading starting now, check your bandwidth speed (with bwm-ng)" + "your cpu (with htop), and your disk usage (with iotop)!" + ) + + for shard_id, arrow_file in shards: + yield ( + shard_id, + arrow_file, + ) + start_shard_id += number_shards diff --git a/video2dataset/logger.py b/video2dataset/logger.py new file mode 100644 index 00000000..4a1796e7 --- /dev/null +++ b/video2dataset/logger.py @@ -0,0 +1,298 @@ +"""logging utils for the downloader""" + +import wandb +import time +from collections import Counter +import fsspec +import json +import multiprocessing +import queue +import traceback + + +class CappedCounter: + """Maintain a counter with a capping to avoid memory issues""" + + def __init__(self, max_size=10**5): + self.max_size = max_size + self.counter = Counter() + + def increment(self, key): + if len(self.counter) >= self.max_size: + self._keep_most_frequent() + self.counter[key] += 1 + + def _keep_most_frequent(self): + self.counter = Counter(dict(self.counter.most_common(int(self.max_size / 2)))) + + def most_common(self, k): + return self.counter.most_common(k) + + def update(self, counter): + self.counter.update(counter.counter) + if len(self.counter) >= self.max_size: + self._keep_most_frequent() + + def dump(self): + return self.counter + + @classmethod + def load(cls, d, max_size=10**5): + c = CappedCounter(max_size) + c.counter = Counter(d) + return c + + +class Logger: + """logger which logs when number of calls reaches a value or a time interval has passed""" + + def __init__(self, min_interval=0): + """Log only every if min_interval (seconds) have elapsed since last log""" + # wait for all processes to return + self.processes_returned = 0 + # min time (in seconds) before logging a new table (avoids too many logs) + self.min_interval = min_interval + self.last = time.perf_counter() + # keep track of whether we logged the last call + self.last_call_logged = False + self.last_args = None + self.last_kwargs = None + + def __call__(self, *args, **kwargs): + self.processes_returned += 1 + if time.perf_counter() - self.last > self.min_interval: + self.do_log(*args, **kwargs) + self.last = time.perf_counter() + self.last_call_logged = True + else: + self.last_call_logged = False + self.last_args = args + self.last_kwargs = kwargs + + def do_log(self, *args, **kwargs): + raise NotImplementedError() + + def sync(self): + """Ensure last call is logged""" + if not self.last_call_logged and self.last_args is not None: + self.do_log(*self.last_args, **self.last_kwargs) + # reset for next file + self.processes_returned = 0 + + +class SpeedLogger(Logger): + """Log performance metrics""" + + def __init__(self, prefix, enable_wandb, **logger_args): + super().__init__(**logger_args) + self.prefix = prefix + self.start_time = float("+inf") + self.end_time = float("-inf") + self.count = 0 + self.success = 0 + self.failed_to_download = 0 + self.failed_to_resize = 0 + self.enable_wandb = enable_wandb + + def __call__( + self, count, success, failed_to_download, failed_to_resize, start_time, end_time + ): # pylint: disable=arguments-differ + self.count += count + self.success += success + self.failed_to_download += failed_to_download + self.failed_to_resize += failed_to_resize + self.start_time = min(start_time, self.start_time) + self.end_time = max(end_time, self.end_time) + super().__call__( + self.count, self.success, self.failed_to_download, self.failed_to_resize, self.start_time, self.end_time + ) + + def do_log( + self, count, success, failed_to_download, failed_to_resize, start_time, end_time + ): # pylint: disable=arguments-differ + duration = end_time - start_time + vid_per_sec = count / duration + success_ratio = 1.0 * success / count + failed_to_download_ratio = 1.0 * failed_to_download / count + failed_to_resize_ratio = 1.0 * failed_to_resize / count + + print( + " - ".join( + [ + f"{self.prefix:<7}", + f"success: {success_ratio:.3f}", + f"failed to download: {failed_to_download_ratio:.3f}", + f"failed to resize: {failed_to_resize_ratio:.3f}", + f"images per sec: {vid_per_sec:.0f}", + f"count: {count}", + ] + ) + ) + + if self.enable_wandb: + wandb.log( + { + f"{self.prefix}/vid_per_sec": vid_per_sec, + f"{self.prefix}/success": success_ratio, + f"{self.prefix}/failed_to_download": failed_to_download_ratio, + f"{self.prefix}/failed_to_resize": failed_to_resize_ratio, + f"{self.prefix}/count": count, + } + ) + + +class StatusTableLogger(Logger): + """Log status table to W&B, up to `max_status` most frequent items""" + + def __init__(self, max_status=100, min_interval=60, enable_wandb=False, **logger_args): + super().__init__(min_interval=min_interval, **logger_args) + # avoids too many errors unique to a specific website (SSL certificates, etc) + self.max_status = max_status + self.enable_wandb = enable_wandb + + def do_log(self, status_dict, count): # pylint: disable=arguments-differ + if self.enable_wandb: + status_table = wandb.Table( + columns=["status", "frequency", "count"], + data=[[k, 1.0 * v / count, v] for k, v in status_dict.most_common(self.max_status)], + ) + wandb.run.log({"status": status_table}) + + +def write_stats( + output_folder, + shard_id, + count, + successes, + failed_to_download, + failed_to_resize, + start_time, + end_time, + status_dict, + oom_shard_count, +): + """Write stats to disk""" + stats = { + "count": count, + "successes": successes, + "failed_to_download": failed_to_download, + "failed_to_resize": failed_to_resize, + "duration": end_time - start_time, + "start_time": start_time, + "end_time": end_time, + "status_dict": status_dict.dump(), + } + fs, output_path = fsspec.core.url_to_fs(output_folder) + shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string + shard_id=shard_id, oom_shard_count=oom_shard_count + ) + json_file = f"{output_path}/{shard_name}_stats.json" + with fs.open(json_file, "w") as f: + json.dump(stats, f, indent=4) + + +# https://docs.python.org/3/library/multiprocessing.html +# logger process that reads stats files regularly, aggregates and send to wandb / print to terminal +class LoggerProcess(multiprocessing.context.SpawnProcess): + """Logger process that reads stats files regularly, aggregates and send to wandb / print to terminal""" + + def __init__(self, output_folder, enable_wandb, wandb_project, config_parameters, log_interval=5): + super().__init__() + self.log_interval = log_interval + self.enable_wandb = enable_wandb + self.output_folder = output_folder + self.stats_files = set() + self.wandb_project = wandb_project + self.done_shards = set() + self.config_parameters = config_parameters + ctx = multiprocessing.get_context("spawn") + self.q = ctx.Queue() + + def run(self): + """Run logger process""" + + fs, output_path = fsspec.core.url_to_fs(self.output_folder, use_listings_cache=False) + + if self.enable_wandb: + self.current_run = wandb.init(project=self.wandb_project, config=self.config_parameters, anonymous="allow") + else: + self.current_run = None + self.total_speed_logger = SpeedLogger("total", enable_wandb=self.enable_wandb) + self.status_table_logger = StatusTableLogger(enable_wandb=self.enable_wandb) + last_check = 0 + total_status_dict = CappedCounter() + while True: + time.sleep(0.1) + try: + self.q.get(False) + last_one = True + except queue.Empty as _: + last_one = False + if not last_one and time.perf_counter() - last_check < self.log_interval: + continue + + try: + # read stats files + stats_files = fs.glob(output_path + "/*.json") + + # filter out files that have an id smaller that are already done + stats_files = [f for f in stats_files if int(f.split("/")[-1].split("_")[0]) not in self.done_shards] + + # get new stats files + new_stats_files = set(stats_files) - self.stats_files + if len(new_stats_files) == 0: + if last_one: + self.finish() + return + + # read new stats files + for stats_file in new_stats_files: + with fs.open(stats_file, "r") as f: + try: + stats = json.load(f) + SpeedLogger("worker", enable_wandb=self.enable_wandb)( + count=stats["count"], + success=stats["successes"], + failed_to_download=stats["failed_to_download"], + failed_to_resize=stats["failed_to_resize"], + start_time=stats["start_time"], + end_time=stats["end_time"], + ) + self.total_speed_logger( + count=stats["count"], + success=stats["successes"], + failed_to_download=stats["failed_to_download"], + failed_to_resize=stats["failed_to_resize"], + start_time=stats["start_time"], + end_time=stats["end_time"], + ) + status_dict = CappedCounter.load(stats["status_dict"]) + total_status_dict.update(status_dict) + self.status_table_logger(total_status_dict, self.total_speed_logger.count) + except Exception as err: # pylint: disable=broad-except + print(f"failed to parse stats file {stats_file}", err) + + self.stats_files.add(stats_file) + last_check = time.perf_counter() + + if last_one: + self.finish() + return + except Exception as e: # pylint: disable=broad-except + traceback.print_exc() + print("logger error", e) + self.finish() + return + + def finish(self): + """Finish logger process""" + self.total_speed_logger.sync() + self.status_table_logger.sync() + if self.current_run is not None: + self.current_run.finish() + + def join(self, timeout=None): + """Stop logger process""" + self.q.put("stop") + super().join() + self.q.close() diff --git a/video2dataset/main.py b/video2dataset/main.py new file mode 100644 index 00000000..6089af66 --- /dev/null +++ b/video2dataset/main.py @@ -0,0 +1,156 @@ +"""Create dataset from video links and metadata.""" + + +from typing import List, Optional +import fire + +from .logger import LoggerProcess +from .data_writer import ( + WebDatasetSampleWriter, + FilesSampleWriter, + ParquetSampleWriter, + TFRecordSampleWriter, + DummySampleWriter, +) +from .input_sharder import InputSharder +from .distributor import multiprocessing_distributor, pyspark_distributor +import fsspec +import sys +import signal +import os +from .worker import Worker + + +def video2dataset( + url_list: str, + output_folder: str = "videos", + processes_count: int = 1, + output_format: str = "files", + input_format: str = "txt", + url_col: str = "url", + caption_col: Optional[str] = None, + number_sample_per_shard: int = 10000, + save_additional_columns: Optional[List[str]] = None, + enable_wandb: bool = False, + wandb_project: str = "video2dataset", + oom_shard_count: int = 5, + distributor: str = "multiprocessing", + subjob_size: int = 1000, + retries: int = 0, + incremental_mode: str = "incremental", + max_shard_retry: int = 1, + timeout: int = 60, +): + """ + create video dataset from video links + """ + + config_parameters = dict(locals()) + + def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + output_folder = make_path_absolute(output_folder) + url_list = make_path_absolute(url_list) + + logger_process = LoggerProcess(output_folder, enable_wandb, wandb_project, config_parameters) + + tmp_path = output_folder + "/_tmp" + fs, tmp_dir = fsspec.core.url_to_fs(tmp_path) + if not fs.exists(tmp_dir): + fs.mkdir(tmp_dir) + + def signal_handler(signal_arg, frame): # pylint: disable=unused-argument + try: + fs.rm(tmp_dir, recursive=True) + except Exception as _: # pylint: disable=broad-except + pass + logger_process.terminate() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + + save_caption = caption_col is not None + + fs, output_path = fsspec.core.url_to_fs(output_folder) + + if not fs.exists(output_path): + fs.mkdir(output_path) + done_shards = set() + else: + if incremental_mode == "incremental": + done_shards = set(int(x.split("/")[-1].split("_")[0]) for x in fs.glob(output_path + "/*.json")) + elif incremental_mode == "overwrite": + fs.rm(output_path, recursive=True) + fs.mkdir(output_path) + done_shards = set() + else: + raise ValueError(f"Unknown incremental mode {incremental_mode}") + + logger_process.done_shards = done_shards + logger_process.start() + + input_sharder = InputSharder( + url_list, + input_format, + url_col, + caption_col, + save_additional_columns, + number_sample_per_shard, + done_shards, + tmp_path, + ) + + if output_format == "webdataset": + sample_writer_class = WebDatasetSampleWriter + elif output_format == "parquet": + sample_writer_class = ParquetSampleWriter # type: ignore + elif output_format == "files": + sample_writer_class = FilesSampleWriter # type: ignore + elif output_format == "tfrecord": + sample_writer_class = TFRecordSampleWriter # type: ignore + elif output_format == "dummy": + sample_writer_class = DummySampleWriter # type: ignore + else: + raise ValueError(f"Invalid output format {output_format}") + + worker = Worker( + sample_writer_class=sample_writer_class, + save_caption=save_caption, + output_folder=output_folder, + column_list=input_sharder.column_list, + timeout=timeout, + number_sample_per_shard=number_sample_per_shard, + oom_shard_count=oom_shard_count, + encode_format="mp4", + retries=retries, + ) + + print("Starting the downloading of this file") + if distributor == "multiprocessing": + distributor_fn = multiprocessing_distributor + elif distributor == "pyspark": + distributor_fn = pyspark_distributor + else: + raise ValueError(f"Distributor {distributor} not supported") + + distributor_fn( + processes_count, + worker, + input_sharder, + subjob_size, + max_shard_retry, + ) + logger_process.join() + fs.rm(tmp_dir, recursive=True) + + +def main(): + fire.Fire(video2dataset) + + +if __name__ == "__main__": + main() diff --git a/video2dataset/reader.py b/video2dataset/reader.py deleted file mode 100644 index 0dd72d46..00000000 --- a/video2dataset/reader.py +++ /dev/null @@ -1,68 +0,0 @@ -"""handles input parsing.""" -import pyarrow.parquet as pq -import pyarrow.csv as csv_pq -import pyarrow as pa - - -class Reader: - """Parses input into required data. - - Necessary columns (reader will always look for these columns in parquet and csv): - * videoLoc - location of video either on disc or URL - * videoID - unique ID of each video, if not provided, ID = index - - Additional special columns: - * caption - will be saved in separate key.txt file - - anything else - put in key.json metadata file - """ - - def __init__(self, src, meta_columns=None): - """ - Input: - - src: - str: path to mp4 file - str: youtube link - str: path to txt file with multiple mp4's or youtube links - list[str]: list with multiple mp4's or youtube links - - meta_columns: - list[str]: columns of useful metadata to save with videos - """ - self.columns = ["videoID", "videoLoc"] - no_dupl_temp = [] - for c in self.columns: - if c in meta_columns: - no_dupl_temp.append(c) - meta_columns.remove(c) - - self.meta_columns = meta_columns if meta_columns is not None else [] - - if isinstance(src, str): - if src.endswith(".txt"): - df = csv_pq.read_csv(src, read_options=csv_pq.ReadOptions(column_names=["videoLoc"])) - df = df.add_column(0, "videoID", [list(range(df.num_rows))]) # add ID's - elif src.endswith(".csv"): - df = csv_pq.read_csv(src) - elif src.endswith(".parquet"): - with open(src, "rb") as f: - columns_to_read = self.columns + meta_columns - df = pq.read_table(f, columns=columns_to_read) - else: # singular video (mp4 or link) - src = [src] - if isinstance(src, list): - df = pa.Table.from_arrays([src], names=["videoLoc"]) - df = df.add_column(0, "videoID", [list(range(df.num_rows))]) # add ID's - - for c in no_dupl_temp: - self.meta_columns.append(c) - self.df = df - - def get_data(self): - vids = self.df["videoLoc"].to_pylist() - ids = self.df["videoID"] - meta = dict( # pylint: disable=consider-using-dict-comprehension - [(meta, self.df[meta]) for meta in self.meta_columns] - ) - return vids, ids, meta diff --git a/video2dataset/video2dataset.py b/video2dataset/video2dataset.py deleted file mode 100644 index 905b96f1..00000000 --- a/video2dataset/video2dataset.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Create dataset from video links and metadata.""" - - -from .reader import Reader -from .writer import FileWriter, WebDatasetWriter -from .downloader import handle_url - - -def video2dataset( - src, - dest="", - output_format="webdataset", - metadata_columns="", -): - """ - create video dataset from video links - - src: - str: path to table of data with video links and metdata - dest: - str: where to save dataset to - output_format: - str: webdataset, files - metadata_columns: - str: a comma separated list of metadata column names to look for in src - """ - if isinstance(metadata_columns, str): - metadata_columns = [metadata_columns] if metadata_columns != "" else [] - metadata_columns = list(metadata_columns) if isinstance(metadata_columns, tuple) else metadata_columns - reader = Reader(src, metadata_columns) - vids, ids, meta = reader.get_data() - - starting_shard_id = 0 - shard_sample_count = 10000 - - if output_format == "files": - writer = FileWriter(dest) - elif output_format == "webdataset": - writer = WebDatasetWriter(dest, 9, "mp4", maxcount=shard_sample_count, shard_id=starting_shard_id) - - for i in range(len(vids)): - print(f"{i}/{len(vids)}") - vid = vids[i] - vid_id = ids[i] - vid_meta = {} - for k in meta: - vid_meta[k] = meta[k][i].as_py() - - # NOTE: Right now assuming video is url (maybe add support for local mp4 - load_vid, file, dst_name = handle_url(vid) - with open(load_vid, "rb") as vid_file: - vid_bytes = vid_file.read() - video = vid_bytes - - writer.write(video, vid_id, vid_meta) - - if file is not None: # for python files that need to be closed - file.close() diff --git a/video2dataset/worker.py b/video2dataset/worker.py new file mode 100644 index 00000000..50f8a5f2 --- /dev/null +++ b/video2dataset/worker.py @@ -0,0 +1,179 @@ +"""the downloader module handles the downloading""" + +import math +import time +import pyarrow as pa +import traceback + +import fsspec + +from video2dataset.data_reader import VideoDataReader +from .logger import CappedCounter +from .logger import write_stats + + +def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count): + true_key = (10**oom_sample_per_shard) * shard_id + key + key_format = oom_sample_per_shard + oom_shard_count + str_key = "{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string + key_format=key_format, true_key=true_key + ) + return str_key + + +class Worker: + """The downloader class gets calls with shards, download them then call the writer to write them down""" + + def __init__( + self, + sample_writer_class, + save_caption, + output_folder, + column_list, + timeout, + number_sample_per_shard, + oom_shard_count, + encode_format, + retries, + ) -> None: + self.sample_writer_class = sample_writer_class + self.save_caption = save_caption + self.output_folder = output_folder + self.column_list = column_list + self.timeout = timeout + self.number_sample_per_shard = number_sample_per_shard + self.oom_shard_count = oom_shard_count + self.encode_format = encode_format + self.retries = retries + self.data_reader = VideoDataReader() + + def __call__( + self, + row, + ): + try: + self.download_shard(row) + return (True, row) + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"shard {row[0]} failed with error {err}") + return (False, row) + + def download_shard( + self, + row, + ): + """Function to start an image downloading in one process""" + + shard_id, shard_file = row + start_time = time.time() + + fs, shard_path = fsspec.core.url_to_fs(shard_file) + with fs.open(shard_path, "rb") as f: + df = pa.ipc.open_file(f).read_all() + schema = df.schema + schema = ( + schema.append(pa.field("key", pa.string())) + .append(pa.field("status", pa.string())) + .append(pa.field("error_message", pa.string())) + ) + + pydict = df.select(self.column_list).to_pydict() + shard_to_dl = list(enumerate(zip(*(pydict[col] for col in self.column_list)))) + del pydict + del df + + status_dict = CappedCounter() + + count = len(shard_to_dl) + successes = 0 + failed_to_download = 0 + failed_to_resize = 0 + url_indice = self.column_list.index("url") + caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None + key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl] + + def data_generator(): + for e in key_url_list: + yield e + + loader = data_generator() + + # give schema to writer + sample_writer = self.sample_writer_class( + shard_id, + self.output_folder, + self.save_caption, + self.oom_shard_count, + schema, + self.encode_format, + ) + oom_sample_per_shard = math.ceil(math.log10(self.number_sample_per_shard)) + for key in loader: + key, vid_stream, error_message = self.data_reader(key, timeout=self.timeout, retries=self.retries) + try: + _, sample_data = shard_to_dl[key] + str_key = compute_key(key, shard_id, oom_sample_per_shard, self.oom_shard_count) + meta = { + **{self.column_list[i]: sample_data[i] for i in range(len(self.column_list))}, + "key": str_key, + "status": None, + "error_message": error_message, + } + if error_message is not None: + failed_to_download += 1 + status = "failed_to_download" + status_dict.increment(error_message) + meta["status"] = status + sample_writer.write( + None, + str_key, + sample_data[caption_indice] if caption_indice is not None else None, + meta, + ) + continue + if error_message is not None: + failed_to_resize += 1 + status = "failed_to_resize" + status_dict.increment(error_message) + meta["status"] = status + meta["error_message"] = error_message + sample_writer.write( + None, + str_key, + sample_data[caption_indice] if caption_indice is not None else None, + meta, + ) + continue + successes += 1 + status = "success" + status_dict.increment(status) + + meta["status"] = status + + sample_writer.write( + vid_stream, + str_key, + sample_data[caption_indice] if caption_indice is not None else None, + meta, + ) + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"Sample {key} failed to download: {err}") + + sample_writer.close() + + end_time = time.time() + write_stats( + self.output_folder, + shard_id, + count, + successes, + failed_to_download, + failed_to_resize, + start_time, + end_time, + status_dict, + self.oom_shard_count, + ) + fs.rm(shard_path) diff --git a/video2dataset/writer.py b/video2dataset/writer.py deleted file mode 100644 index f314c2a0..00000000 --- a/video2dataset/writer.py +++ /dev/null @@ -1,92 +0,0 @@ -"""save embeddings.""" -import os -import json - -import fsspec -import numpy as np -import webdataset as wds - -from io import BytesIO - - -class FileWriter: - """Writes output as files.""" - - def __init__(self, output_folder): - self.output_folder = output_folder - - self.fs, self.output_folder = fsspec.core.url_to_fs(output_folder) - - def write(self, video, key, metadata=None): - """write sample to file.""" - key = str(key) - - save_pth = os.path.join(self.output_folder, key + ".mp4") - with self.fs.open(save_pth, "wb") as f: - f.write(video) - - if metadata is not None: - if "caption" in metadata: - caption = str(metadata.pop("caption")) - caption_filename = os.path.join(self.output_folder, key + ".txt") - with self.fs.open(caption_filename, "w") as f: - f.write(caption) - if len(metadata) > 0: - j = json.dumps(metadata, indent=4) - meta_filename = os.path.join(self.output_folder, key + ".json") - with self.fs.open(meta_filename, "w") as f: - f.write(j) - - def close(self): - pass - - -class WebDatasetWriter: - """Writes output in WebDataset format.""" - - def __init__(self, output_folder, oom_shard_count, encode_format, maxcount=10000, shard_id=0): - self.output_folder = output_folder - self.oom_shard_count = oom_shard_count - self.encode_format = encode_format - self.maxcount = maxcount - self.shard_id = shard_id - - self.count = 0 - - self.tarwriter = None - self.tar_fd = None - - self.create_shard() - - def create_shard(self): - """create new shard in sequential order.""" - self.close() - shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string - shard_id=self.shard_id, oom_shard_count=self.oom_shard_count - ) - fs, output_path = fsspec.core.url_to_fs(self.output_folder) - self.tar_fd = fs.open(f"{output_path}/{shard_name}.tar", "wb") - self.tarwriter = wds.TarWriter(self.tar_fd) - - def write(self, video, key, metadata=None): - """write sample to current shard.""" - key = str(key) - if self.count >= self.maxcount: - self.shard_id += 1 - self.count = 0 - self.create_shard() - - sample = {"__key__": key, self.encode_format: video} - if metadata is not None: - if "caption" in metadata: - sample["txt"] = str(metadata.pop("caption")) - if len(metadata) > 0: - sample["json"] = json.dumps(metadata, indent=4) - - self.tarwriter.write(sample) - self.count += 1 - - def close(self): - if self.tarwriter is not None: - self.tarwriter.close() - self.tar_fd.close()