Skip to content

Commit

Permalink
feat: sample builder ergonomics (#28)
Browse files Browse the repository at this point in the history
- add `gather_every` (~strides) and `length` to FixedWindowSampleBuilder
- remove RollingWindowSampleBuilder
  • Loading branch information
egorchakov authored Dec 3, 2024
1 parent 34c2f16 commit a61a037
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 78 deletions.
2 changes: 1 addition & 1 deletion config/_templates/dataloader/torch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ collate_fn:
num_workers: 1
pin_memory: false
persistent_workers: true
multiprocessing_context: forkserver
multiprocessing_context: spawn
10 changes: 1 addition & 9 deletions config/_templates/dataset/yaak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,7 @@ inputs:
index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx
every: 6i
period: 6i

- _target_: pipefunc.PipeFunc
renames:
input: samples
output_name: samples_filtered
func:
_target_: rbyte.io.DataFrameFilter
predicate: |
array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6
length: 6

kwargs:
meta:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rbyte"
version = "0.10.0"
version = "0.10.1"
description = "Multimodal PyTorch dataset library"
authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }]
maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }]
Expand Down
9 changes: 2 additions & 7 deletions src/rbyte/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from importlib.metadata import version

from .dataset import Dataset
from .sample import FixedWindowSampleBuilder, RollingWindowSampleBuilder
from .sample import FixedWindowSampleBuilder

__version__ = version(__package__ or __name__)

__all__ = [
"Dataset",
"FixedWindowSampleBuilder",
"RollingWindowSampleBuilder",
"__version__",
]
__all__ = ["Dataset", "FixedWindowSampleBuilder", "__version__"]
3 changes: 1 addition & 2 deletions src/rbyte/sample/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .fixed_window import FixedWindowSampleBuilder
from .rolling_window import RollingWindowSampleBuilder

__all__ = ["FixedWindowSampleBuilder", "RollingWindowSampleBuilder"]
__all__ = ["FixedWindowSampleBuilder"]
17 changes: 12 additions & 5 deletions src/rbyte/sample/fixed_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import polars as pl
from polars._typing import ClosedInterval
from pydantic import validate_call
from pydantic import PositiveInt, validate_call


@final
Expand All @@ -19,18 +19,26 @@ class FixedWindowSampleBuilder:
__name__ = __qualname__

@validate_call
def __init__(
def __init__( # noqa: PLR0913
self,
*,
index_column: str,
every: str | timedelta,
period: str | timedelta | None = None,
closed: ClosedInterval = "left",
gather_every: PositiveInt = 1,
length: PositiveInt | None = None,
) -> None:
self._index_column = pl.col(index_column)
self._every = every
self._period = period
self._closed: ClosedInterval = closed
self._gather_every = gather_every
self._length_filter = (
(self._index_column.list.len() > 0)
if length is None
else (self._index_column.list.len() == length)
)

def __call__(self, input: pl.DataFrame) -> pl.DataFrame:
return (
Expand All @@ -44,8 +52,7 @@ def __call__(self, input: pl.DataFrame) -> pl.DataFrame:
label="datapoint",
start_by="datapoint",
)
.agg(pl.all())
.filter(self._index_column.list.len() > 0)
.sort(_index_column)
.agg(pl.all().gather_every(self._gather_every))
.filter(self._length_filter)
.drop(_index_column)
)
48 changes: 0 additions & 48 deletions src/rbyte/sample/rolling_window.py

This file was deleted.

5 changes: 0 additions & 5 deletions src/rbyte/scripts/visualize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from multiprocessing.context import ForkServerContext
from typing import Any, cast

import hydra
import torch.multiprocessing as mp
from hydra.utils import instantiate
from omegaconf import DictConfig
from structlog import get_logger
Expand All @@ -19,9 +17,6 @@ def main(config: DictConfig) -> None:
logger = cast(Logger[Any], instantiate(config.logger))
dataloader = cast(DataLoader[Any], instantiate(config.dataloader))

if isinstance(dataloader.multiprocessing_context, ForkServerContext): # pyright: ignore[reportUnknownMemberType]
mp.set_forkserver_preload(["rbyte"])

for batch_idx, batch in enumerate(tqdm(dataloader)):
logger.log(batch_idx, batch)

Expand Down

0 comments on commit a61a037

Please sign in to comment.