From cc5b915740811f9ceb06d3907fe8071cc55d9a14 Mon Sep 17 00:00:00 2001 From: Evgenii Gorchakov Date: Thu, 28 Nov 2024 12:36:34 +0100 Subject: [PATCH] feat: pipeline-based sample building (#25) --- .github/workflows/ci.yaml | 2 +- .github/workflows/release.yaml | 2 +- .pre-commit-config.yaml | 4 +- config/_templates/dataset/carla.yaml | 88 +++--- config/_templates/dataset/mimicgen.yaml | 62 +++-- config/_templates/dataset/nuscenes/mcap.yaml | 133 +++++---- config/_templates/dataset/nuscenes/rrd.yaml | 124 +++++---- config/_templates/dataset/yaak.yaml | 254 ++++++++++-------- config/_templates/dataset/zod.yaml | 180 +++++++------ config/_templates/logger/rerun/carla.yaml | 2 - .../logger/rerun/nuscenes/mcap.yaml | 9 +- .../_templates/logger/rerun/nuscenes/rrd.yaml | 10 +- justfile | 20 ++ pyproject.toml | 20 +- src/rbyte/config/base.py | 20 +- src/rbyte/dataset.py | 38 +-- src/rbyte/io/__init__.py | 41 +-- src/rbyte/io/_json/__init__.py | 4 +- src/rbyte/io/_json/dataframe_builder.py | 47 ++++ src/rbyte/io/_json/table_reader.py | 85 ------ src/rbyte/io/_mcap/__init__.py | 4 +- .../{table_reader.py => dataframe_builder.py} | 95 +++---- src/rbyte/io/_mcap/tensor_source.py | 6 +- src/rbyte/io/_numpy/tensor_source.py | 6 +- src/rbyte/io/dataframe/__init__.py | 13 + src/rbyte/io/{table => dataframe}/aligner.py | 71 ++--- src/rbyte/io/dataframe/concater.py | 29 ++ src/rbyte/io/dataframe/filter.py | 14 + .../transforms => dataframe}/fps_resampler.py | 25 +- src/rbyte/io/dataframe/indexer.py | 18 ++ src/rbyte/io/hdf5/__init__.py | 4 +- src/rbyte/io/hdf5/dataframe_builder.py | 52 ++++ src/rbyte/io/hdf5/table_reader.py | 97 ------- src/rbyte/io/path/__init__.py | 4 +- src/rbyte/io/path/dataframe_builder.py | 51 ++++ src/rbyte/io/path/table_reader.py | 75 ------ src/rbyte/io/rrd/__init__.py | 4 +- src/rbyte/io/rrd/dataframe_builder.py | 79 ++++++ src/rbyte/io/rrd/table_reader.py | 106 -------- src/rbyte/io/table/__init__.py | 6 - src/rbyte/io/table/base.py | 28 -- src/rbyte/io/table/builder.py | 90 ------- src/rbyte/io/table/concater.py | 42 --- src/rbyte/io/table/transforms/__init__.py | 4 - src/rbyte/io/table/transforms/base.py | 8 - src/rbyte/io/yaak/__init__.py | 4 +- .../{table_reader.py => dataframe_builder.py} | 47 ++-- src/rbyte/io/yaak/message_iterator.py | 6 +- src/rbyte/sample/base.py | 8 - src/rbyte/sample/fixed_window.py | 23 +- src/rbyte/sample/rolling_window.py | 23 +- src/rbyte/utils/__init__.py | 3 + src/rbyte/utils/{mcap => _mcap}/__init__.py | 0 .../{mcap => _mcap}/decoders/__init__.py | 0 .../decoders/json_decoder_factory.py | 0 .../decoders/protobuf_decoder_factory.py | 0 src/rbyte/utils/_pipefunc.py | 33 +++ .../utils/{dataframe/misc.py => dataframe.py} | 0 src/rbyte/utils/dataframe/__init__.py | 4 - src/rbyte/utils/dataframe/cache.py | 42 --- src/rbyte/utils/{functional.py => tensor.py} | 2 +- tests/test_dataloader.py | 34 +-- 62 files changed, 1072 insertions(+), 1233 deletions(-) create mode 100644 src/rbyte/io/_json/dataframe_builder.py delete mode 100644 src/rbyte/io/_json/table_reader.py rename src/rbyte/io/_mcap/{table_reader.py => dataframe_builder.py} (59%) create mode 100644 src/rbyte/io/dataframe/__init__.py rename src/rbyte/io/{table => dataframe}/aligner.py (62%) create mode 100644 src/rbyte/io/dataframe/concater.py create mode 100644 src/rbyte/io/dataframe/filter.py rename src/rbyte/io/{table/transforms => dataframe}/fps_resampler.py (50%) create mode 100644 src/rbyte/io/dataframe/indexer.py create mode 100644 src/rbyte/io/hdf5/dataframe_builder.py delete mode 100644 src/rbyte/io/hdf5/table_reader.py create mode 100644 src/rbyte/io/path/dataframe_builder.py delete mode 100644 src/rbyte/io/path/table_reader.py create mode 100644 src/rbyte/io/rrd/dataframe_builder.py delete mode 100644 src/rbyte/io/rrd/table_reader.py delete mode 100644 src/rbyte/io/table/__init__.py delete mode 100644 src/rbyte/io/table/base.py delete mode 100644 src/rbyte/io/table/builder.py delete mode 100644 src/rbyte/io/table/concater.py delete mode 100644 src/rbyte/io/table/transforms/__init__.py delete mode 100644 src/rbyte/io/table/transforms/base.py rename src/rbyte/io/yaak/{table_reader.py => dataframe_builder.py} (61%) delete mode 100644 src/rbyte/sample/base.py rename src/rbyte/utils/{mcap => _mcap}/__init__.py (100%) rename src/rbyte/utils/{mcap => _mcap}/decoders/__init__.py (100%) rename src/rbyte/utils/{mcap => _mcap}/decoders/json_decoder_factory.py (100%) rename src/rbyte/utils/{mcap => _mcap}/decoders/protobuf_decoder_factory.py (100%) create mode 100644 src/rbyte/utils/_pipefunc.py rename src/rbyte/utils/{dataframe/misc.py => dataframe.py} (100%) delete mode 100644 src/rbyte/utils/dataframe/__init__.py delete mode 100644 src/rbyte/utils/dataframe/cache.py rename src/rbyte/utils/{functional.py => tensor.py} (95%) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5753901..a1b2432 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -42,7 +42,7 @@ jobs: only: ytt - name: setup uv - uses: astral-sh/setup-uv@v3 + uses: astral-sh/setup-uv@v4 with: version: "latest" enable-cache: true diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index bcf9ac2..12085ec 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -33,7 +33,7 @@ jobs: fetch-depth: 0 - name: build and inspect - uses: hynek/build-and-inspect-python-package@v2.9.0 + uses: hynek/build-and-inspect-python-package@v2 with: attest-build-provenance-github: "true" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bcaac53..d898dfa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,14 +11,14 @@ repos: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.4 + rev: v0.8.0 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.21.1 + rev: 1.22.0 hooks: - id: basedpyright diff --git a/config/_templates/dataset/carla.yaml b/config/_templates/dataset/carla.yaml index 5fbdb3d..05f940b 100644 --- a/config/_templates/dataset/carla.yaml +++ b/config/_templates/dataset/carla.yaml @@ -75,39 +75,65 @@ inputs: #@ end - table_builder: - _target_: rbyte.io.table.TableBuilder - _convert_: all - readers: - ego_logs: - path: ${data_dir}/(@=input_id@)/ego_logs.json - reader: - _target_: rbyte.io.JsonTableReader - _recursive_: false - fields: - records: - _idx_: - control.brake: - control.throttle: - control.steer: - state.velocity.value: - state.acceleration.value: + samples: + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + bound: + path: ${data_dir}/(@=input_id@)/ego_logs.json + output_name: ego_logs + func: + _target_: rbyte.io.JsonDataFrameBuilder + fields: + records: + control.brake: + control.throttle: + control.steer: + state.velocity.value: + state.acceleration.value: - transforms: - - _target_: rbyte.io.FpsResampler - source_fps: 20 - target_fps: 30 + - _target_: pipefunc.PipeFunc + renames: + input: ego_logs + output_name: data + func: + _target_: rbyte.io.DataFrameConcater + method: vertical - merger: - _target_: rbyte.io.TableConcater - method: vertical + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: data_resampled + func: + _target_: rbyte.io.DataFrameFpsResampler + fps_in: 20 + fps_out: 30 - filter: | - `control.throttle` > 0.5 + - _target_: pipefunc.PipeFunc + renames: + input: data_resampled + output_name: data_indexed + func: + _target_: rbyte.io.DataFrameIndexer + name: _idx_ - #@ end + - _target_: pipefunc.PipeFunc + renames: + input: data_indexed + output_name: data_filtered + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + `control.throttle` > 0.5 -sample_builder: - _target_: rbyte.RollingWindowSampleBuilder - index_column: _idx_ - period: 1i + - _target_: pipefunc.PipeFunc + renames: + input: data_filtered + output_name: samples + func: + _target_: rbyte.RollingWindowSampleBuilder + index_column: _idx_ + period: 1i + #@ end diff --git a/config/_templates/dataset/mimicgen.yaml b/config/_templates/dataset/mimicgen.yaml index fffc463..2c321e9 100644 --- a/config/_templates/dataset/mimicgen.yaml +++ b/config/_templates/dataset/mimicgen.yaml @@ -12,8 +12,8 @@ #@ ] --- _target_: rbyte.Dataset -_convert_: all _recursive_: false +_convert_: all inputs: #@ for input_id, input_keys in inputs.items(): #@ for input_key in input_keys: @@ -28,27 +28,45 @@ inputs: key: (@=input_key@)/(@=frame_key@) #@ end - table_builder: - _target_: rbyte.io.TableBuilder - _convert_: all - readers: - hdf5: - path: "${data_dir}/(@=input_id@).hdf5" - reader: - _target_: rbyte.io.Hdf5TableReader - _recursive_: false - fields: - (@=input_key@): - _idx_: - obs/robot0_eef_pos: + samples: + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + bound: + path: "${data_dir}/(@=input_id@).hdf5" + output_name: data + func: + _target_: rbyte.io.Hdf5DataFrameBuilder + fields: + (@=input_key@): + obs/robot0_eef_pos: + + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: data_indexed + func: + _target_: rbyte.io.DataFrameIndexer + name: _idx_ + + - _target_: pipefunc.PipeFunc + renames: + input: data_indexed + output_name: data_concated + func: + _target_: rbyte.io.DataFrameConcater + method: vertical + + - _target_: pipefunc.PipeFunc + renames: + input: data_concated + output_name: samples + func: + _target_: rbyte.RollingWindowSampleBuilder + index_column: _idx_ + period: 1i - merger: - _target_: rbyte.io.TableConcater - method: vertical #@ end #@ end - -sample_builder: - _target_: rbyte.RollingWindowSampleBuilder - index_column: _idx_ - period: 1i diff --git a/config/_templates/dataset/nuscenes/mcap.yaml b/config/_templates/dataset/nuscenes/mcap.yaml index 3097f15..beb0380 100644 --- a/config/_templates/dataset/nuscenes/mcap.yaml +++ b/config/_templates/dataset/nuscenes/mcap.yaml @@ -11,15 +11,15 @@ #@ } --- _target_: rbyte.Dataset -_convert_: all _recursive_: false +_convert_: all inputs: #@ for input_id in inputs: (@=input_id@): sources: #@ for camera, topic in camera_topics.items(): (@=camera@): - index_column: mcap/(@=topic@)/_idx_ + index_column: (@=topic@)/_idx_ source: _target_: rbyte.io.McapTensorSource path: "${data_dir}/(@=input_id@).mcap" @@ -33,66 +33,85 @@ inputs: fastupsample: true #@ end - table_builder: - _target_: rbyte.io.TableBuilder - _convert_: all - readers: - mcap: - path: "${data_dir}/(@=input_id@).mcap" - reader: - _target_: rbyte.io.McapTableReader - _recursive_: false - decoder_factories: - - rbyte.utils.mcap.ProtobufDecoderFactory - - rbyte.utils.mcap.JsonDecoderFactory + samples: + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + bound: + path: ${data_dir}/(@=input_id@).mcap + output_name: data + func: + _target_: rbyte.io.McapDataFrameBuilder + decoder_factories: + - rbyte.utils._mcap.ProtobufDecoderFactory + - rbyte.utils._mcap.JsonDecoderFactory + fields: + #@ for topic in camera_topics.values(): + (@=topic@): + log_time: + _target_: polars.Datetime + time_unit: ns + #@ end - fields: - #@ for topic in camera_topics.values(): - (@=topic@): - _idx_: - log_time: - _target_: polars.Datetime - time_unit: ns - #@ end + /odom: + log_time: + _target_: polars.Datetime + time_unit: ns + vel.x: - /odom: - log_time: - _target_: polars.Datetime - time_unit: ns - vel.x: + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: data_indexed + func: + _target_: rbyte.io.DataFrameIndexer + name: _idx_ - merger: - _target_: rbyte.io.TableAligner - separator: "/" - merge: - mcap: - #@ topic = camera_topics.values()[0] - (@=topic@): - key: log_time + - _target_: pipefunc.PipeFunc + renames: + input: data_indexed + output_name: data_aligned + func: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + #@ topic = camera_topics.values()[0] + (@=topic@): + key: log_time - #@ for topic in camera_topics.values()[1:]: - (@=topic@): - key: log_time - columns: - _idx_: - method: asof - tolerance: 40ms - strategy: nearest - #@ end + #@ for topic in camera_topics.values()[1:]: + (@=topic@): + key: log_time + columns: + _idx_: + method: asof + tolerance: 40ms + strategy: nearest + #@ end - /odom: - key: log_time - columns: - vel.x: - method: interp + /odom: + key: log_time + columns: + vel.x: + method: interp - filter: | - `mcap//odom/vel.x` >= 8 + - _target_: pipefunc.PipeFunc + renames: + input: data_aligned + output_name: data_filtered + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + `/odom/vel.x` >= 8 - cache: + - _target_: pipefunc.PipeFunc + renames: + input: data_filtered + output_name: samples + func: + _target_: rbyte.RollingWindowSampleBuilder + index_column: (@=camera_topics.values()[0]@)/_idx_ + period: 1i #@ end - -sample_builder: - _target_: rbyte.RollingWindowSampleBuilder - index_column: mcap/(@=camera_topics.values()[0]@)/_idx_ - period: 1i diff --git a/config/_templates/dataset/nuscenes/rrd.yaml b/config/_templates/dataset/nuscenes/rrd.yaml index c824b6e..308412a 100644 --- a/config/_templates/dataset/nuscenes/rrd.yaml +++ b/config/_templates/dataset/nuscenes/rrd.yaml @@ -11,15 +11,15 @@ #@ } --- _target_: rbyte.Dataset -_convert_: all _recursive_: false +_convert_: all inputs: #@ for input_id in inputs: (@=input_id@): sources: #@ for camera, entity in camera_entities.items(): (@=camera@): - index_column: rrd/(@=entity@)/_idx_ + index_column: (@=entity@)/_idx_ source: _target_: rbyte.io.RrdFrameSource path: "${data_dir}/(@=input_id@).rrd" @@ -33,57 +33,79 @@ inputs: fastupsample: true #@ end - table_builder: - _target_: rbyte.io.TableBuilder - _convert_: all - readers: - rrd: - path: "${data_dir}/(@=input_id@).rrd" - reader: - _target_: rbyte.io.RrdTableReader - _recursive_: false - index: timestamp - contents: - #@ for entity in camera_entities.values(): - (@=entity@): - - _idx_ - #@ end + samples: + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + bound: + path: "${data_dir}/(@=input_id@).rrd" + output_name: data + func: + _target_: rbyte.io.RrdDataFrameBuilder + index: timestamp + contents: + #@ for entity in camera_entities.values(): + (@=entity@): + #@ end - /world/ego_vehicle/LIDAR_TOP: - - Position3D + /world/ego_vehicle/LIDAR_TOP: + - Position3D - merger: - _target_: rbyte.io.TableAligner - separator: / - merge: - rrd: - #@ entity = camera_entities.values()[0] - (@=entity@): - key: timestamp + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: data_indexed + func: + _target_: rbyte.io.DataFrameIndexer + name: _idx_ - #@ for entity in camera_entities.values()[1:]: - (@=entity@): - key: timestamp - columns: - _idx_: - method: asof - strategy: nearest - tolerance: 40ms - #@ end + - _target_: pipefunc.PipeFunc + renames: + input: data_indexed + output_name: data_aligned + func: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + #@ entity = camera_entities.values()[0] + (@=entity@): + key: timestamp - /world/ego_vehicle/LIDAR_TOP: - key: timestamp - columns: - Position3D: - method: asof - strategy: nearest - tolerance: 40ms - - filter: | - `rrd//world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50' - #@ end + #@ for entity in camera_entities.values()[1:]: + (@=entity@): + key: timestamp + columns: + _idx_: + method: asof + strategy: nearest + tolerance: 60ms + #@ end -sample_builder: - _target_: rbyte.RollingWindowSampleBuilder - index_column: rrd/(@=camera_entities.values()[0]@)/_idx_ - period: 1i + /world/ego_vehicle/LIDAR_TOP: + key: timestamp + columns: + Position3D: + method: asof + strategy: nearest + tolerance: 60ms + + - _target_: pipefunc.PipeFunc + renames: + input: data_aligned + output_name: data_filtered + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + `/world/ego_vehicle/CAM_FRONT/timestamp` between '2018-07-24 03:28:48' and '2018-07-24 03:28:50' + + - _target_: pipefunc.PipeFunc + renames: + input: data_filtered + output_name: samples + func: + _target_: rbyte.RollingWindowSampleBuilder + index_column: (@=camera_entities.values()[0]@)/_idx_ + period: 1i + #@ end diff --git a/config/_templates/dataset/yaak.yaml b/config/_templates/dataset/yaak.yaml index 7cffab1..a083f30 100644 --- a/config/_templates/dataset/yaak.yaml +++ b/config/_templates/dataset/yaak.yaml @@ -22,118 +22,152 @@ inputs: index_column: "meta/ImageMetadata.(@=source_id@)/frame_idx" source: _target_: rbyte.io.FfmpegFrameSource - _recursive_: true path: "${data_dir}/(@=input_id@)/(@=source_id@).pii.mp4" resize_shorter_side: 324 #@ end - table_builder: - _target_: rbyte.io.TableBuilder - _convert_: all - readers: - meta: - path: ${data_dir}/(@=input_id@)/metadata.log - reader: - _target_: rbyte.io.YaakMetadataTableReader - _recursive_: false - fields: - rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - frame_idx: polars.Int32 - camera_name: - _target_: polars.Enum - categories: - - cam_front_center - - cam_front_left - - cam_front_right - - cam_left_forward - - cam_right_forward - - cam_left_backward - - cam_right_backward - - cam_rear - - rbyte.io.yaak.proto.can_pb2.VehicleMotion: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - speed: polars.Float32 - gear: - _target_: polars.Enum - categories: ["0", "1", "2", "3"] - - mcap: - path: ${data_dir}/(@=input_id@)/ai.mcap - reader: - _target_: rbyte.io.McapTableReader - _recursive_: false - decoder_factories: [rbyte.utils.mcap.ProtobufDecoderFactory] - fields: - /ai/safety_score: - clip.end_timestamp: - _target_: polars.Datetime - time_unit: ns - - score: polars.Float32 - - merger: - _target_: rbyte.io.TableAligner - separator: "/" - merge: - meta: - ImageMetadata.(@=cameras[0]@): - key: time_stamp - - #@ for camera in cameras[1:]: - ImageMetadata.(@=camera@): - key: time_stamp - columns: - frame_idx: - method: asof - tolerance: 20ms - strategy: nearest - #@ end - - VehicleMotion: - key: time_stamp - columns: - speed: - method: interp - gear: - method: asof - tolerance: 100ms - strategy: nearest - - mcap: - /ai/safety_score: - key: clip.end_timestamp - columns: - clip.end_timestamp: - method: asof - tolerance: 500ms - strategy: nearest - score: - method: asof - tolerance: 500ms - strategy: nearest - - filter: | - `meta/VehicleMotion/gear` == '3' - - cache: - _target_: rbyte.utils.dataframe.DataframeDiskCache - directory: /tmp/rbyte-cache - size_limit: 1GiB - #@ end + samples: + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + bound: + path: ${data_dir}/(@=input_id@)/metadata.log + output_name: meta_data + func: + _target_: rbyte.io.YaakMetadataDataFrameBuilder + fields: + rbyte.io.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + frame_idx: + _target_: polars.Int32 + + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + speed: + _target_: polars.Float32 + + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + + - _target_: pipefunc.PipeFunc + bound: + path: ${data_dir}/(@=input_id@)/ai.mcap + output_name: mcap_data + func: + _target_: rbyte.io.McapDataFrameBuilder + decoder_factories: [rbyte.utils._mcap.ProtobufDecoderFactory] + fields: + /ai/safety_score: + clip.end_timestamp: + _target_: polars.Datetime + time_unit: ns + + score: + _target_: polars.Float32 -sample_builder: - _target_: rbyte.FixedWindowSampleBuilder - index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx - every: 6i - period: 6i - filter: | - array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6 - and array_mean(`meta/VehicleMotion/speed`) > 40 + - _target_: pipefunc.PipeFunc + func: + _target_: hydra.utils.get_method + path: rbyte.utils.make_dict + bound: + k0: meta + k1: mcap + renames: + v0: meta_data + v1: mcap_data + output_name: data + + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: data_aligned + func: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + meta: + ImageMetadata.(@=cameras[0]@): + key: time_stamp + + #@ for camera in cameras[1:]: + ImageMetadata.(@=camera@): + key: time_stamp + columns: + frame_idx: + method: asof + tolerance: 20ms + strategy: nearest + #@ end + + VehicleMotion: + key: time_stamp + columns: + speed: + method: interp + gear: + method: asof + tolerance: 100ms + strategy: nearest + + mcap: + /ai/safety_score: + key: clip.end_timestamp + columns: + clip.end_timestamp: + method: asof + tolerance: 500ms + strategy: nearest + score: + method: asof + tolerance: 500ms + strategy: nearest + + - _target_: pipefunc.PipeFunc + renames: + input: data_aligned + output_name: data_filtered + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + `meta/VehicleMotion/gear` == '3' + + - _target_: pipefunc.PipeFunc + renames: + input: data_filtered + output_name: samples + func: + _target_: rbyte.FixedWindowSampleBuilder + index_column: meta/ImageMetadata.(@=cameras[0]@)/frame_idx + every: 6i + period: 6i + + - _target_: pipefunc.PipeFunc + renames: + input: samples + output_name: samples_filtered + func: + _target_: rbyte.io.DataFrameFilter + predicate: | + array_length(`meta/ImageMetadata.(@=cameras[0]@)/time_stamp`) == 6 + #@ end diff --git a/config/_templates/dataset/zod.yaml b/config/_templates/dataset/zod.yaml index 9e56288..1a07f07 100644 --- a/config/_templates/dataset/zod.yaml +++ b/config/_templates/dataset/zod.yaml @@ -1,7 +1,7 @@ --- _target_: rbyte.Dataset -_convert_: all _recursive_: false +_convert_: all inputs: 000002_short: sources: @@ -24,92 +24,118 @@ inputs: path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{:%Y-%m-%dT%H:%M:%S.%f}Z.npy" select: ["x", "y", "z"] - table_builder: - _target_: rbyte.io.TableBuilder - _convert_: all - readers: - camera_front_blur: - path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" - reader: - _target_: rbyte.io.PathTableReader - _recursive_: false - fields: - timestamp: - _target_: polars.Datetime - time_unit: ns - - lidar_velodyne: - path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy" - reader: - _target_: rbyte.io.PathTableReader - _recursive_: false - fields: - timestamp: - _target_: polars.Datetime - time_unit: ns - - vehicle_data: - path: "${data_dir}/sequences/000002_short/vehicle_data.hdf5" - reader: - _target_: rbyte.io.Hdf5TableReader - _recursive_: false - fields: - ego_vehicle_controls: - timestamp/nanoseconds/value: + samples: + pipeline: + _target_: pipefunc.Pipeline + validate_type_annotations: false + functions: + - _target_: pipefunc.PipeFunc + bound: + path: "${data_dir}/sequences/000002_short/camera_front_blur/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.jpg" + output_name: camera_front_blur_data + func: + _target_: rbyte.io.PathDataFrameBuilder + fields: + timestamp: _target_: polars.Datetime time_unit: ns - acceleration_pedal/ratio/unitless/value: - steering_wheel_angle/angle/radians/value: - - satellite: - timestamp/nanoseconds/value: + - _target_: pipefunc.PipeFunc + bound: + path: "${data_dir}/sequences/000002_short/lidar_velodyne/000002_romeo_{timestamp:%Y-%m-%dT%H:%M:%S.%f}Z.npy" + output_name: lidar_velodyne_data + func: + _target_: rbyte.io.PathDataFrameBuilder + fields: + timestamp: _target_: polars.Datetime time_unit: ns - speed/meters_per_second/value: + - _target_: pipefunc.PipeFunc + bound: + path: "${data_dir}/sequences/000002_short/vehicle_data.hdf5" + output_name: vehicle_data + func: + _target_: rbyte.io.Hdf5DataFrameBuilder + fields: + ego_vehicle_controls: + timestamp/nanoseconds/value: + _target_: polars.Datetime + time_unit: ns + + acceleration_pedal/ratio/unitless/value: + steering_wheel_angle/angle/radians/value: + + satellite: + timestamp/nanoseconds/value: + _target_: polars.Datetime + time_unit: ns + + speed/meters_per_second/value: + + - _target_: pipefunc.PipeFunc + bound: + k0: camera_front_blur + k1: lidar_velodyne + k2: vehicle_data + renames: + v0: camera_front_blur_data + v1: lidar_velodyne_data + v2: vehicle_data + output_name: data + func: + _target_: hydra.utils.get_method + path: rbyte.utils.make_dict - merger: - _target_: rbyte.io.TableAligner - separator: "/" - merge: - camera_front_blur: - key: timestamp + - _target_: pipefunc.PipeFunc + renames: + input: data + output_name: data_aligned + func: + _target_: rbyte.io.DataFrameAligner + separator: / + fields: + camera_front_blur: + key: timestamp - lidar_velodyne: - key: timestamp - columns: - timestamp: - method: asof - strategy: nearest - tolerance: 100ms + lidar_velodyne: + key: timestamp + columns: + timestamp: + method: asof + strategy: nearest + tolerance: 100ms - vehicle_data: - ego_vehicle_controls: - key: timestamp/nanoseconds/value - columns: - timestamp/nanoseconds/value: - method: asof - strategy: nearest - tolerance: 100ms + vehicle_data: + ego_vehicle_controls: + key: timestamp/nanoseconds/value + columns: + timestamp/nanoseconds/value: + method: asof + strategy: nearest + tolerance: 100ms - acceleration_pedal/ratio/unitless/value: - method: asof - strategy: nearest - tolerance: 100ms + acceleration_pedal/ratio/unitless/value: + method: asof + strategy: nearest + tolerance: 100ms - steering_wheel_angle/angle/radians/value: - method: asof - strategy: nearest - tolerance: 100ms + steering_wheel_angle/angle/radians/value: + method: asof + strategy: nearest + tolerance: 100ms - satellite: - key: timestamp/nanoseconds/value - columns: - speed/meters_per_second/value: - method: interp + satellite: + key: timestamp/nanoseconds/value + columns: + speed/meters_per_second/value: + method: interp -sample_builder: - _target_: rbyte.FixedWindowSampleBuilder - index_column: camera_front_blur/timestamp - every: 300ms + - _target_: pipefunc.PipeFunc + renames: + input: data_aligned + output_name: samples + func: + _target_: rbyte.FixedWindowSampleBuilder + index_column: camera_front_blur/timestamp + every: 300ms diff --git a/config/_templates/logger/rerun/carla.yaml b/config/_templates/logger/rerun/carla.yaml index 58a8b2a..d8865fc 100644 --- a/config/_templates/logger/rerun/carla.yaml +++ b/config/_templates/logger/rerun/carla.yaml @@ -5,8 +5,6 @@ #@ ] --- _target_: rbyte.viz.loggers.RerunLogger -_recursive_: true -_convert_: all schema: #@ for camera in cameras: (@=camera@): diff --git a/config/_templates/logger/rerun/nuscenes/mcap.yaml b/config/_templates/logger/rerun/nuscenes/mcap.yaml index 98389ee..9351213 100644 --- a/config/_templates/logger/rerun/nuscenes/mcap.yaml +++ b/config/_templates/logger/rerun/nuscenes/mcap.yaml @@ -15,12 +15,12 @@ schema: #@ end #@ topic = camera_topics.values()[0] - mcap/(@=topic@)/_idx_: TimeSequenceColumn - mcap/(@=topic@)/log_time: TimeNanosColumn + (@=topic@)/_idx_: TimeSequenceColumn + (@=topic@)/log_time: TimeNanosColumn #@ for topic in camera_topics.values()[1:]: - mcap/(@=topic@)/_idx_: TimeSequenceColumn + (@=topic@)/_idx_: TimeSequenceColumn #@ end - mcap//odom/vel.x: Scalar + /odom/vel.x: Scalar spawn: true blueprint: @@ -36,4 +36,3 @@ blueprint: #@ end - _target_: rerun.blueprint.TimeSeriesView - origin: mcap/ diff --git a/config/_templates/logger/rerun/nuscenes/rrd.yaml b/config/_templates/logger/rerun/nuscenes/rrd.yaml index babdb29..2e1c59b 100644 --- a/config/_templates/logger/rerun/nuscenes/rrd.yaml +++ b/config/_templates/logger/rerun/nuscenes/rrd.yaml @@ -16,12 +16,12 @@ schema: #@ end #@ entity = camera_entities.values()[0] - rrd/(@=entity@)/_idx_: TimeSequenceColumn - rrd/(@=entity@)/timestamp: TimeNanosColumn + (@=entity@)/_idx_: TimeSequenceColumn + (@=entity@)/timestamp: TimeNanosColumn #@ for entity in camera_entities.values()[1:]: - rrd/(@=entity@)/_idx_: TimeSequenceColumn + (@=entity@)/_idx_: TimeSequenceColumn #@ end - #! rrd//world/ego_vehicle/LIDAR_TOP/Position3D: Points3D + /world/ego_vehicle/LIDAR_TOP/Position3D: Points3D blueprint: _target_: rerun.blueprint.Blueprint @@ -29,7 +29,7 @@ blueprint: - _target_: rerun.blueprint.Vertical contents: - _target_: rerun.blueprint.Spatial3DView - origin: rrd/ + origin: - _target_: rerun.blueprint.Horizontal contents: diff --git a/justfile b/justfile index 5efc81b..d56b02d 100644 --- a/justfile +++ b/justfile @@ -61,6 +61,26 @@ visualize *ARGS: generate-config hydra/job_logging=disabled \ {{ ARGS }} +[group('visualize')] +visualize-mimicgen: + just visualize dataset=mimicgen logger=rerun/mimicgen ++data_dir={{ justfile_directory() }}/tests/data/mimicgen + +[group('visualize')] +visualize-yaak: + just visualize dataset=yaak logger=rerun/yaak ++data_dir={{ justfile_directory() }}/tests/data/yaak + +[group('visualize')] +visualize-zod: + just visualize dataset=zod logger=rerun/zod ++data_dir={{ justfile_directory() }}/tests/data/zod + +[group('visualize')] +visualize-nuscenes-mcap: + just visualize dataset=nuscenes/mcap logger=rerun/nuscenes/mcap ++data_dir={{ justfile_directory() }}/tests/data/nuscenes/mcap + +[group('visualize')] +visualize-nuscenes-rrd: + just visualize dataset=nuscenes/rrd logger=rerun/nuscenes/rrd ++data_dir={{ justfile_directory() }}/tests/data/nuscenes/rrd + # rerun server and viewer rerun bind="0.0.0.0" port="9876" ws-server-port="9877" web-viewer-port="9090": RUST_LOG=debug uv run rerun \ diff --git a/pyproject.toml b/pyproject.toml index 26c1b9f..5c94036 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.8.0" +version = "0.9.0" description = "Multimodal PyTorch dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] @@ -8,8 +8,8 @@ dependencies = [ "tensordict>=0.6.2", "torch", "numpy", - "polars>=1.14.0", - "pydantic>=2.9.2", + "polars>=1.15.0", + "pydantic>=2.10.2", "more-itertools>=10.5.0", "hydra-core>=1.3.2", "optree>=0.13.1", @@ -20,6 +20,7 @@ dependencies = [ "structlog>=24.4.0", "xxhash>=3.5.0", "tqdm>=4.66.5", + "pipefunc>=0.40.2", ] readme = "README.md" requires-python = ">=3.12,<3.13" @@ -68,8 +69,8 @@ build-backend = "hatchling.build" [tool.uv] dev-dependencies = [ - "wat-inspector>=0.4.2", - "lovely-tensors>=0.1.17", + "wat-inspector>=0.4.3", + "lovely-tensors>=0.1.18", "pudb>=2024.1.2", "ipython>=8.29.0", "ipython-autoimport>=0.5", @@ -88,8 +89,8 @@ reproducible = true [tool.hatch.build.targets.sdist] include = ["src/rbyte"] -exclude = ["src/rbyte/io/table/yaak/idl-repo"] -artifacts = ["src/rbyte/io/table/yaak/proto/*_pb2.py*"] +exclude = ["src/rbyte/io/yaak/idl-repo"] +artifacts = ["src/rbyte/io/yaak/proto/*_pb2.py*"] [tool.hatch.build.targets.sdist.hooks.custom] enable-by-default = true @@ -97,13 +98,14 @@ require-runtime-features = ["build"] [tool.hatch.build.targets.wheel] packages = ["src/rbyte"] -artifacts = ["src/rbyte/io/table/yaak/proto/*_pb2.py*"] +artifacts = ["src/rbyte/io/yaak/proto/*_pb2.py*"] [tool.basedpyright] typeCheckingMode = "all" enableTypeIgnoreComments = true reportMissingTypeStubs = "none" reportAny = "none" +reportExplicitAny = "none" reportIgnoreCommentWithoutRule = "error" venvPath = "." @@ -127,7 +129,7 @@ skip-magic-trailing-comma = true preview = true select = ["ALL"] fixable = ["ALL"] -ignore = ["D", "CPY", "COM812", "F722", "PD901", "ISC001", "TD"] +ignore = ["A001", "A002", "D", "CPY", "COM812", "F722", "PD901", "ISC001", "TD"] [tool.ruff.lint.isort] split-on-trailing-comma = false diff --git a/src/rbyte/config/base.py b/src/rbyte/config/base.py index 4935432..c3cc0f9 100644 --- a/src/rbyte/config/base.py +++ b/src/rbyte/config/base.py @@ -1,9 +1,10 @@ from functools import cached_property -from typing import ClassVar, Literal, TypeVar +from typing import ClassVar, Literal from hydra.utils import instantiate from pydantic import BaseModel as _BaseModel -from pydantic import ConfigDict, Field, ImportString, field_serializer, model_validator +from pydantic import ConfigDict, Field, ImportString, model_validator +from pydantic import RootModel as _RootModel class BaseModel(_BaseModel): @@ -16,7 +17,13 @@ class BaseModel(_BaseModel): ) -T = TypeVar("T") +class RootModel[T](_RootModel[T]): + model_config: ClassVar[ConfigDict] = ConfigDict( + arbitrary_types_allowed=True, + frozen=True, + validate_assignment=True, + ignored_types=(cached_property,), + ) class HydraConfig[T](BaseModel): @@ -25,18 +32,13 @@ class HydraConfig[T](BaseModel): target: ImportString[type[T]] = Field(alias="_target_") recursive: bool = Field(alias="_recursive_", default=True) convert: Literal["none", "partial", "object", "all"] = Field( - alias="_convert_", default="none" + alias="_convert_", default="all" ) partial: bool = Field(alias="_partial_", default=False) def instantiate(self, **kwargs: object) -> T: return instantiate(self.model_dump(by_alias=True), **kwargs) - @field_serializer("target") - @staticmethod - def serialize_target(v: object) -> str: - return ImportString._serialize(v) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue] # noqa: SLF001 - @model_validator(mode="before") @classmethod def validate_model(cls, data: object) -> object: diff --git a/src/rbyte/dataset.py b/src/rbyte/dataset.py index 97c106b..196cff1 100644 --- a/src/rbyte/dataset.py +++ b/src/rbyte/dataset.py @@ -5,6 +5,7 @@ import polars as pl import torch +from pipefunc import Pipeline from pydantic import Field, StringConstraints, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars @@ -14,9 +15,7 @@ from rbyte.batch import Batch, BatchMeta from rbyte.config import BaseModel, HydraConfig from rbyte.io.base import TensorSource -from rbyte.io.table.base import TableBuilder -from rbyte.sample.base import SampleBuilder -from rbyte.utils.functional import pad_sequence +from rbyte.utils.tensor import pad_sequence __all__ = ["Dataset"] @@ -32,9 +31,15 @@ class SourceConfig(BaseModel): index_column: str +class PipelineConfig(BaseModel): + pipeline: HydraConfig[Pipeline] + output_name: str | None = None + kwargs: dict[str, object] = Field(default_factory=dict) + + class InputConfig(BaseModel): sources: Mapping[Id, SourceConfig] = Field(min_length=1) - table_builder: HydraConfig[TableBuilder] + samples: PipelineConfig @unique @@ -50,24 +55,27 @@ class Column(StrEnum): class Dataset(TorchDataset[TensorDict]): @validate_call(config=BaseModel.model_config) def __init__( - self, - inputs: Annotated[Mapping[Id, InputConfig], Field(min_length=1)], - sample_builder: HydraConfig[SampleBuilder], + self, inputs: Annotated[Mapping[Id, InputConfig], Field(min_length=1)] ) -> None: logger.debug("initializing dataset") super().__init__() - _sample_builder = sample_builder.instantiate() samples: Mapping[str, pl.DataFrame] = {} for input_id, input_cfg in inputs.items(): with bound_contextvars(input_id=input_id): - table = input_cfg.table_builder.instantiate().build() - samples[input_id] = _sample_builder.build(table) + samples_cfg = input_cfg.samples + pipeline = samples_cfg.pipeline.instantiate() + output_name = ( + samples_cfg.output_name or pipeline.unique_leaf_node.output_name # pyright: ignore[reportUnknownMemberType] + ) + samples[input_id] = pipeline.run( + output_name=output_name, kwargs=samples_cfg.kwargs + ) logger.debug( "built samples", - rows=table.select(pl.len()).item(), - samples=samples[input_id].select(pl.len()).item(), + columns=samples[input_id].columns, + len=len(samples[input_id]), ) input_id_enum = pl.Enum(sorted(samples)) @@ -145,7 +153,7 @@ def __getitems__(self, indexes: Sequence[int]) -> Batch: # noqa: PLW3201 .agg(Column.source_config, Column.source_idxs) ) - tensors: Mapping[str, torch.Tensor] = { + tensor_data: Mapping[str, torch.Tensor] = { row[Column.source_id]: pad_sequence( [ self._get_source(source)[idxs] @@ -159,11 +167,11 @@ def __getitems__(self, indexes: Sequence[int]) -> Batch: # noqa: PLW3201 for row in sources.collect().iter_rows(named=True) } - table: Mapping[str, Sequence[object]] = samples.select( + sample_data: Mapping[str, Sequence[object]] = samples.select( pl.exclude(Column.sample_idx, Column.input_id).to_physical() ).to_dict(as_series=False) - data = TensorDict(tensors | table, batch_size=batch_size) # pyright: ignore[reportArgumentType] + data = TensorDict(tensor_data | sample_data, batch_size=batch_size) # pyright: ignore[reportArgumentType] meta = BatchMeta( sample_idx=samples[Column.sample_idx].to_torch(), # pyright: ignore[reportCallIssue] diff --git a/src/rbyte/io/__init__.py b/src/rbyte/io/__init__.py index 34e7d26..4662482 100644 --- a/src/rbyte/io/__init__.py +++ b/src/rbyte/io/__init__.py @@ -1,39 +1,46 @@ -from ._json import JsonTableReader +from ._json import JsonDataFrameBuilder from ._numpy import NumpyTensorSource -from .path import PathTableReader, PathTensorSource -from .table import FpsResampler, TableAligner, TableBuilder, TableConcater +from .dataframe import ( + DataFrameAligner, + DataFrameConcater, + DataFrameFilter, + DataFrameFpsResampler, + DataFrameIndexer, +) +from .path import PathDataFrameBuilder, PathTensorSource __all__: list[str] = [ - "FpsResampler", - "JsonTableReader", + "DataFrameAligner", + "DataFrameConcater", + "DataFrameFilter", + "DataFrameFpsResampler", + "DataFrameIndexer", + "JsonDataFrameBuilder", "NumpyTensorSource", - "PathTableReader", + "PathDataFrameBuilder", "PathTensorSource", - "TableAligner", - "TableBuilder", - "TableConcater", ] try: - from .hdf5 import Hdf5TableReader, Hdf5TensorSource + from .hdf5 import Hdf5DataFrameBuilder, Hdf5TensorSource except ImportError: pass else: - __all__ += ["Hdf5TableReader", "Hdf5TensorSource"] + __all__ += ["Hdf5DataFrameBuilder", "Hdf5TensorSource"] try: - from ._mcap import McapTableReader, McapTensorSource + from ._mcap import McapDataFrameBuilder, McapTensorSource except ImportError: pass else: - __all__ += ["McapTableReader", "McapTensorSource"] + __all__ += ["McapDataFrameBuilder", "McapTensorSource"] try: - from .rrd import RrdFrameSource, RrdTableReader + from .rrd import RrdDataFrameBuilder, RrdFrameSource except ImportError: pass else: - __all__ += ["RrdFrameSource", "RrdTableReader"] + __all__ += ["RrdDataFrameBuilder", "RrdFrameSource"] try: from .video.ffmpeg_source import FfmpegFrameSource @@ -43,8 +50,8 @@ __all__ += ["FfmpegFrameSource"] try: - from .yaak import YaakMetadataTableReader + from .yaak import YaakMetadataDataFrameBuilder except ImportError: pass else: - __all__ += ["YaakMetadataTableReader"] + __all__ += ["YaakMetadataDataFrameBuilder"] diff --git a/src/rbyte/io/_json/__init__.py b/src/rbyte/io/_json/__init__.py index 2c19de4..6477553 100644 --- a/src/rbyte/io/_json/__init__.py +++ b/src/rbyte/io/_json/__init__.py @@ -1,3 +1,3 @@ -from .table_reader import JsonTableReader +from .dataframe_builder import JsonDataFrameBuilder -__all__ = ["JsonTableReader"] +__all__ = ["JsonDataFrameBuilder"] diff --git a/src/rbyte/io/_json/dataframe_builder.py b/src/rbyte/io/_json/dataframe_builder.py new file mode 100644 index 0000000..79fe202 --- /dev/null +++ b/src/rbyte/io/_json/dataframe_builder.py @@ -0,0 +1,47 @@ +from collections.abc import Mapping +from os import PathLike +from pathlib import Path +from typing import final + +import polars as pl +from optree import PyTree +from polars._typing import PolarsDataType # noqa: PLC2701 +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) +from pydantic import ConfigDict, validate_call + +from rbyte.utils.dataframe import unnest_all + +type Fields = Mapping[str, Mapping[str, PolarsDataType | None]] + + +@final +class JsonDataFrameBuilder: + __name__ = __qualname__ + + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__(self, fields: Fields) -> None: + self._fields = fields + + def __call__(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: + dfs: Mapping[str, pl.DataFrame] = {} + + for k, series in ( + pl.read_json(Path(path)).select(self._fields).to_dict().items() + ): + df_schema = { + name: dtype + for name, dtype in self._fields[k].items() + if dtype is not None + } + df = pl.DataFrame(series).lazy().explode(k).unnest(k) + dfs[k] = ( + df.select(unnest_all(df.collect_schema())) + .select(self._fields[k].keys()) + .cast(df_schema) # pyright: ignore[reportArgumentType] + .collect() + ) + + return dfs # pyright: ignore[reportReturnType] diff --git a/src/rbyte/io/_json/table_reader.py b/src/rbyte/io/_json/table_reader.py deleted file mode 100644 index 3ecd60a..0000000 --- a/src/rbyte/io/_json/table_reader.py +++ /dev/null @@ -1,85 +0,0 @@ -import json -from collections.abc import Hashable, Mapping, Sequence -from enum import StrEnum, unique -from functools import cached_property -from os import PathLike -from pathlib import Path -from typing import override - -import polars as pl -from optree import PyTree, tree_map -from polars._typing import PolarsDataType -from polars.datatypes import ( - DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 - DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 -) -from pydantic import Field -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableReader -from rbyte.io.table.transforms.base import TableTransform -from rbyte.utils.dataframe.misc import unnest_all - - -@unique -class SpecialField(StrEnum): - idx = "_idx_" - - -class Config(BaseModel): - fields: Mapping[str, Mapping[str, HydraConfig[PolarsDataType] | None]] - transforms: Sequence[HydraConfig[TableTransform]] = Field(default=()) - - -class JsonTableReader(TableReader, Hashable): - def __init__(self, **kwargs: object) -> None: - self._config: Config = Config.model_validate(kwargs) - - @override - def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: - dfs: Mapping[str, pl.DataFrame] = {} - - for k, series in ( - pl.read_json(Path(path)).select(self._fields).to_dict().items() - ): - df_schema = { - name: dtype - for name, dtype in self._fields[k].items() - if dtype is not None - } - df = pl.DataFrame(series).lazy().explode(k).unnest(k) - df = ( - df.select(unnest_all(df.collect_schema())) - .select(self._config.fields[k].keys() - set(SpecialField)) - .cast(df_schema) # pyright: ignore[reportArgumentType] - .collect() - ) - - for transform in self._transforms: - df = transform(df) - - if (idx_name := SpecialField.idx) in (df_schema := self._fields[k]): - df = df.with_row_index(idx_name).cast({ - idx_name: df_schema[idx_name] or pl.UInt32 - }) - - dfs[k] = df - - return dfs # pyright: ignore[reportReturnType] - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) - - @cached_property - def _fields(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: - return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType, reportReturnType] - - @cached_property - def _transforms(self) -> Sequence[TableTransform]: - return tuple(transform.instantiate() for transform in self._config.transforms) diff --git a/src/rbyte/io/_mcap/__init__.py b/src/rbyte/io/_mcap/__init__.py index aba5317..7bc6fd8 100644 --- a/src/rbyte/io/_mcap/__init__.py +++ b/src/rbyte/io/_mcap/__init__.py @@ -1,4 +1,4 @@ -from .table_reader import McapTableReader +from .dataframe_builder import McapDataFrameBuilder from .tensor_source import McapTensorSource -__all__ = ["McapTableReader", "McapTensorSource"] +__all__ = ["McapDataFrameBuilder", "McapTensorSource"] diff --git a/src/rbyte/io/_mcap/table_reader.py b/src/rbyte/io/_mcap/dataframe_builder.py similarity index 59% rename from src/rbyte/io/_mcap/table_reader.py rename to src/rbyte/io/_mcap/dataframe_builder.py index f7a32f6..2dbc8df 100644 --- a/src/rbyte/io/_mcap/table_reader.py +++ b/src/rbyte/io/_mcap/dataframe_builder.py @@ -1,75 +1,65 @@ -import json from collections import defaultdict -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Iterable, Mapping, Sequence from enum import StrEnum, unique -from functools import cached_property from mmap import ACCESS_READ, mmap from operator import attrgetter from os import PathLike from pathlib import Path -from typing import Any, NamedTuple, override +from typing import NamedTuple, final import more_itertools as mit import polars as pl from mcap.decoder import DecoderFactory from mcap.reader import SeekingReader -from optree import PyTree, tree_map -from polars._typing import PolarsDataType +from polars._typing import PolarsDataType # noqa: PLC2701 from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) -from pydantic import ( - ImportString, - SerializationInfo, - SerializerFunctionWrapHandler, - field_serializer, -) +from pydantic import ConfigDict, ImportString, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars from tqdm import tqdm -from xxhash import xxh3_64_intdigest as digest -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableReader from rbyte.utils.dataframe import unnest_all logger = get_logger(__name__) -class Config(BaseModel): - decoder_factories: frozenset[ImportString[type[DecoderFactory]]] - fields: Mapping[str, Mapping[str, HydraConfig[PolarsDataType] | None]] - validate_crcs: bool = False - - @field_serializer("decoder_factories", when_used="json", mode="wrap") - @staticmethod - def serialize_decoder_factories( - value: frozenset[ImportString[type[DecoderFactory]]], - nxt: SerializerFunctionWrapHandler, - _info: SerializationInfo, - ) -> Sequence[str]: - return sorted(nxt(value)) +type Fields = Mapping[str, Mapping[str, PolarsDataType | None]] +type DecoderFactories = Sequence[ + ImportString[type[DecoderFactory]] | type[DecoderFactory] +] class RowValues(NamedTuple): topic: str - values: Iterable[Any] + values: Iterable[object] @unique class SpecialField(StrEnum): log_time = "log_time" publish_time = "publish_time" - idx = "_idx_" -class McapTableReader(TableReader, Hashable): - def __init__(self, **kwargs: object) -> None: - self._config: Config = Config.model_validate(kwargs) +@final +class McapDataFrameBuilder: + __name__ = __qualname__ - @override - def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + *, + decoder_factories: DecoderFactories, + fields: Fields, + validate_crcs: bool = True, + ) -> None: + self._decoder_factories = decoder_factories + self._fields = fields + self._validate_crcs = validate_crcs + + def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: with ( bound_contextvars(path=str(path)), Path(path).open("rb") as _f, @@ -77,8 +67,8 @@ def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: ): reader = SeekingReader( f, # pyright: ignore[reportArgumentType] - validate_crcs=self._config.validate_crcs, - decoder_factories=[f() for f in self._config.decoder_factories], + validate_crcs=self._validate_crcs, + decoder_factories=[f() for f in self._decoder_factories], ) summary = reader.get_summary() if summary is None: @@ -119,10 +109,6 @@ def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: mit.partition(lambda kv: kv[0] in SpecialField, schema.items()), ) - special_fields = { - k: v for k, v in special_fields.items() if k != SpecialField.idx - } - row_df = pl.DataFrame( [getattr(dmt.message, field) for field in special_fields], schema=special_fields, @@ -137,29 +123,10 @@ def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: rows[dmt.channel.topic].append(row_df) - dfs: Mapping[str, pl.DataFrame] = {} - for topic, row_dfs in rows.items(): - df = pl.concat(row_dfs, how="vertical") - if (idx_name := SpecialField.idx) in (schema := self._fields[topic]): - df = df.with_row_index(idx_name).cast({ - idx_name: schema[idx_name] or pl.UInt32 - }) - - dfs[topic] = df.rechunk() - - return dfs # pyright: ignore[reportReturnType] - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) - - @cached_property - def _fields(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: - return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType, reportReturnType] + return { + topic: pl.concat(row_dfs, how="vertical", rechunk=True) + for topic, row_dfs in rows.items() + } @staticmethod def _build_message_df( diff --git a/src/rbyte/io/_mcap/tensor_source.py b/src/rbyte/io/_mcap/tensor_source.py index f815fb8..d0c75a2 100644 --- a/src/rbyte/io/_mcap/tensor_source.py +++ b/src/rbyte/io/_mcap/tensor_source.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import cached_property from mmap import ACCESS_READ, mmap -from typing import IO, Any, override +from typing import IO, override import more_itertools as mit import numpy.typing as npt @@ -73,7 +73,7 @@ def __init__( logger.error(msg := "missing message decoder") raise RuntimeError(msg) - self._message_decoder: Callable[[bytes], Any] = message_decoder + self._message_decoder: Callable[[bytes], object] = message_decoder self._chunk_indexes: tuple[ChunkIndex, ...] = tuple( chunk_index for chunk_index in summary.chunk_indexes @@ -122,7 +122,7 @@ def __getitem__(self, indexes: Iterable[int]) -> Shaped[Tensor, "b h w c"]: stream.read(message_index.message_start_offset - stream.count) # pyright: ignore[reportUnusedCallResult] message = Message.read(stream, message_index.message_length) decoded_message = self._message_decoder(message.data) - frames[frame_index] = self._decoder(decoded_message.data) + frames[frame_index] = self._decoder(decoded_message.data) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] return torch.stack([torch.from_numpy(frames[idx]) for idx in indexes]) # pyright: ignore[reportUnknownMemberType] diff --git a/src/rbyte/io/_numpy/tensor_source.py b/src/rbyte/io/_numpy/tensor_source.py index a952996..2984eaa 100644 --- a/src/rbyte/io/_numpy/tensor_source.py +++ b/src/rbyte/io/_numpy/tensor_source.py @@ -2,7 +2,7 @@ from functools import cached_property from os import PathLike from pathlib import Path -from typing import TYPE_CHECKING, Any, override +from typing import TYPE_CHECKING, override import numpy as np import torch @@ -14,7 +14,7 @@ from rbyte.config.base import BaseModel from rbyte.io.base import TensorSource -from rbyte.utils.functional import pad_sequence +from rbyte.utils.tensor import pad_sequence if TYPE_CHECKING: from types import EllipsisType @@ -35,7 +35,7 @@ def _path_posix(self) -> str: return self._path.resolve().as_posix() @override - def __getitem__(self, indexes: Iterable[Any]) -> Tensor: + def __getitem__(self, indexes: Iterable[object]) -> Tensor: tensors: list[Tensor] = [] for index in indexes: path = self._path_posix.format(index) diff --git a/src/rbyte/io/dataframe/__init__.py b/src/rbyte/io/dataframe/__init__.py new file mode 100644 index 0000000..59e7b9f --- /dev/null +++ b/src/rbyte/io/dataframe/__init__.py @@ -0,0 +1,13 @@ +from .aligner import DataFrameAligner +from .concater import DataFrameConcater +from .filter import DataFrameFilter +from .fps_resampler import DataFrameFpsResampler +from .indexer import DataFrameIndexer + +__all__ = [ + "DataFrameAligner", + "DataFrameConcater", + "DataFrameFilter", + "DataFrameFpsResampler", + "DataFrameIndexer", +] diff --git a/src/rbyte/io/table/aligner.py b/src/rbyte/io/dataframe/aligner.py similarity index 62% rename from src/rbyte/io/table/aligner.py rename to src/rbyte/io/dataframe/aligner.py index 0a880c7..7c10dcd 100644 --- a/src/rbyte/io/table/aligner.py +++ b/src/rbyte/io/dataframe/aligner.py @@ -1,9 +1,7 @@ -import json from collections import OrderedDict -from collections.abc import Hashable from datetime import timedelta from functools import cached_property -from typing import Annotated, Literal, override +from typing import Literal, final from uuid import uuid4 import polars as pl @@ -15,15 +13,12 @@ tree_map_with_path, ) from polars._typing import AsofJoinStrategy -from pydantic import Field, StringConstraints +from pydantic import Field, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars -from xxhash import xxh3_64_intdigest as digest from rbyte.config.base import BaseModel -from .base import TableMerger - logger = get_logger(__name__) @@ -37,58 +32,56 @@ class AsofColumnMergeConfig(BaseModel): tolerance: str | int | float | timedelta | None = None -ColumnMergeConfig = InterpColumnMergeConfig | AsofColumnMergeConfig +type ColumnMergeConfig = InterpColumnMergeConfig | AsofColumnMergeConfig -class TableMergeConfig(BaseModel): +class MergeConfig(BaseModel): key: str columns: OrderedDict[str, ColumnMergeConfig] = Field(default_factory=OrderedDict) -type MergeConfig = TableMergeConfig | OrderedDict[str, "MergeConfig"] +type Fields = MergeConfig | OrderedDict[str, "Fields"] + +@final +class DataFrameAligner: + __name__ = __qualname__ -class Config(BaseModel): - merge: MergeConfig - separator: Annotated[str, StringConstraints(strip_whitespace=True)] = "/" + @validate_call + def __init__(self, *, fields: Fields, separator: str = "/") -> None: + self._fields = fields + self._separator = separator @cached_property - def merge_fqn(self) -> PyTree[TableMergeConfig]: - # fully qualified key/column names - def fqn(path: tuple[str, ...], cfg: TableMergeConfig) -> TableMergeConfig: - key = self.separator.join((*path, cfg.key)) + def _fully_qualified_fields(self) -> PyTree[MergeConfig]: + def fqn(path: tuple[str, ...], cfg: MergeConfig) -> MergeConfig: + key = self._separator.join((*path, cfg.key)) columns = OrderedDict({ - self.separator.join((*path, k)): v for k, v in cfg.columns.items() + self._separator.join((*path, k)): v for k, v in cfg.columns.items() }) - return TableMergeConfig(key=key, columns=columns) - - return tree_map_with_path(fqn, self.merge) # pyright: ignore[reportArgumentType] + return MergeConfig(key=key, columns=columns) + return tree_map_with_path(fqn, self._fields) # pyright: ignore[reportArgumentType] -class TableAligner(TableMerger, Hashable): - def __init__(self, **kwargs: object) -> None: - self._config: Config = Config.model_validate(kwargs) + def __call__(self, input: PyTree[pl.DataFrame]) -> pl.DataFrame: + fields = self._fully_qualified_fields - @override - def merge(self, src: PyTree[pl.DataFrame]) -> pl.DataFrame: - merge_configs = self._config.merge_fqn - - def get_df(accessor: PyTreeAccessor, cfg: TableMergeConfig) -> pl.DataFrame: + def get_df(accessor: PyTreeAccessor, cfg: MergeConfig) -> pl.DataFrame: return ( - accessor(src) - .rename(lambda col: self._config.separator.join((*accessor.path, col))) # pyright: ignore[reportUnknownLambdaType, reportUnknownArgumentType] + accessor(input) + .rename(lambda col: self._separator.join((*accessor.path, col))) # pyright: ignore[reportUnknownLambdaType, reportUnknownArgumentType] .sort(cfg.key) ) - dfs = tree_map_with_accessor(get_df, merge_configs) - accessor, *accessors_rest = tree_accessors(merge_configs) + dfs = tree_map_with_accessor(get_df, fields) + accessor, *accessors_rest = tree_accessors(fields) df: pl.DataFrame = accessor(dfs) - left_on = accessor(merge_configs).key + left_on = accessor(fields).key for accessor in accessors_rest: other: pl.DataFrame = accessor(dfs) - merge_config: TableMergeConfig = accessor(merge_configs) + merge_config: MergeConfig = accessor(fields) key = merge_config.key for column, config in merge_config.columns.items(): @@ -144,11 +137,3 @@ def get_df(accessor: PyTreeAccessor, cfg: TableMergeConfig) -> pl.DataFrame: ) return df - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) diff --git a/src/rbyte/io/dataframe/concater.py b/src/rbyte/io/dataframe/concater.py new file mode 100644 index 0000000..709040d --- /dev/null +++ b/src/rbyte/io/dataframe/concater.py @@ -0,0 +1,29 @@ +from typing import final + +import polars as pl +from optree import PyTree, tree_leaves, tree_map_with_path +from polars._typing import ConcatMethod +from pydantic import validate_call + + +@final +class DataFrameConcater: + __name__ = __qualname__ + + @validate_call + def __init__( + self, method: ConcatMethod = "horizontal", separator: str | None = None + ) -> None: + self._method: ConcatMethod = method + self._separator = separator + + def __call__(self, input: PyTree[pl.DataFrame]) -> pl.DataFrame: + if (sep := self._separator) is not None: + input = tree_map_with_path( + lambda path, df: df.rename( # pyright: ignore[reportUnknownArgumentType,reportUnknownLambdaType, reportUnknownMemberType] + lambda col: f"{sep.join([*path, col])}" # pyright: ignore[reportUnknownArgumentType,reportUnknownLambdaType] + ), + input, + ) + + return pl.concat(tree_leaves(input), how=self._method, rechunk=True) diff --git a/src/rbyte/io/dataframe/filter.py b/src/rbyte/io/dataframe/filter.py new file mode 100644 index 0000000..258b29d --- /dev/null +++ b/src/rbyte/io/dataframe/filter.py @@ -0,0 +1,14 @@ +from typing import final + +import polars as pl + + +@final +class DataFrameFilter: + __name__ = __qualname__ + + def __init__(self, predicate: str) -> None: + self._query = f"select * from self where {predicate}" # noqa: S608 + + def __call__(self, input: pl.DataFrame) -> pl.DataFrame: + return input.sql(self._query) diff --git a/src/rbyte/io/table/transforms/fps_resampler.py b/src/rbyte/io/dataframe/fps_resampler.py similarity index 50% rename from src/rbyte/io/table/transforms/fps_resampler.py rename to src/rbyte/io/dataframe/fps_resampler.py index 62878bd..93df618 100644 --- a/src/rbyte/io/table/transforms/fps_resampler.py +++ b/src/rbyte/io/dataframe/fps_resampler.py @@ -1,31 +1,30 @@ from math import lcm -from typing import final, override +from typing import final from uuid import uuid4 import polars as pl from pydantic import PositiveInt, validate_call -from .base import TableTransform - @final -class FpsResampler(TableTransform): +class DataFrameFpsResampler: + __name__ = __qualname__ + IDX_COL = uuid4().hex @validate_call - def __init__(self, source_fps: PositiveInt, target_fps: PositiveInt) -> None: + def __init__(self, fps_in: PositiveInt, fps_out: PositiveInt) -> None: super().__init__() - self._source_fps = source_fps - self._target_fps = target_fps - self._fps_lcm = lcm(source_fps, target_fps) + self._fps_in = fps_in + self._fps_out = fps_out + self._fps_lcm = lcm(fps_in, fps_out) - @override - def __call__(self, src: pl.DataFrame) -> pl.DataFrame: + def __call__(self, input: pl.DataFrame) -> pl.DataFrame: return ( - src.with_row_index(self.IDX_COL) - .with_columns(pl.col(self.IDX_COL) * (self._fps_lcm // self._source_fps)) - .upsample(self.IDX_COL, every=f"{self._fps_lcm // self._target_fps}i") + input.with_row_index(self.IDX_COL) + .with_columns(pl.col(self.IDX_COL) * (self._fps_lcm // self._fps_in)) + .upsample(self.IDX_COL, every=f"{self._fps_lcm // self._fps_out}i") .interpolate() .fill_null(strategy="backward") .drop(self.IDX_COL) diff --git a/src/rbyte/io/dataframe/indexer.py b/src/rbyte/io/dataframe/indexer.py new file mode 100644 index 0000000..1797a73 --- /dev/null +++ b/src/rbyte/io/dataframe/indexer.py @@ -0,0 +1,18 @@ +from functools import partial +from typing import final + +import polars as pl +from optree import PyTree, tree_map +from pydantic import validate_call + + +@final +class DataFrameIndexer: + __name__ = __qualname__ + + @validate_call + def __init__(self, name: str) -> None: + self._fn = partial(pl.DataFrame.with_row_index, name=name) + + def __call__(self, input: PyTree[pl.DataFrame]) -> PyTree[pl.DataFrame]: + return tree_map(self._fn, input) diff --git a/src/rbyte/io/hdf5/__init__.py b/src/rbyte/io/hdf5/__init__.py index aa57ae7..14925f1 100644 --- a/src/rbyte/io/hdf5/__init__.py +++ b/src/rbyte/io/hdf5/__init__.py @@ -1,4 +1,4 @@ -from .table_reader import Hdf5TableReader +from .dataframe_builder import Hdf5DataFrameBuilder from .tensor_source import Hdf5TensorSource -__all__ = ["Hdf5TableReader", "Hdf5TensorSource"] +__all__ = ["Hdf5DataFrameBuilder", "Hdf5TensorSource"] diff --git a/src/rbyte/io/hdf5/dataframe_builder.py b/src/rbyte/io/hdf5/dataframe_builder.py new file mode 100644 index 0000000..7d9e44f --- /dev/null +++ b/src/rbyte/io/hdf5/dataframe_builder.py @@ -0,0 +1,52 @@ +from collections.abc import Mapping, Sequence +from os import PathLike +from typing import cast, final + +import numpy.typing as npt +import polars as pl +from h5py import Dataset, File +from optree import PyTree, tree_map, tree_map_with_path +from polars._typing import PolarsDataType # noqa: PLC2701 +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) +from pydantic import ConfigDict, validate_call + +type Fields = Mapping[str, PolarsDataType | None] | Mapping[str, "Fields"] + + +@final +class Hdf5DataFrameBuilder: + __name__ = __qualname__ + + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__(self, fields: Fields) -> None: + self._fields = fields + + def __call__(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: + with File(path) as f: + + def build_series( + path: Sequence[str], dtype: PolarsDataType | None + ) -> pl.Series | None: + key = "/".join(path) + match obj := f.get(key): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + case Dataset(): + values = cast(npt.ArrayLike, obj[:]) + return pl.Series(values=values, dtype=dtype) + + case None: + return None + + case _: # pyright: ignore[reportUnknownVariableType] + raise NotImplementedError + + series = tree_map_with_path(build_series, self._fields, none_is_leaf=True) # pyright: ignore[reportArgumentType] + + return tree_map( + pl.DataFrame, + series, + is_leaf=lambda obj: isinstance(obj, dict) + and all(isinstance(v, pl.Series) or v is None for v in obj.values()), # pyright: ignore[reportUnknownVariableType] + ) diff --git a/src/rbyte/io/hdf5/table_reader.py b/src/rbyte/io/hdf5/table_reader.py deleted file mode 100644 index f50da8e..0000000 --- a/src/rbyte/io/hdf5/table_reader.py +++ /dev/null @@ -1,97 +0,0 @@ -import json -from collections.abc import Hashable, Mapping, Sequence -from enum import StrEnum, unique -from functools import cached_property -from os import PathLike -from typing import cast, override - -import numpy.typing as npt -import polars as pl -from h5py import Dataset, File -from optree import PyTree, tree_map, tree_map_with_path -from polars._typing import PolarsDataType # noqa: PLC2701 -from polars.datatypes import ( - DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 - DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 -) -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config import BaseModel -from rbyte.config.base import HydraConfig -from rbyte.io.table.base import TableReader - -type Fields = Mapping[str, HydraConfig[PolarsDataType] | None] | Mapping[str, "Fields"] - - -class Config(BaseModel): - fields: Fields - - -Config.model_rebuild() # pyright: ignore[reportUnusedCallResult] - - -@unique -class SpecialField(StrEnum): - idx = "_idx_" - - -class Hdf5TableReader(TableReader, Hashable): - def __init__(self, **kwargs: object) -> None: - self._config: Config = Config.model_validate(kwargs) - - @override - def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: - with File(path) as f: - - def build_series( - path: Sequence[str], dtype: PolarsDataType | None - ) -> pl.Series | None: - key = "/".join(path) - match obj := f.get(key): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - case Dataset(): - values = cast(npt.ArrayLike, obj[:]) - return pl.Series(values=values, dtype=dtype) - - case None: - return None - - case _: # pyright: ignore[reportUnknownVariableType] - raise NotImplementedError - - series = tree_map_with_path(build_series, self._fields, none_is_leaf=True) - - dfs = tree_map( - pl.DataFrame, - series, - is_leaf=lambda obj: isinstance(obj, dict) - and all(isinstance(v, pl.Series) or v is None for v in obj.values()), # pyright: ignore[reportUnknownVariableType] - ) - - def maybe_add_index( - df: pl.DataFrame, schema: Mapping[str, PolarsDataType | None] - ) -> pl.DataFrame: - match schema: - case {SpecialField.idx: dtype}: - return df.select( - pl.int_range(pl.len(), dtype=dtype or pl.UInt32).alias( # pyright: ignore[reportArgumentType] - SpecialField.idx - ), - pl.exclude(SpecialField.idx), - ) - - case _: - return df - - return tree_map(maybe_add_index, dfs, self._fields) - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) - - @cached_property - def _fields(self) -> PyTree[PolarsDataType | None]: - return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType] diff --git a/src/rbyte/io/path/__init__.py b/src/rbyte/io/path/__init__.py index c316858..fe600cf 100644 --- a/src/rbyte/io/path/__init__.py +++ b/src/rbyte/io/path/__init__.py @@ -1,4 +1,4 @@ -from .table_reader import PathTableReader +from .dataframe_builder import PathDataFrameBuilder from .tensor_source import PathTensorSource -__all__ = ["PathTableReader", "PathTensorSource"] +__all__ = ["PathDataFrameBuilder", "PathTensorSource"] diff --git a/src/rbyte/io/path/dataframe_builder.py b/src/rbyte/io/path/dataframe_builder.py new file mode 100644 index 0000000..a1396c7 --- /dev/null +++ b/src/rbyte/io/path/dataframe_builder.py @@ -0,0 +1,51 @@ +import os +from collections.abc import Mapping +from os import PathLike +from pathlib import Path +from typing import final + +import parse +import polars as pl +from polars._typing import PolarsDataType # noqa: PLC2701 +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) +from pydantic import ConfigDict, validate_call +from structlog import get_logger + +logger = get_logger(__name__) + + +type Fields = Mapping[str, PolarsDataType | None] + + +@final +class PathDataFrameBuilder: + __name__ = __qualname__ + + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__(self, fields: Fields) -> None: + self._fields = fields + + def __call__(self, path: PathLike[str]) -> pl.DataFrame: + parser = parse.compile(Path(path).resolve().as_posix()) # pyright: ignore[reportUnknownMemberType] + match parser.named_fields, parser.fixed_fields: # pyright: ignore[reportUnknownMemberType] + case ([_, *_], []): # pyright: ignore[reportUnknownVariableType] + pass + + case (named_fields, fixed_fields): # pyright: ignore[reportUnknownVariableType] + logger.error( + msg := "parser not supported", + named_fields=named_fields, + fixed_fields=fixed_fields, + ) + raise RuntimeError(msg) + + parent = Path(os.path.commonpath([path, parser._expression])) # pyright: ignore[reportPrivateUsage] # noqa: SLF001 + results = (parser.parse(p.as_posix()) for p in parent.rglob("*") if p.is_file()) # pyright: ignore[reportUnknownMemberType] + + return pl.DataFrame( + data=(r.named for r in results if isinstance(r, parse.Result)), # pyright: ignore[reportUnknownMemberType] + schema=self._fields, + ) diff --git a/src/rbyte/io/path/table_reader.py b/src/rbyte/io/path/table_reader.py deleted file mode 100644 index b2cdb36..0000000 --- a/src/rbyte/io/path/table_reader.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -from collections.abc import Mapping -from enum import StrEnum, unique -from functools import cached_property -from os import PathLike -from pathlib import Path -from typing import override - -import parse -import polars as pl -from optree import PyTree, tree_map -from polars._typing import PolarsDataType -from polars.datatypes import ( - DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 - DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 -) -from structlog import get_logger - -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableReader - -logger = get_logger(__name__) - - -class Config(BaseModel): - fields: Mapping[str, HydraConfig[PolarsDataType] | None] = {} - - -@unique -class SpecialField(StrEnum): - idx = "_idx_" - - -class PathTableReader(TableReader): - def __init__(self, **kwargs: object) -> None: - self._config: Config = Config.model_validate(kwargs) - - @override - def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: - parser = parse.compile(Path(path).resolve().as_posix()) # pyright: ignore[reportUnknownMemberType] - match parser.named_fields, parser.fixed_fields: # pyright: ignore[reportUnknownMemberType] - case ([_, *_], []): # pyright: ignore[reportUnknownVariableType] - pass - - case (named_fields, fixed_fields): # pyright: ignore[reportUnknownVariableType] - logger.error( - msg := "parser not supported", - named_fields=named_fields, - fixed_fields=fixed_fields, - ) - raise RuntimeError(msg) - - parent = Path(os.path.commonpath([path, parser._expression])) # pyright: ignore[reportPrivateUsage] # noqa: SLF001 - results = (parser.parse(p.as_posix()) for p in parent.rglob("*") if p.is_file()) # pyright: ignore[reportUnknownMemberType] - - df = pl.DataFrame( - result.named # pyright: ignore[reportUnknownMemberType] - for result in results - if isinstance(result, parse.Result) - ) - - if (idx_name := SpecialField.idx) in self._fields: - df = df.with_row_index(idx_name).cast({ - idx_name: self._fields[idx_name] or pl.UInt32 - }) - - df_schema = { - name: dtype for name, dtype in self._fields.items() if dtype is not None - } - - return df.cast(df_schema) # pyright: ignore[reportArgumentType, reportReturnType] - - @cached_property - def _fields(self) -> Mapping[str, PolarsDataType | None]: - return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType, reportReturnType] diff --git a/src/rbyte/io/rrd/__init__.py b/src/rbyte/io/rrd/__init__.py index bacca6e..eb1c8b3 100644 --- a/src/rbyte/io/rrd/__init__.py +++ b/src/rbyte/io/rrd/__init__.py @@ -1,4 +1,4 @@ +from .dataframe_builder import RrdDataFrameBuilder from .frame_source import RrdFrameSource -from .table_reader import RrdTableReader -__all__ = ["RrdFrameSource", "RrdTableReader"] +__all__ = ["RrdDataFrameBuilder", "RrdFrameSource"] diff --git a/src/rbyte/io/rrd/dataframe_builder.py b/src/rbyte/io/rrd/dataframe_builder.py new file mode 100644 index 0000000..0b93c47 --- /dev/null +++ b/src/rbyte/io/rrd/dataframe_builder.py @@ -0,0 +1,79 @@ +from collections.abc import Mapping, Sequence +from enum import StrEnum, auto, unique +from os import PathLike +from typing import cast, final + +import more_itertools as mit +import polars as pl +import rerun.dataframe as rrd +from pydantic import validate_call + + +@unique +class Column(StrEnum): + log_tick = auto() + log_time = auto() + + +@final +class RrdDataFrameBuilder: + __name__ = __qualname__ + + @validate_call + def __init__( + self, index: str, contents: Mapping[str, Sequence[str] | None] + ) -> None: + self._index = index + self._contents = contents + + def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + recording = rrd.load_recording(path) # pyright: ignore[reportUnknownMemberType] + schema = recording.schema() + + # Entity contents must include a non-static component to get index values. + extra_contents: Mapping[str, Sequence[str]] = {} + for entity_path, components in self._contents.items(): + if components is None or all( + (col := schema.column_for(entity_path, component)) is not None + and col.is_static + for component in components + ): + non_static_components = [ + col.component_name.removeprefix("rerun.components.") + for col in schema.component_columns() + if not col.is_static and col.entity_path == entity_path + ] + + extra_component = mit.first( + (comp for comp in non_static_components if "Indicator" in comp), + default=mit.first(non_static_components), + ) + + extra_contents[entity_path] = [extra_component] + + view = recording.view( + index=self._index, + contents={ + entity_path: [*(components or []), *extra_contents.get(entity_path, [])] + for entity_path, components in self._contents.items() + }, + include_indicator_columns=True, + ) + + recording_df = cast(pl.DataFrame, pl.from_arrow(view.select().read_all())).drop( # pyright: ignore[reportUnknownMemberType] + Column.log_tick, Column.log_time + ) + + entity_columns = mit.map_reduce( + recording_df.select(pl.exclude(self._index)).columns, + keyfunc=lambda x: x.split(":")[0], + ) + + return { + entity: recording_df.select( + self._index, pl.col(*columns).name.map(lambda x: x.split(":")[1]) + ) + .drop_nulls() + .drop(extra_contents.get(entity, [])) + for entity, columns in entity_columns.items() + } diff --git a/src/rbyte/io/rrd/table_reader.py b/src/rbyte/io/rrd/table_reader.py deleted file mode 100644 index 97b6c8a..0000000 --- a/src/rbyte/io/rrd/table_reader.py +++ /dev/null @@ -1,106 +0,0 @@ -import json -from collections.abc import Hashable, Mapping, Sequence -from enum import StrEnum, auto, unique -from os import PathLike -from typing import cast, override - -import more_itertools as mit -import polars as pl -import rerun.dataframe as rrd -from optree import PyTree -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config.base import BaseModel -from rbyte.io.table.base import TableReader - - -class Config(BaseModel): - index: str - contents: Mapping[str, Sequence[str]] - - -@unique -class Column(StrEnum): - log_tick = auto() - log_time = auto() - idx = "_idx_" - - -class RrdTableReader(TableReader, Hashable): - def __init__(self, **kwargs: object) -> None: - self._config: Config = Config.model_validate(kwargs) - - @override - def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: - recording = rrd.load_recording(path) # pyright: ignore[reportUnknownMemberType] - schema = recording.schema() - - # Entity contents must include a non-static component to get index values. - extra_contents: Mapping[str, Sequence[str]] = {} - for entity_path, components in self._config.contents.items(): - match components: - case [Column.idx, *rest] if all( - (col := schema.column_for(entity_path, component)) is not None - and col.is_static - for component in rest - ): - non_static_components = [ - col.component_name.removeprefix("rerun.components.") - for col in schema.component_columns() - if not col.is_static and col.entity_path == entity_path - ] - - extra_component = mit.first( - (comp for comp in non_static_components if "Indicator" in comp), - default=mit.first(non_static_components), - ) - - extra_contents[entity_path] = [extra_component] - - case _: - pass - - view = recording.view( - index=self._config.index, - contents={ - entity_path: [*components, *extra_contents.get(entity_path, [])] - for entity_path, components in self._config.contents.items() - }, - include_indicator_columns=True, - ) - - recording_df = cast(pl.DataFrame, pl.from_arrow(view.select().read_all())).drop( # pyright: ignore[reportUnknownMemberType] - Column.log_tick, Column.log_time - ) - - entity_columns = mit.map_reduce( - recording_df.select(pl.exclude(self._config.index)).columns, - keyfunc=lambda x: x.split(":")[0], - ) - - dfs: Mapping[str, pl.DataFrame] = {} - - for entity, columns in entity_columns.items(): - entity_df = ( - recording_df.select( - self._config.index, - pl.col(*columns).name.map(lambda x: x.split(":")[1]), - ) - .drop_nulls() - .drop(extra_contents.get(entity, [])) - ) - - if Column.idx in self._config.contents[entity]: - entity_df = entity_df.with_row_index(Column.idx) - - dfs[entity] = entity_df - - return dfs # pyright: ignore[reportReturnType] - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) diff --git a/src/rbyte/io/table/__init__.py b/src/rbyte/io/table/__init__.py deleted file mode 100644 index 20c3101..0000000 --- a/src/rbyte/io/table/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .aligner import TableAligner -from .builder import TableBuilder -from .concater import TableConcater -from .transforms import FpsResampler - -__all__ = ["FpsResampler", "TableAligner", "TableBuilder", "TableConcater"] diff --git a/src/rbyte/io/table/base.py b/src/rbyte/io/table/base.py deleted file mode 100644 index 0ba3f5a..0000000 --- a/src/rbyte/io/table/base.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Hashable -from os import PathLike -from typing import Protocol, runtime_checkable - -from optree import PyTree -from polars import DataFrame - - -@runtime_checkable -class TableBuilder(Protocol): - def build(self) -> DataFrame: ... - - -@runtime_checkable -class TableReader(Protocol): - def read(self, path: PathLike[str]) -> PyTree[DataFrame]: ... - - -@runtime_checkable -class TableMerger(Protocol): - def merge(self, src: PyTree[DataFrame]) -> DataFrame: ... - - -@runtime_checkable -class TableCache(Protocol): - def __contains__(self, key: Hashable) -> bool: ... - def get(self, key: Hashable) -> DataFrame | None: ... - def set(self, key: Hashable, value: DataFrame) -> bool: ... diff --git a/src/rbyte/io/table/builder.py b/src/rbyte/io/table/builder.py deleted file mode 100644 index 16e87a6..0000000 --- a/src/rbyte/io/table/builder.py +++ /dev/null @@ -1,90 +0,0 @@ -from collections.abc import Hashable, Mapping -from mmap import ACCESS_READ, mmap -from os import PathLike -from pathlib import Path -from typing import Annotated, Any, override - -import polars as pl -from optree import PyTree, tree_map -from pydantic import Field, StringConstraints, validate_call -from structlog import get_logger -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config import BaseModel - -from .base import TableBuilder as _TableBuilder -from .base import TableCache, TableMerger, TableReader - -logger = get_logger(__name__) - - -class TableReaderConfig(BaseModel): - path: PathLike[str] - reader: TableReader - - -class TableBuilder(_TableBuilder): - @validate_call(config=BaseModel.model_config) - def __init__( - self, - readers: Annotated[Mapping[str, TableReaderConfig], Field(min_length=1)], - merger: TableMerger, - filter: Annotated[str, StringConstraints(strip_whitespace=True)] | None = None, # noqa: A002 - cache: TableCache | None = None, - ) -> None: - super().__init__() - - self._readers: Mapping[str, TableReaderConfig] = readers - self._merger: TableMerger = merger - self._filter: str | None = filter - self._cache: TableCache | None = cache - - def _build_cache_key(self) -> Hashable: - from rbyte import __version__ # noqa: PLC0415 - - key: list[Any] = [__version__, hash(self._merger)] - - if self._filter is not None: - key.append(digest(self._filter)) - - for reader_name, reader_config in sorted(self._readers.items()): - with ( - Path(reader_config.path).open("rb") as _f, - mmap(_f.fileno(), 0, access=ACCESS_READ) as f, - ): - file_hash = digest(f) # pyright: ignore[reportArgumentType] - - key.append((file_hash, digest(reader_name), hash(reader_config.reader))) - - return tuple(key) - - @override - def build(self) -> pl.DataFrame: - match self._cache: - case TableCache(): - key = self._build_cache_key() - if key in self._cache: - logger.debug("reading table from cache") - df = self._cache.get(key) - if df is None: - raise RuntimeError - - return df - - df = self._build() - if not self._cache.set(key, df): - logger.warning("failed to cache table") - - return df - - case None: - return self._build() - - def _build(self) -> pl.DataFrame: - dfs: PyTree[pl.DataFrame] = tree_map( - lambda cfg: cfg.reader.read(cfg.path), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportUnknownLambdaType] - self._readers, # pyright: ignore[reportArgumentType] - ) - df = self._merger.merge(dfs) - - return df.sql(f"select * from self where ({self._filter or True})") # noqa: S608 diff --git a/src/rbyte/io/table/concater.py b/src/rbyte/io/table/concater.py deleted file mode 100644 index 9d3afb0..0000000 --- a/src/rbyte/io/table/concater.py +++ /dev/null @@ -1,42 +0,0 @@ -import json -from collections.abc import Hashable -from typing import override - -import polars as pl -from optree import PyTree, tree_leaves, tree_map_with_path -from polars._typing import ConcatMethod -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config import BaseModel - -from .base import TableMerger - - -class Config(BaseModel): - separator: str | None = None - method: ConcatMethod = "horizontal" - - -class TableConcater(TableMerger, Hashable): - def __init__(self, **kwargs: object) -> None: - self._config: Config = Config.model_validate(kwargs) - - @override - def merge(self, src: PyTree[pl.DataFrame]) -> pl.DataFrame: - if (sep := self._config.separator) is not None: - src = tree_map_with_path( - lambda path, df: df.rename( # pyright: ignore[reportUnknownArgumentType,reportUnknownLambdaType, reportUnknownMemberType] - lambda col: f"{sep.join([*path, col])}" # pyright: ignore[reportUnknownArgumentType,reportUnknownLambdaType] - ), - src, - ) - - return pl.concat(tree_leaves(src), how=self._config.method, rechunk=True) - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) diff --git a/src/rbyte/io/table/transforms/__init__.py b/src/rbyte/io/table/transforms/__init__.py deleted file mode 100644 index 734fa11..0000000 --- a/src/rbyte/io/table/transforms/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import TableTransform -from .fps_resampler import FpsResampler - -__all__ = ["FpsResampler", "TableTransform"] diff --git a/src/rbyte/io/table/transforms/base.py b/src/rbyte/io/table/transforms/base.py deleted file mode 100644 index a287bb2..0000000 --- a/src/rbyte/io/table/transforms/base.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Protocol, runtime_checkable - -import polars as pl - - -@runtime_checkable -class TableTransform(Protocol): - def __call__(self, src: pl.DataFrame) -> pl.DataFrame: ... diff --git a/src/rbyte/io/yaak/__init__.py b/src/rbyte/io/yaak/__init__.py index 8cfd72d..845e6a0 100644 --- a/src/rbyte/io/yaak/__init__.py +++ b/src/rbyte/io/yaak/__init__.py @@ -1,3 +1,3 @@ -from .table_reader import YaakMetadataTableReader +from .dataframe_builder import YaakMetadataDataFrameBuilder -__all__ = ["YaakMetadataTableReader"] +__all__ = ["YaakMetadataDataFrameBuilder"] diff --git a/src/rbyte/io/yaak/table_reader.py b/src/rbyte/io/yaak/dataframe_builder.py similarity index 61% rename from src/rbyte/io/yaak/table_reader.py rename to src/rbyte/io/yaak/dataframe_builder.py index 2208e57..606b16b 100644 --- a/src/rbyte/io/yaak/table_reader.py +++ b/src/rbyte/io/yaak/dataframe_builder.py @@ -1,29 +1,22 @@ -import json from collections.abc import Mapping -from functools import cached_property from mmap import ACCESS_READ, mmap from operator import itemgetter from os import PathLike from pathlib import Path -from typing import cast, override +from typing import cast, final import more_itertools as mit import polars as pl from google.protobuf.message import Message -from optree import PyTree, tree_map -from polars._typing import PolarsDataType +from polars._typing import PolarsDataType # noqa: PLC2701 from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) from ptars import HandlerPool -from pydantic import ImportString +from pydantic import ConfigDict, ImportString, validate_call from structlog import get_logger from tqdm import tqdm -from xxhash import xxh3_64_intdigest as digest - -from rbyte.config.base import BaseModel, HydraConfig -from rbyte.io.table.base import TableReader from .message_iterator import YaakMetadataMessageIterator from .proto import sensor_pb2 @@ -31,20 +24,22 @@ logger = get_logger(__name__) -class Config(BaseModel): - fields: Mapping[ - ImportString[type[Message]], Mapping[str, HydraConfig[PolarsDataType] | None] - ] +type Fields = Mapping[ + type[Message] | ImportString[type[Message]], Mapping[str, PolarsDataType | None] +] + +@final +class YaakMetadataDataFrameBuilder: + __name__ = __qualname__ -class YaakMetadataTableReader(TableReader): - def __init__(self, **kwargs: object) -> None: + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__(self, *, fields: Fields) -> None: super().__init__() - self._config: Config = Config.model_validate(kwargs) + self._fields = fields - @override - def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: + def __call__(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f: handler_pool = HandlerPool() @@ -81,16 +76,4 @@ def read(self, path: PathLike[str]) -> PyTree[pl.DataFrame]: ).items() } - return dfs # pyright: ignore[reportReturnType] - - @override - def __hash__(self) -> int: - config = self._config.model_dump_json() - # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 - config_str = json.dumps(json.loads(config), sort_keys=True) - - return digest(config_str) - - @cached_property - def _fields(self) -> Mapping[type[Message], Mapping[str, PolarsDataType | None]]: - return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownVariableType, reportReturnType, reportUnknownMemberType, reportUnknownArgumentType] + return dfs diff --git a/src/rbyte/io/yaak/message_iterator.py b/src/rbyte/io/yaak/message_iterator.py index 6efda9a..1dc5a4c 100644 --- a/src/rbyte/io/yaak/message_iterator.py +++ b/src/rbyte/io/yaak/message_iterator.py @@ -85,11 +85,11 @@ def __next__(self) -> tuple[type[Message], bytes]: return msg_type, msg_buf # pyright: ignore[reportPossiblyUnboundVariable] def _read_message(self) -> tuple[int, bytes] | None: - msg__type_idx_buf = self._file.read(4) - if not msg__type_idx_buf: + msg_type_idx_buf = self._file.read(4) + if not msg_type_idx_buf: return None - msg_type_idx = to_uint32(msg__type_idx_buf) + msg_type_idx = to_uint32(msg_type_idx_buf) msg_len = to_uint32(self._file.read(4)) msg_buf = self._file.read(msg_len) diff --git a/src/rbyte/sample/base.py b/src/rbyte/sample/base.py deleted file mode 100644 index eb44648..0000000 --- a/src/rbyte/sample/base.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Protocol, runtime_checkable - -import polars as pl - - -@runtime_checkable -class SampleBuilder(Protocol): - def build(self, source: pl.DataFrame) -> pl.DataFrame: ... diff --git a/src/rbyte/sample/fixed_window.py b/src/rbyte/sample/fixed_window.py index f82c3fc..ff26276 100644 --- a/src/rbyte/sample/fixed_window.py +++ b/src/rbyte/sample/fixed_window.py @@ -1,15 +1,14 @@ from datetime import timedelta -from typing import Literal, override +from typing import final from uuid import uuid4 import polars as pl from polars._typing import ClosedInterval from pydantic import validate_call -from .base import SampleBuilder - -class FixedWindowSampleBuilder(SampleBuilder): +@final +class FixedWindowSampleBuilder: """ Build samples using fixed (potentially overlapping) windows based on a temporal or integer column. @@ -17,6 +16,8 @@ class FixedWindowSampleBuilder(SampleBuilder): https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by_dynamic """ + __name__ = __qualname__ + @validate_call def __init__( self, @@ -25,18 +26,15 @@ def __init__( every: str | timedelta, period: str | timedelta | None = None, closed: ClosedInterval = "left", - filter: str | None = None, # noqa: A002 ) -> None: - self._index_column: pl.Expr = pl.col(index_column) - self._every: str | timedelta = every - self._period: str | timedelta | None = period + self._index_column = pl.col(index_column) + self._every = every + self._period = period self._closed: ClosedInterval = closed - self._filter: str | Literal[True] = filter if filter is not None else True - @override - def build(self, source: pl.DataFrame) -> pl.DataFrame: + def __call__(self, input: pl.DataFrame) -> pl.DataFrame: return ( - source.sort(self._index_column) + input.sort(self._index_column) .with_columns(self._index_column.alias(_index_column := uuid4().hex)) .group_by_dynamic( index_column=_index_column, @@ -47,7 +45,6 @@ def build(self, source: pl.DataFrame) -> pl.DataFrame: start_by="datapoint", ) .agg(pl.all()) - .sql(f"select * from self where ({self._filter})") # noqa: S608 .filter(self._index_column.list.len() > 0) .sort(_index_column) .drop(_index_column) diff --git a/src/rbyte/sample/rolling_window.py b/src/rbyte/sample/rolling_window.py index cdfee09..c8e78c2 100644 --- a/src/rbyte/sample/rolling_window.py +++ b/src/rbyte/sample/rolling_window.py @@ -1,21 +1,22 @@ from datetime import timedelta -from typing import Literal, override +from typing import final from uuid import uuid4 import polars as pl from polars._typing import ClosedInterval from pydantic import validate_call -from .base import SampleBuilder - -class RollingWindowSampleBuilder(SampleBuilder): +@final +class RollingWindowSampleBuilder: """ Build samples using rolling windows based on a temporal or integer column. https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.rolling """ + __name__ = __qualname__ + @validate_call def __init__( self, @@ -24,18 +25,15 @@ def __init__( period: str | timedelta, offset: str | timedelta | None = None, closed: ClosedInterval = "right", - filter: str | None = None, # noqa: A002 ) -> None: - self._index_column: pl.Expr = pl.col(index_column) - self._period: str | timedelta = period - self._offset: str | timedelta | None = offset + self._index_column = pl.col(index_column) + self._period = period + self._offset = offset self._closed: ClosedInterval = closed - self._filter: str | Literal[True] = filter if filter is not None else True - @override - def build(self, source: pl.DataFrame) -> pl.DataFrame: + def __call__(self, input: pl.DataFrame) -> pl.DataFrame: return ( - source.sort(self._index_column) + input.sort(self._index_column) .with_columns(self._index_column.alias(_index_column := uuid4().hex)) .rolling( index_column=_index_column, @@ -44,7 +42,6 @@ def build(self, source: pl.DataFrame) -> pl.DataFrame: closed=self._closed, ) .agg(pl.all()) - .sql(f"select * from self where ({self._filter})") # noqa: S608 .filter(self._index_column.list.len() > 0) .sort(_index_column) .drop(_index_column) diff --git a/src/rbyte/utils/__init__.py b/src/rbyte/utils/__init__.py index e69de29..8b4d903 100644 --- a/src/rbyte/utils/__init__.py +++ b/src/rbyte/utils/__init__.py @@ -0,0 +1,3 @@ +from ._pipefunc import make_dict + +__all__ = ["make_dict"] diff --git a/src/rbyte/utils/mcap/__init__.py b/src/rbyte/utils/_mcap/__init__.py similarity index 100% rename from src/rbyte/utils/mcap/__init__.py rename to src/rbyte/utils/_mcap/__init__.py diff --git a/src/rbyte/utils/mcap/decoders/__init__.py b/src/rbyte/utils/_mcap/decoders/__init__.py similarity index 100% rename from src/rbyte/utils/mcap/decoders/__init__.py rename to src/rbyte/utils/_mcap/decoders/__init__.py diff --git a/src/rbyte/utils/mcap/decoders/json_decoder_factory.py b/src/rbyte/utils/_mcap/decoders/json_decoder_factory.py similarity index 100% rename from src/rbyte/utils/mcap/decoders/json_decoder_factory.py rename to src/rbyte/utils/_mcap/decoders/json_decoder_factory.py diff --git a/src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py b/src/rbyte/utils/_mcap/decoders/protobuf_decoder_factory.py similarity index 100% rename from src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py rename to src/rbyte/utils/_mcap/decoders/protobuf_decoder_factory.py diff --git a/src/rbyte/utils/_pipefunc.py b/src/rbyte/utils/_pipefunc.py new file mode 100644 index 0000000..dc7afeb --- /dev/null +++ b/src/rbyte/utils/_pipefunc.py @@ -0,0 +1,33 @@ +from collections.abc import Hashable, Iterable + +__unspecified = object() + + +# need signature for pipefunc +def make_dict( # noqa: PLR0913 + *, + k0: Hashable, + v0: object, + k1: Hashable = __unspecified, + v1: object = __unspecified, + k2: Hashable = __unspecified, + v2: object = __unspecified, + k3: Hashable = __unspecified, + v3: object = __unspecified, + k4: Hashable = __unspecified, + v4: object = __unspecified, +) -> dict[Hashable, object]: + keys = (k0, k1, k2, k3, k4) + values = (v0, v1, v2, v3, v4) + + def items() -> Iterable[tuple[Hashable, object]]: + for key, value in zip(keys, values, strict=True): + if (key is __unspecified) and (value is __unspecified): + continue + + elif (key is __unspecified) or (value is __unspecified): + raise ValueError + + yield (key, value) + + return dict(items()) diff --git a/src/rbyte/utils/dataframe/misc.py b/src/rbyte/utils/dataframe.py similarity index 100% rename from src/rbyte/utils/dataframe/misc.py rename to src/rbyte/utils/dataframe.py diff --git a/src/rbyte/utils/dataframe/__init__.py b/src/rbyte/utils/dataframe/__init__.py deleted file mode 100644 index b12763f..0000000 --- a/src/rbyte/utils/dataframe/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .cache import DataframeDiskCache -from .misc import unnest_all - -__all__ = ["DataframeDiskCache", "unnest_all"] diff --git a/src/rbyte/utils/dataframe/cache.py b/src/rbyte/utils/dataframe/cache.py deleted file mode 100644 index 93fa76d..0000000 --- a/src/rbyte/utils/dataframe/cache.py +++ /dev/null @@ -1,42 +0,0 @@ -from collections.abc import Hashable -from io import BufferedReader -from tempfile import TemporaryFile -from typing import Literal, override - -import polars as pl -from diskcache import Cache -from pydantic import ByteSize, DirectoryPath, NewPath, validate_call - -from rbyte.io.table.base import TableCache - - -class DataframeDiskCache(TableCache): - @validate_call - def __init__( - self, directory: DirectoryPath | NewPath, size_limit: ByteSize | None = None - ) -> None: - super().__init__() - self._cache: Cache = Cache(directory=directory, size_limit=size_limit) - - @override - def __contains__(self, key: Hashable) -> bool: - return key in self._cache - - @override - def get(self, key: Hashable) -> pl.DataFrame | None: - match val := self._cache.get(key, default=None, read=True): # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] - case BufferedReader(): - return pl.read_ipc(val) - - case None: - return None - - case _: # pyright: ignore[reportUnknownVariableType] - raise NotImplementedError - - @override - def set(self, key: Hashable, value: pl.DataFrame) -> Literal[True]: - with TemporaryFile() as f: - value.write_ipc(f, compression="uncompressed") - f.seek(0) # pyright: ignore[reportUnusedCallResult] - return self._cache.set(key, f, read=True) # pyright: ignore[reportUnknownMemberType] diff --git a/src/rbyte/utils/functional.py b/src/rbyte/utils/tensor.py similarity index 95% rename from src/rbyte/utils/functional.py rename to src/rbyte/utils/tensor.py index 343d24f..fa58a40 100644 --- a/src/rbyte/utils/functional.py +++ b/src/rbyte/utils/tensor.py @@ -8,7 +8,7 @@ def pad_dim( - input: Float[Tensor, "..."], # noqa: A002 + input: Float[Tensor, "..."], *, pad: tuple[int, int], dim: int, diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 545ce48..fcb3a9a 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -80,17 +80,11 @@ def test_nuscenes_mcap() -> None: "CAM_FRONT": Tensor(shape=[c.B, _, *_]), "CAM_FRONT_LEFT": Tensor(shape=[c.B, _, *_]), "CAM_FRONT_RIGHT": Tensor(shape=[c.B, _, *_]), - "mcap//CAM_FRONT/image_rect_compressed/_idx_": Tensor(shape=[c.B, _]), - "mcap//CAM_FRONT/image_rect_compressed/log_time": Tensor( - shape=[c.B, _] - ), - "mcap//CAM_FRONT_LEFT/image_rect_compressed/_idx_": Tensor( - shape=[c.B, _] - ), - "mcap//CAM_FRONT_RIGHT/image_rect_compressed/_idx_": Tensor( - shape=[c.B, _] - ), - "mcap//odom/vel.x": Tensor(shape=[c.B, _]), + "/CAM_FRONT/image_rect_compressed/_idx_": Tensor(shape=[c.B, _]), + "/CAM_FRONT/image_rect_compressed/log_time": Tensor(shape=[c.B, _]), + "/CAM_FRONT_LEFT/image_rect_compressed/_idx_": Tensor(shape=[c.B, _]), + "/CAM_FRONT_RIGHT/image_rect_compressed/_idx_": Tensor(shape=[c.B, _]), + "/odom/vel.x": Tensor(shape=[c.B, _]), **data_rest, }, "meta": { @@ -137,19 +131,11 @@ def test_nuscenes_rrd() -> None: "CAM_FRONT": Tensor(shape=[c.B, _, *_]), "CAM_FRONT_LEFT": Tensor(shape=[c.B, _, *_]), "CAM_FRONT_RIGHT": Tensor(shape=[c.B, _, *_]), - "rrd//world/ego_vehicle/CAM_FRONT/timestamp": Tensor( - shape=[c.B, _, *_] - ), - "rrd//world/ego_vehicle/CAM_FRONT/_idx_": Tensor(shape=[c.B, _, *_]), - "rrd//world/ego_vehicle/CAM_FRONT_LEFT/_idx_": Tensor( - shape=[c.B, _, *_] - ), - "rrd//world/ego_vehicle/CAM_FRONT_RIGHT/_idx_": Tensor( - shape=[c.B, _, *_] - ), - "rrd//world/ego_vehicle/LIDAR_TOP/Position3D": Tensor( - shape=[c.B, _, *_] - ), + "/world/ego_vehicle/CAM_FRONT/timestamp": Tensor(shape=[c.B, _, *_]), + "/world/ego_vehicle/CAM_FRONT/_idx_": Tensor(shape=[c.B, _, *_]), + "/world/ego_vehicle/CAM_FRONT_LEFT/_idx_": Tensor(shape=[c.B, _, *_]), + "/world/ego_vehicle/CAM_FRONT_RIGHT/_idx_": Tensor(shape=[c.B, _, *_]), + "/world/ego_vehicle/LIDAR_TOP/Position3D": Tensor(shape=[c.B, _, *_]), **data_rest, }, "meta": {