From 2e397ec009ef952ac5f8c27fcba62cc091a4bcb2 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:40:36 +1300 Subject: [PATCH] :alembic: Benchmark experiments loading ERA5 Zarr data using kvikIO (#4) Setting up a LightningDataModule with a torchdata DataPipe and looping through one epoch (23 mini-batches) as a timing benchmark experiment using the kvikIO engine. The DataPipe does loading from Zarr, slicing with xbatcher, batching, and collating to torch.Tensor objects. Timer uses Python's time.perf_counter. --- 1_benchmark_kvikIOzarr.py | 182 ++++++++++++++++++++++++++++++++++++++ README.md | 9 ++ 2 files changed, 191 insertions(+) create mode 100644 1_benchmark_kvikIOzarr.py diff --git a/1_benchmark_kvikIOzarr.py b/1_benchmark_kvikIOzarr.py new file mode 100644 index 0000000..bf9ae1d --- /dev/null +++ b/1_benchmark_kvikIOzarr.py @@ -0,0 +1,182 @@ +""" +Experiments on WeatherBench2, loading with kvikIO or Zarr engine. + + python 1_benchmark_kvikIOzarr.py + +References: +- https://weatherbench2.readthedocs.io/en/latest/index.html +""" +import time + +import cupy +import lightning as L +import torch +import torchdata +import torchdata.dataloader2 +import tqdm +import xarray as xr +import zen3geo + + +# %% +def sel_datavars( + dataset: xr.Dataset, + data_vars: list = ["geopotential", "u_component_of_wind", "v_component_of_wind"], +) -> xr.Dataset: + """ + Select specific data variables from an xarray.Dataset object. + """ + return dataset.get(key=data_vars) # .sel(level=500) + + +def xarray_to_tensor_collate_fn( + samples: torchdata.datapipes.DataChunk, +) -> (torch.Tensor, torch.Tensor, list[dict]): + """ + Converts individual xarray.Dataset objects to torch.Tensor (float32 dtype), + and stack them all into a single torch.Tensor. Also outputs a metadata list + of dictionaries that contains the timestamp of the xarray.Dataset's data + variable. + """ + tensor_t0: torch.Tensor = torch.stack( + tensors=[ + torch.as_tensor(data=sample.isel(time=0).to_array().data, device="cuda") + for sample in samples + ] + ) + tensor_t1: torch.Tensor = torch.stack( + tensors=[ + torch.as_tensor(data=sample.isel(time=1).to_array().data, device="cuda") + for sample in samples + ] + ) + metadata: list[dict] = [ + { + "unixtime0": torch.as_tensor( + data=sample.time.isel(time=0).astype(dtype="int64").data, device="cuda" + ), + "unixtime1": torch.as_tensor( + data=sample.time.isel(time=1).astype(dtype="int64").data, device="cuda" + ), + } + for sample in samples + ] + + return tensor_t0, tensor_t1, metadata + + +# %% +class WeatherBench2DataModule(L.LightningDataModule): + """ + LightningDataModule to load WeatherBench2 data from Zarr. + """ + + def __init__( + self, + zarr_store: str = "2020-full_37-6h-0p25deg-chunk-1_zuv500.zarr", + # zarr_store: str = "gs://weatherbench2/datasets/era5/1959-2022-full_37-6h-0p25deg-chunk-1.zarr-v2", + engine: str = "kvikio", + batch_size: int = 32, + ): + """ + Go from a Zarr datacube to 6-hourly time-slice chips! + + Also does mini-batching. + + Parameters + ---------- + zarr_stores : str + A path or URL to a Zarr stores to read from. E.g. ``store1.zarr``. + See list of available WeatherBench2 Zarr stores at + https://weatherbench2.readthedocs.io/en/latest/data-guide.html + + engine : str + The engine to use in `xr.open_dataset`. E.g. ``kvikio``, ``zarr``. + Default is ``kvikio``. + + batch_size : int + Size of each mini-batch. Default is 32. + + Returns + ------- + datapipe : torchdata.datapipes.iter.IterDataPipe + A torch DataPipe that can be passed into a torch DataLoader. + """ + super().__init__() + self.zarr_store: str = zarr_store + self.engine: str = engine + self.batch_size: int = batch_size + + print(f"Loading data using {self.engine} engine") + + def setup(self, stage: str | None = None) -> torchdata.datapipes.iter.IterDataPipe: + """ + Data operations to perform on every GPU. + Split data into training and test sets, etc. + + Returns + ------- + datapipes : IterDataPipe + The torch DataPipe object to iterate over the training set. + """ + # Step 0 - Iterate through all the Zarr stores + dp_source: torchdata.datapipes.iter.IterDataPipe = ( + torchdata.datapipes.iter.IterableWrapper(iterable=[self.zarr_store]) + ) + + # Step 1 - Create WeatherBench2 xarray.Dataset chips (full metadata) + dp_weather_chips: torchdata.datapipes.iter.IterDataPipe = ( + # Step 1.0 - Open each Zarr store using xarray + dp_source.read_from_xpystac( + engine=self.engine, chunks=None, consolidated=False + ) + # Step 1.1 - Select desired data variables at 500hPa + .map(fn=sel_datavars) + # Step 1.2 - Slice datacube along time-dimension into 12 hour chunks (2x 6-hourly) + .slice_with_xbatcher( + input_dims={"latitude": 721, "longitude": 1440, "time": 2}, + preload_batch=False, + ) + ) + + # Step 2 - Train/validation split each chip based on geography + # TODO + + # Step 3 - Batch and split ERA5 chips into Machine Learning format + self.datapipe_train = ( + # Step 3.1 - Create mini-batches (default is 32) + dp_weather_chips.batch(batch_size=self.batch_size) + # Step 3.2 - Convert xarray.Dataset to torch.Tensor and stack + .collate(collate_fn=xarray_to_tensor_collate_fn) + ) + + def train_dataloader(self) -> torchdata.dataloader2.DataLoader2: + """ + Loads the data used in the training loop. + """ + return torchdata.dataloader2.DataLoader2(datapipe=self.datapipe_train) + + +# %% +if __name__ == "__main__": + # Optimize torch performance + torch.set_float32_matmul_precision(precision="medium") + + # Setup data + datamodule: L.LightningDataModule = WeatherBench2DataModule(engine="kvikio") + datamodule.setup() + train_dataloader = datamodule.train_dataloader() + + # Start timing + tic = time.perf_counter() + + # Training loop + for i, batch in tqdm.tqdm(iterable=enumerate(train_dataloader), total=23): + input, target, metadata = batch + # Compute Mean Squared Error loss between t=0 and t=1, just for fun + loss: torch.Tensor = torch.functional.F.mse_loss(input=input, target=target) + print(f"Batch {i}, MSE Loss: {loss}") + + # Stop timing + toc = time.perf_counter() + print(f"Total: {toc - tic:0.4f} seconds") diff --git a/README.md b/README.md index a88b0e1..bb3508f 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,15 @@ This will save a one year subset of the WeatherBench2 ERA5 dataset to your local disk. It will include data at pressure level 500hPa, with the variables 'geopotential', 'u_component_of_wind', and 'v_component_of_wind' only. +To run the benchmark experiment loading with the kvikIO engine, run: + + python 1_benchmark_kvikIOzarr.py + +This will print out a progress bar showing the ERA5 data being loaded in +mini-batches (simulating a neural network training loop), and a total count +for the time taken to finish. One 'epoch' should take under 15 seconds on an +Ampere generation (e.g. RTX A2000) NVIDIA GPU. + # References ## Links