Skip to content

Commit

Permalink
⚗️ Benchmark experiments loading ERA5 Zarr data using kvikIO (#4)
Browse files Browse the repository at this point in the history
Setting up a LightningDataModule with a torchdata DataPipe and looping through one epoch (23 mini-batches) as a timing benchmark experiment using the kvikIO engine. The DataPipe does loading from Zarr, slicing with xbatcher, batching, and collating to torch.Tensor objects. Timer uses Python's time.perf_counter.
  • Loading branch information
weiji14 authored Oct 10, 2023
1 parent 2958bf3 commit 2e397ec
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 0 deletions.
182 changes: 182 additions & 0 deletions 1_benchmark_kvikIOzarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
Experiments on WeatherBench2, loading with kvikIO or Zarr engine.
python 1_benchmark_kvikIOzarr.py
References:
- https://weatherbench2.readthedocs.io/en/latest/index.html
"""
import time

import cupy
import lightning as L
import torch
import torchdata
import torchdata.dataloader2
import tqdm
import xarray as xr
import zen3geo


# %%
def sel_datavars(
dataset: xr.Dataset,
data_vars: list = ["geopotential", "u_component_of_wind", "v_component_of_wind"],
) -> xr.Dataset:
"""
Select specific data variables from an xarray.Dataset object.
"""
return dataset.get(key=data_vars) # .sel(level=500)


def xarray_to_tensor_collate_fn(
samples: torchdata.datapipes.DataChunk,
) -> (torch.Tensor, torch.Tensor, list[dict]):
"""
Converts individual xarray.Dataset objects to torch.Tensor (float32 dtype),
and stack them all into a single torch.Tensor. Also outputs a metadata list
of dictionaries that contains the timestamp of the xarray.Dataset's data
variable.
"""
tensor_t0: torch.Tensor = torch.stack(
tensors=[
torch.as_tensor(data=sample.isel(time=0).to_array().data, device="cuda")
for sample in samples
]
)
tensor_t1: torch.Tensor = torch.stack(
tensors=[
torch.as_tensor(data=sample.isel(time=1).to_array().data, device="cuda")
for sample in samples
]
)
metadata: list[dict] = [
{
"unixtime0": torch.as_tensor(
data=sample.time.isel(time=0).astype(dtype="int64").data, device="cuda"
),
"unixtime1": torch.as_tensor(
data=sample.time.isel(time=1).astype(dtype="int64").data, device="cuda"
),
}
for sample in samples
]

return tensor_t0, tensor_t1, metadata


# %%
class WeatherBench2DataModule(L.LightningDataModule):
"""
LightningDataModule to load WeatherBench2 data from Zarr.
"""

def __init__(
self,
zarr_store: str = "2020-full_37-6h-0p25deg-chunk-1_zuv500.zarr",
# zarr_store: str = "gs://weatherbench2/datasets/era5/1959-2022-full_37-6h-0p25deg-chunk-1.zarr-v2",
engine: str = "kvikio",
batch_size: int = 32,
):
"""
Go from a Zarr datacube to 6-hourly time-slice chips!
Also does mini-batching.
Parameters
----------
zarr_stores : str
A path or URL to a Zarr stores to read from. E.g. ``store1.zarr``.
See list of available WeatherBench2 Zarr stores at
https://weatherbench2.readthedocs.io/en/latest/data-guide.html
engine : str
The engine to use in `xr.open_dataset`. E.g. ``kvikio``, ``zarr``.
Default is ``kvikio``.
batch_size : int
Size of each mini-batch. Default is 32.
Returns
-------
datapipe : torchdata.datapipes.iter.IterDataPipe
A torch DataPipe that can be passed into a torch DataLoader.
"""
super().__init__()
self.zarr_store: str = zarr_store
self.engine: str = engine
self.batch_size: int = batch_size

print(f"Loading data using {self.engine} engine")

def setup(self, stage: str | None = None) -> torchdata.datapipes.iter.IterDataPipe:
"""
Data operations to perform on every GPU.
Split data into training and test sets, etc.
Returns
-------
datapipes : IterDataPipe
The torch DataPipe object to iterate over the training set.
"""
# Step 0 - Iterate through all the Zarr stores
dp_source: torchdata.datapipes.iter.IterDataPipe = (
torchdata.datapipes.iter.IterableWrapper(iterable=[self.zarr_store])
)

# Step 1 - Create WeatherBench2 xarray.Dataset chips (full metadata)
dp_weather_chips: torchdata.datapipes.iter.IterDataPipe = (
# Step 1.0 - Open each Zarr store using xarray
dp_source.read_from_xpystac(
engine=self.engine, chunks=None, consolidated=False
)
# Step 1.1 - Select desired data variables at 500hPa
.map(fn=sel_datavars)
# Step 1.2 - Slice datacube along time-dimension into 12 hour chunks (2x 6-hourly)
.slice_with_xbatcher(
input_dims={"latitude": 721, "longitude": 1440, "time": 2},
preload_batch=False,
)
)

# Step 2 - Train/validation split each chip based on geography
# TODO

# Step 3 - Batch and split ERA5 chips into Machine Learning format
self.datapipe_train = (
# Step 3.1 - Create mini-batches (default is 32)
dp_weather_chips.batch(batch_size=self.batch_size)
# Step 3.2 - Convert xarray.Dataset to torch.Tensor and stack
.collate(collate_fn=xarray_to_tensor_collate_fn)
)

def train_dataloader(self) -> torchdata.dataloader2.DataLoader2:
"""
Loads the data used in the training loop.
"""
return torchdata.dataloader2.DataLoader2(datapipe=self.datapipe_train)


# %%
if __name__ == "__main__":
# Optimize torch performance
torch.set_float32_matmul_precision(precision="medium")

# Setup data
datamodule: L.LightningDataModule = WeatherBench2DataModule(engine="kvikio")
datamodule.setup()
train_dataloader = datamodule.train_dataloader()

# Start timing
tic = time.perf_counter()

# Training loop
for i, batch in tqdm.tqdm(iterable=enumerate(train_dataloader), total=23):
input, target, metadata = batch
# Compute Mean Squared Error loss between t=0 and t=1, just for fun
loss: torch.Tensor = torch.functional.F.mse_loss(input=input, target=target)
print(f"Batch {i}, MSE Loss: {loss}")

# Stop timing
toc = time.perf_counter()
print(f"Total: {toc - tic:0.4f} seconds")
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ This will save a one year subset of the WeatherBench2 ERA5 dataset to your
local disk. It will include data at pressure level 500hPa, with the variables
'geopotential', 'u_component_of_wind', and 'v_component_of_wind' only.

To run the benchmark experiment loading with the kvikIO engine, run:

python 1_benchmark_kvikIOzarr.py

This will print out a progress bar showing the ERA5 data being loaded in
mini-batches (simulating a neural network training loop), and a total count
for the time taken to finish. One 'epoch' should take under 15 seconds on an
Ampere generation (e.g. RTX A2000) NVIDIA GPU.

# References

## Links
Expand Down

0 comments on commit 2e397ec

Please sign in to comment.