From 5e9982436e8de00d0e8ba0b10048c56c130d4910 Mon Sep 17 00:00:00 2001 From: Evgenii Gorchakov Date: Mon, 21 Oct 2024 11:49:16 +0200 Subject: [PATCH] feat: support multiple table readers (#15) - add support for multiple table readers - add a `protobuf -> pl.DataFrame` mcap decoder factory - replace `CarlaRecordsTableBuilder` with `JsonTableReader` --- .pre-commit-config.yaml | 6 +- README.md | 63 ++++----- examples/config_templates/build_table.yaml | 4 +- examples/config_templates/dataset/carla.yaml | 44 ++++--- examples/config_templates/dataset/hdf5.yaml | 21 +-- examples/config_templates/dataset/mcap.yaml | 41 +++--- examples/config_templates/dataset/yaak.yaml | 105 +++++++++------ .../config_templates/logger/rerun/carla.yaml | 2 +- .../config_templates/logger/rerun/yaak.yaml | 3 + .../config_templates/table_builder/carla.yaml | 36 ++++-- .../config_templates/table_builder/hdf5.yaml | 36 +++--- .../config_templates/table_builder/mcap.yaml | 43 ++++--- .../config_templates/table_builder/yaak.yaml | 82 +++++++----- .../table_writer/console.yaml | 4 + pyproject.toml | 17 ++- src/rbyte/config/base.py | 12 +- src/rbyte/dataset.py | 9 +- src/rbyte/io/table/__init__.py | 24 +++- src/rbyte/io/table/aligner.py | 3 + src/rbyte/io/table/base.py | 4 +- src/rbyte/io/table/builder.py | 86 ++++++++----- src/rbyte/io/table/carla/__init__.py | 3 - src/rbyte/io/table/carla/builder.py | 54 -------- src/rbyte/io/table/concater.py | 13 +- src/rbyte/io/table/hdf5/reader.py | 30 ++--- src/rbyte/io/table/json/__init__.py | 3 + src/rbyte/io/table/json/reader.py | 87 +++++++++++++ src/rbyte/io/table/mcap/reader.py | 120 ++++++++++-------- src/rbyte/io/table/transforms/base.py | 4 +- .../io/table/transforms/fps_resampler.py | 4 +- src/rbyte/io/table/yaak/reader.py | 9 +- src/rbyte/sample/builder.py | 2 +- src/rbyte/scripts/build_table.py | 2 +- src/rbyte/utils/mcap/__init__.py | 4 +- src/rbyte/utils/mcap/decoders/__init__.py | 4 + .../mcap/decoders/json_decoder_factory.py | 24 ++++ .../mcap/decoders/protobuf_decoder_factory.py | 72 +++++++++++ src/rbyte/utils/mcap/json_decoder_factory.py | 27 ---- 38 files changed, 684 insertions(+), 423 deletions(-) create mode 100644 examples/config_templates/table_writer/console.yaml delete mode 100644 src/rbyte/io/table/carla/__init__.py delete mode 100644 src/rbyte/io/table/carla/builder.py create mode 100644 src/rbyte/io/table/json/__init__.py create mode 100644 src/rbyte/io/table/json/reader.py create mode 100644 src/rbyte/utils/mcap/decoders/__init__.py create mode 100644 src/rbyte/utils/mcap/decoders/json_decoder_factory.py create mode 100644 src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py delete mode 100644 src/rbyte/utils/mcap/json_decoder_factory.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7cf807a..6a06afa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ fail_fast: true repos: - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.20.2 + rev: v0.21 hooks: - id: validate-pyproject @@ -13,7 +13,7 @@ repos: - id: typos - repo: https://github.com/asottile/pyupgrade - rev: v3.17.0 + rev: v3.18.0 hooks: - id: pyupgrade @@ -31,7 +31,7 @@ repos: - id: ruff-format - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror - rev: 1.18.3 + rev: 1.18.4 hooks: - id: basedpyright diff --git a/README.md b/README.md index fc773dc..5c40ff5 100644 --- a/README.md +++ b/README.md @@ -139,40 +139,41 @@ dataset: frame_decoder: ${frame_decoder} table: - path: data/NuScenes-v1.0-mini-scene-0103.mcap builder: _target_: rbyte.io.table.TableBuilder _convert_: all - reader: - _target_: rbyte.io.table.mcap.McapTableReader - _recursive_: false - decoder_factories: - - mcap_protobuf.decoder.DecoderFactory - - rbyte.utils.mcap.McapJsonDecoderFactory - fields: - /CAM_FRONT/image_rect_compressed: - _idx_: - log_time: - _target_: polars.Datetime - time_unit: ns - - /CAM_FRONT_LEFT/image_rect_compressed: - _idx_: - log_time: - _target_: polars.Datetime - time_unit: ns - - /CAM_FRONT_RIGHT/image_rect_compressed: - _idx_: - log_time: - _target_: polars.Datetime - time_unit: ns - - /odom: - log_time: - _target_: polars.Datetime - time_unit: ns - vel.x: null + readers: + - path: data/NuScenes-v1.0-mini-scene-0103.mcap + reader: + _target_: rbyte.io.table.mcap.McapTableReader + _recursive_: false + decoder_factories: + - mcap_protobuf.decoder.DecoderFactory + - rbyte.utils.mcap.McapJsonDecoderFactory + fields: + /CAM_FRONT/image_rect_compressed: + _idx_: + log_time: + _target_: polars.Datetime + time_unit: ns + + /CAM_FRONT_LEFT/image_rect_compressed: + _idx_: + log_time: + _target_: polars.Datetime + time_unit: ns + + /CAM_FRONT_RIGHT/image_rect_compressed: + _idx_: + log_time: + _target_: polars.Datetime + time_unit: ns + + /odom: + log_time: + _target_: polars.Datetime + time_unit: ns + vel.x: null merger: _target_: rbyte.io.table.TableAligner diff --git a/examples/config_templates/build_table.yaml b/examples/config_templates/build_table.yaml index 70aaac3..755726d 100644 --- a/examples/config_templates/build_table.yaml +++ b/examples/config_templates/build_table.yaml @@ -1,11 +1,9 @@ --- defaults: - table_builder: !!null - - table_writer: !!null + - table_writer: console - _self_ -path: ??? - hydra: output_subdir: !!null run: diff --git a/examples/config_templates/dataset/carla.yaml b/examples/config_templates/dataset/carla.yaml index e40d3e1..4f308d5 100644 --- a/examples/config_templates/dataset/carla.yaml +++ b/examples/config_templates/dataset/carla.yaml @@ -62,7 +62,7 @@ inputs: frame: #@ for source_id in cameras: (@=source_id@): - index_column: frame_idx + index_column: _idx_ reader: _target_: rbyte.io.frame.DirectoryFrameReader path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).defish.mp4/576x324/{:09d}.jpg" @@ -76,28 +76,40 @@ inputs: #@ end table: - path: ${data_dir}/(@=input_id@)/ego_logs.json builder: - _target_: rbyte.io.table.carla.CarlaRecordsTableBuilder + _target_: rbyte.io.table.TableBuilder _convert_: all - index_column: frame_idx - select: - - control.brake - - control.throttle - - control.steer - - state.velocity.value - - state.acceleration.value + readers: + - path: ${data_dir}/(@=input_id@)/ego_logs.json + reader: + _target_: rbyte.io.table.JsonTableReader + _recursive_: false + fields: + records: + _idx_: + control.brake: + control.throttle: + control.steer: + state.velocity.value: + state.acceleration.value: + + transforms: + - _target_: rbyte.io.table.transforms.FpsResampler + source_fps: 20 + target_fps: 30 + + merger: + _target_: rbyte.io.table.TableConcater + method: vertical + + filter: | + `control.throttle` > 0.5 - filter: !!null - transforms: - - _target_: rbyte.io.table.transforms.FpsResampler - source_fps: 20 - target_fps: 30 #@ end sample_builder: _target_: rbyte.sample.builder.GreedySampleTableBuilder - index_column: frame_idx + index_column: _idx_ length: 1 stride: 1 min_step: 1 diff --git a/examples/config_templates/dataset/hdf5.yaml b/examples/config_templates/dataset/hdf5.yaml index 64054f6..f15e1c1 100644 --- a/examples/config_templates/dataset/hdf5.yaml +++ b/examples/config_templates/dataset/hdf5.yaml @@ -36,22 +36,23 @@ inputs: #@ end table: - path: "${data_dir}/(@=input_id@).hdf5" builder: _target_: rbyte.io.table.TableBuilder _convert_: all - reader: - _target_: rbyte.io.table.hdf5.Hdf5TableReader - _recursive_: false - fields: - /data/demo_0: - _idx_: - obs/object: - task_successes: + readers: + - path: "${data_dir}/(@=input_id@).hdf5" + reader: + _target_: rbyte.io.table.hdf5.Hdf5TableReader + _recursive_: false + fields: + (@=input_key@): + _idx_: + obs/object: + task_successes: merger: _target_: rbyte.io.table.TableConcater - separator: "/" + method: vertical #@ end #@ end diff --git a/examples/config_templates/dataset/mcap.yaml b/examples/config_templates/dataset/mcap.yaml index fe798b0..4e94a0b 100644 --- a/examples/config_templates/dataset/mcap.yaml +++ b/examples/config_templates/dataset/mcap.yaml @@ -36,32 +36,33 @@ inputs: #@ end table: - path: "${data_dir}/(@=input_id@).mcap" builder: _target_: rbyte.io.table.TableBuilder _convert_: all - reader: - _target_: rbyte.io.table.mcap.McapTableReader - _recursive_: false - decoder_factories: - - mcap_protobuf.decoder.DecoderFactory - - rbyte.utils.mcap.McapJsonDecoderFactory + readers: + - path: "${data_dir}/(@=input_id@).mcap" + reader: + _target_: rbyte.io.table.mcap.McapTableReader + _recursive_: false + decoder_factories: + - rbyte.utils.mcap.ProtobufDecoderFactory + - rbyte.utils.mcap.JsonDecoderFactory - fields: - #@ for topic in camera_topics: - (@=topic@): - log_time: - _target_: polars.Datetime - time_unit: ns + fields: + #@ for topic in camera_topics: + (@=topic@): + log_time: + _target_: polars.Datetime + time_unit: ns - _idx_: - #@ end + _idx_: + #@ end - /odom: - log_time: - _target_: polars.Datetime - time_unit: ns - vel.x: + /odom: + log_time: + _target_: polars.Datetime + time_unit: ns + vel.x: merger: _target_: rbyte.io.table.TableAligner diff --git a/examples/config_templates/dataset/yaak.yaml b/examples/config_templates/dataset/yaak.yaml index e81168f..e8f3427 100644 --- a/examples/config_templates/dataset/yaak.yaml +++ b/examples/config_templates/dataset/yaak.yaml @@ -1,7 +1,7 @@ #@yaml/text-templated-strings #@ drives = [ -#@ 'Niro098-HQ/2024-08-26--06-06-03', +#@ 'Niro098-HQ/2024-01-22--09-03-16', #@ ] #@ cameras = [ @@ -21,47 +21,67 @@ inputs: (@=source_id@): index_column: "ImageMetadata.(@=source_id@).frame_idx" reader: - _target_: rbyte.io.frame.VideoFrameReader - path: "${data_dir}/(@=input_id@)/(@=source_id@).defish.mp4" - resize_shorter_side: 324 + _target_: rbyte.io.frame.DirectoryFrameReader + _recursive_: true + path: "${data_dir}/(@=input_id@)/frames/(@=source_id@).pii.mp4/576x324/{:09d}.jpg" + frame_decoder: + _target_: simplejpeg.decode_jpeg + _partial_: true + colorspace: rgb + fastdct: true + fastupsample: true #@ end table: - path: ${data_dir}/(@=input_id@)/metadata.log builder: _target_: rbyte.io.table.TableBuilder - reader: - _target_: rbyte.io.table.yaak.YaakMetadataTableReader - _recursive_: false - _convert_: all - fields: - rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata: - time_stamp: - _target_: polars.Datetime - time_unit: ns - - frame_idx: polars.UInt32 - camera_name: - _target_: polars.Enum - categories: - - cam_front_center - - cam_front_left - - cam_front_right - - cam_left_forward - - cam_right_forward - - cam_left_backward - - cam_right_backward - - cam_rear - - rbyte.io.table.yaak.proto.can_pb2.VehicleMotion: - time_stamp: - _target_: polars.Datetime - time_unit: ns + _convert_: all + readers: + - path: ${data_dir}/(@=input_id@)/metadata.log + reader: + _target_: rbyte.io.table.yaak.YaakMetadataTableReader + _recursive_: false + fields: + rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: ns - speed: polars.Float32 - gear: - _target_: polars.Enum - categories: ["0", "1", "2", "3"] + frame_idx: polars.UInt32 + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear + + rbyte.io.table.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: ns + + speed: polars.Float32 + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + + - path: ${data_dir}/(@=input_id@)/ai.mcap + reader: + _target_: rbyte.io.table.mcap.McapTableReader + _recursive_: false + decoder_factories: [rbyte.utils.mcap.ProtobufDecoderFactory] + fields: + /ai/safety_score: + clip.end_timestamp: + _target_: polars.Datetime + time_unit: ns + + score: polars.Float32 merger: _target_: rbyte.io.table.TableAligner @@ -91,6 +111,15 @@ inputs: method: asof tolerance: 100ms + /ai/safety_score: + clip.end_timestamp: + method: ref + + score: + method: asof + tolerance: 100ms + strategy: nearest + filter: | `VehicleMotion.gear` == '3' @@ -104,7 +133,7 @@ sample_builder: _target_: rbyte.sample.builder.GreedySampleTableBuilder index_column: ImageMetadata.(@=cameras[0]@).frame_idx length: 6 - stride: 1 + stride: 10 min_step: 6 filter: | - array_lower(`VehicleMotion.speed`) > 80 + array_lower(`VehicleMotion.speed`) > 50 diff --git a/examples/config_templates/logger/rerun/carla.yaml b/examples/config_templates/logger/rerun/carla.yaml index d02d6b9..1204650 100644 --- a/examples/config_templates/logger/rerun/carla.yaml +++ b/examples/config_templates/logger/rerun/carla.yaml @@ -17,7 +17,7 @@ schema: #@ end table: - frame_idx: TimeSequenceColumn + _idx_: TimeSequenceColumn control.brake: Scalar control.steer: Scalar control.throttle: Scalar diff --git a/examples/config_templates/logger/rerun/yaak.yaml b/examples/config_templates/logger/rerun/yaak.yaml index 4bba368..28eb05b 100644 --- a/examples/config_templates/logger/rerun/yaak.yaml +++ b/examples/config_templates/logger/rerun/yaak.yaml @@ -23,3 +23,6 @@ schema: #@ end VehicleMotion.time_stamp: TimeNanosColumn VehicleMotion.speed: Scalar + + /ai/safety_score.clip.end_timestamp: TimeNanosColumn + /ai/safety_score.score: Scalar diff --git a/examples/config_templates/table_builder/carla.yaml b/examples/config_templates/table_builder/carla.yaml index 6ee4a92..9184245 100644 --- a/examples/config_templates/table_builder/carla.yaml +++ b/examples/config_templates/table_builder/carla.yaml @@ -1,18 +1,28 @@ --- -_target_: rbyte.io.table.carla.CarlaRecordsTableBuilder +_target_: rbyte.io.table.TableBuilder _convert_: all -index_column: frame_idx -select: - - control.brake - - control.throttle - - control.steer - - state.velocity.value - - state.acceleration.value +readers: + - path: ??? + reader: + _target_: rbyte.io.table.JsonTableReader + _recursive_: false + fields: + records: + _idx_: + control.brake: + control.throttle: + control.steer: + state.velocity.value: + state.acceleration.value: + + transforms: + - _target_: rbyte.io.table.transforms.FpsResampler + source_fps: 20 + target_fps: 30 + +merger: + _target_: rbyte.io.table.TableConcater + method: vertical filter: | `control.throttle` > 0.5 - -transforms: - - _target_: rbyte.io.table.transforms.FpsResampler - source_fps: 20 - target_fps: 30 diff --git a/examples/config_templates/table_builder/hdf5.yaml b/examples/config_templates/table_builder/hdf5.yaml index 17ad2e5..4a9562e 100644 --- a/examples/config_templates/table_builder/hdf5.yaml +++ b/examples/config_templates/table_builder/hdf5.yaml @@ -1,23 +1,25 @@ --- _target_: rbyte.io.table.TableBuilder _convert_: all -reader: - _target_: rbyte.io.table.hdf5.Hdf5TableReader - _recursive_: false - fields: - /data/demo_0: - _idx_: - actions: - dones: - obs/gt_nav: - obs/object: - obs/proprio: - obs/proprio_nav: - obs/scan: - rewards: - states: - task_successes: +readers: + - path: ??? + reader: + _target_: rbyte.io.table.hdf5.Hdf5TableReader + _recursive_: false + fields: + /data/demo_0: + _idx_: + actions: + dones: + obs/gt_nav: + obs/object: + obs/proprio: + obs/proprio_nav: + obs/scan: + rewards: + states: + task_successes: merger: _target_: rbyte.io.table.TableConcater - separator: "/" + method: vertical diff --git a/examples/config_templates/table_builder/mcap.yaml b/examples/config_templates/table_builder/mcap.yaml index 32c5cc3..bb2343c 100644 --- a/examples/config_templates/table_builder/mcap.yaml +++ b/examples/config_templates/table_builder/mcap.yaml @@ -7,28 +7,31 @@ --- _target_: rbyte.io.table.TableBuilder _convert_: all -reader: - _target_: rbyte.io.table.mcap.McapTableReader - _recursive_: false - decoder_factories: - - mcap_protobuf.decoder.DecoderFactory - - rbyte.utils.mcap.McapJsonDecoderFactory - - fields: - #@ for topic in camera_topics: - (@=topic@): - log_time: - _target_: polars.Datetime - time_unit: ns +readers: + - path: ??? + reader: + _target_: rbyte.io.table.mcap.McapTableReader + _recursive_: false + decoder_factories: + - rbyte.utils.mcap.ProtobufDecoderFactory + - rbyte.utils.mcap.JsonDecoderFactory + - mcap_ros2.decoder.DecoderFactory - _idx_: - #@ end + fields: + #@ for topic in camera_topics: + (@=topic@): + log_time: + _target_: polars.Datetime + time_unit: ns - /odom: - log_time: - _target_: polars.Datetime - time_unit: ns - vel.x: + _idx_: + #@ end + + /odom: + log_time: + _target_: polars.Datetime + time_unit: ns + vel.x: merger: _target_: rbyte.io.table.TableAligner diff --git a/examples/config_templates/table_builder/yaak.yaml b/examples/config_templates/table_builder/yaak.yaml index 4e8c4ca..c08e8a7 100644 --- a/examples/config_templates/table_builder/yaak.yaml +++ b/examples/config_templates/table_builder/yaak.yaml @@ -7,38 +7,53 @@ #@ ] --- _target_: rbyte.io.table.TableBuilder -reader: - _target_: rbyte.io.table.yaak.YaakMetadataTableReader - _recursive_: false - _convert_: all - fields: - rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata: - time_stamp: - _target_: polars.Datetime - time_unit: ns +_convert_: all +readers: + - path: ??? + reader: + _target_: rbyte.io.table.yaak.YaakMetadataTableReader + _recursive_: false + fields: + rbyte.io.table.yaak.proto.sensor_pb2.ImageMetadata: + time_stamp: + _target_: polars.Datetime + time_unit: ns - frame_idx: polars.UInt32 - camera_name: - _target_: polars.Enum - categories: - - cam_front_center - - cam_front_left - - cam_front_right - - cam_left_forward - - cam_right_forward - - cam_left_backward - - cam_right_backward - - cam_rear + frame_idx: polars.UInt32 + camera_name: + _target_: polars.Enum + categories: + - cam_front_center + - cam_front_left + - cam_front_right + - cam_left_forward + - cam_right_forward + - cam_left_backward + - cam_right_backward + - cam_rear - rbyte.io.table.yaak.proto.can_pb2.VehicleMotion: - time_stamp: - _target_: polars.Datetime - time_unit: ns + rbyte.io.table.yaak.proto.can_pb2.VehicleMotion: + time_stamp: + _target_: polars.Datetime + time_unit: ns - speed: polars.Float32 - gear: - _target_: polars.Enum - categories: ["0", "1", "2", "3"] + speed: polars.Float32 + gear: + _target_: polars.Enum + categories: ["0", "1", "2", "3"] + + - path: ??? + reader: + _target_: rbyte.io.table.mcap.McapTableReader + _recursive_: false + decoder_factories: [rbyte.utils.mcap.ProtobufDecoderFactory] + fields: + /ai/safety_score: + clip.end_timestamp: + _target_: polars.Datetime + time_unit: ns + + score: polars.Float32 merger: _target_: rbyte.io.table.TableAligner @@ -68,6 +83,15 @@ merger: method: asof tolerance: 100ms + /ai/safety_score: + clip.end_timestamp: + method: ref + + score: + method: asof + tolerance: 100ms + strategy: nearest + filter: | `VehicleMotion.gear` == '3' diff --git a/examples/config_templates/table_writer/console.yaml b/examples/config_templates/table_writer/console.yaml new file mode 100644 index 0000000..fba8b54 --- /dev/null +++ b/examples/config_templates/table_writer/console.yaml @@ -0,0 +1,4 @@ +--- +_target_: polars.DataFrame.glimpse +_partial_: true +max_items_per_column: 3 diff --git a/pyproject.toml b/pyproject.toml index 3054e04..21c6c36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rbyte" -version = "0.4.0" +version = "0.5.0" description = "Multimodal dataset library" authors = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] maintainers = [{ name = "Evgenii Gorchakov", email = "evgenii@yaak.ai" }] @@ -37,14 +37,17 @@ classifiers = [ build = ["hatchling>=1.25.0", "grpcio-tools>=1.62.0", "protoletariat==3.2.19"] visualize = ["rerun-sdk>=0.18.2"] mcap = [ - "mcap>=1.1.1", - "mcap-protobuf-support>=0.5.1", - "mcap-ros2-support>=0.5.3", - "python-box>=7.2.0", + "mcap>=1.2.1", + "mcap-ros2-support>=0.5.5", + "protobuf", + "mcap-protobuf-support>=0.5.2", ] -yaak = ["protobuf", "ptars>=0.0.2"] +yaak = ["protobuf", "ptars>=0.0.3"] jpeg = ["simplejpeg>=1.7.6"] -video = ["python-vali>=4.2.0.post0", "video-reader-rs>=0.1.5"] +video = [ + "python-vali>=4.2.0.post0; sys_platform == 'linux'", + "video-reader-rs>=0.1.5", +] hdf5 = ["h5py>=3.12.1"] [project.scripts] diff --git a/src/rbyte/config/base.py b/src/rbyte/config/base.py index 3b5f15a..30e21ef 100644 --- a/src/rbyte/config/base.py +++ b/src/rbyte/config/base.py @@ -3,7 +3,7 @@ from hydra.utils import instantiate from pydantic import BaseModel as _BaseModel -from pydantic import ConfigDict, Field, ImportString, field_serializer +from pydantic import ConfigDict, Field, ImportString, field_serializer, model_validator class BaseModel(_BaseModel): @@ -35,3 +35,13 @@ def instantiate(self, **kwargs: object) -> T: @staticmethod def serialize_target(v: object) -> str: return ImportString._serialize(v) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue] # noqa: SLF001 + + @model_validator(mode="before") + @classmethod + def validate_model(cls, data: object) -> object: + match data: + case str(): + return {"_target_": data} + + case _: + return data diff --git a/src/rbyte/dataset.py b/src/rbyte/dataset.py index 42b3bdf..8a4cd0b 100644 --- a/src/rbyte/dataset.py +++ b/src/rbyte/dataset.py @@ -6,7 +6,7 @@ import more_itertools as mit import polars as pl import torch -from pydantic import ConfigDict, Field, FilePath, StringConstraints, validate_call +from pydantic import ConfigDict, Field, StringConstraints, validate_call from structlog import get_logger from structlog.contextvars import bound_contextvars from tensordict import TensorDict @@ -33,7 +33,6 @@ class FrameSourceConfig(BaseModel): class TableSourceConfig(BaseModel): - path: FilePath builder: HydraConfig[TableBuilderBase] @@ -140,10 +139,10 @@ def _build_table(cls, sources: SourcesConfig) -> pl.LazyFrame: return pl.LazyFrame(frame_idxs) case SourcesConfig( - frame=frame_sources, table=TableSourceConfig(path=path, builder=builder) + frame=frame_sources, table=TableSourceConfig(builder=builder) ): table_builder = builder.instantiate() - table = table_builder.build(path).lazy() + table = table_builder.build().lazy() schema = table.collect_schema() for frame_source_id, frame_source in frame_sources.items(): @@ -162,6 +161,8 @@ def _build_table(cls, sources: SourcesConfig) -> pl.LazyFrame: return table case _: + logger.error("not implemented") + raise NotImplementedError @property diff --git a/src/rbyte/io/table/__init__.py b/src/rbyte/io/table/__init__.py index e5d6276..125536d 100644 --- a/src/rbyte/io/table/__init__.py +++ b/src/rbyte/io/table/__init__.py @@ -1,5 +1,27 @@ from .aligner import TableAligner from .builder import TableBuilder from .concater import TableConcater +from .json import JsonTableReader -__all__ = ["TableAligner", "TableBuilder", "TableConcater"] +__all__ = ["JsonTableReader", "TableAligner", "TableBuilder", "TableConcater"] + +try: + from .hdf5 import Hdf5TableReader +except ImportError: + pass +else: + __all__ += ["Hdf5TableReader"] + +try: + from .mcap import McapTableReader +except ImportError: + pass +else: + __all__ += ["McapTableReader"] + +try: + from .yaak import YaakMetadataTableReader +except ImportError: + pass +else: + __all__ += ["YaakMetadataTableReader"] diff --git a/src/rbyte/io/table/aligner.py b/src/rbyte/io/table/aligner.py index 3b1eb79..0144c0d 100644 --- a/src/rbyte/io/table/aligner.py +++ b/src/rbyte/io/table/aligner.py @@ -81,6 +81,9 @@ def _col_name(self, *args: str) -> str: @override def merge(self, src: Mapping[str, pl.DataFrame]) -> pl.DataFrame: + if unused_keys := src.keys() - self._config.merge.keys(): + logger.warning("unused", keys=sorted(unused_keys)) + dfs = { k: src[k] .sort(self._config.ref_columns[k]) diff --git a/src/rbyte/io/table/base.py b/src/rbyte/io/table/base.py index 7680822..0b1cfba 100644 --- a/src/rbyte/io/table/base.py +++ b/src/rbyte/io/table/base.py @@ -4,12 +4,12 @@ import polars as pl -type Table = pl.DataFrame +Table = pl.DataFrame @runtime_checkable class TableBuilderBase(Protocol): - def build(self, path: PathLike[str]) -> Table: ... + def build(self) -> Table: ... @runtime_checkable diff --git a/src/rbyte/io/table/builder.py b/src/rbyte/io/table/builder.py index 21da615..da72343 100644 --- a/src/rbyte/io/table/builder.py +++ b/src/rbyte/io/table/builder.py @@ -1,15 +1,18 @@ -from collections.abc import Hashable +import operator +from collections import Counter +from collections.abc import Hashable, Sequence +from functools import reduce from mmap import ACCESS_READ, mmap -from os import PathLike from pathlib import Path -from typing import Annotated, override +from typing import Annotated, Any, override +import more_itertools as mit import polars as pl -from pydantic import ConfigDict, StringConstraints, validate_call +from pydantic import ConfigDict, Field, FilePath, StringConstraints, validate_call from structlog import get_logger -from structlog.contextvars import bound_contextvars from xxhash import xxh3_64_intdigest as digest +from rbyte.config.base import BaseModel from rbyte.io.table.base import ( TableBuilderBase, TableCacheBase, @@ -20,59 +23,80 @@ logger = get_logger(__name__) +class TableReaderConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + path: FilePath + reader: TableReaderBase + + class TableBuilder(TableBuilderBase): @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, - reader: TableReaderBase, + readers: Annotated[Sequence[TableReaderConfig], Field(min_length=1)], merger: TableMergerBase, filter: Annotated[str, StringConstraints(strip_whitespace=True)] | None = None, # noqa: A002 cache: TableCacheBase | None = None, ) -> None: super().__init__() - self._reader = reader + self._readers = readers self._merger = merger self._filter = filter self._cache = cache - def _build_cache_key(self, path: PathLike[str]) -> Hashable: + def _build_cache_key(self) -> Hashable: from rbyte import __version__ as rbyte_version # noqa: PLC0415 - key = [rbyte_version, hash(self._reader), hash(self._merger)] + key: list[Any] = [rbyte_version, hash(self._merger)] if self._filter is not None: key.append(digest(self._filter)) - with Path(path).open("rb") as _f, mmap(_f.fileno(), 0, access=ACCESS_READ) as f: - key.append(digest(f)) # pyright: ignore[reportArgumentType] + for reader_config in self._readers: + with ( + Path(reader_config.path).open("rb") as _f, + mmap(_f.fileno(), 0, access=ACCESS_READ) as f, + ): + file_hash = digest(f) # pyright: ignore[reportArgumentType] + + key.append((file_hash, hash(reader_config.reader))) return tuple(key) @override - def build(self, path: PathLike[str]) -> pl.DataFrame: - with bound_contextvars(path=str(path)): - match self._cache: - case TableCacheBase(): - key = self._build_cache_key(path) - if key in self._cache: - logger.debug("reading table from cache") - df = self._cache.get(key) - if df is None: - raise RuntimeError - - return df - - df = self._build(path) - if not self._cache.set(key, df): - logger.warning("failed to cache table") + def build(self) -> pl.DataFrame: + match self._cache: + case TableCacheBase(): + key = self._build_cache_key() + if key in self._cache: + logger.debug("reading table from cache") + df = self._cache.get(key) + if df is None: + raise RuntimeError return df - case None: - return self._build(path) + df = self._build() + if not self._cache.set(key, df): + logger.warning("failed to cache table") + + return df - def _build(self, path: PathLike[str]) -> pl.DataFrame: - dfs = self._reader.read(path) + case None: + return self._build() + + def _build(self) -> pl.DataFrame: + reader_dfs = [cfg.reader.read(cfg.path) for cfg in self._readers] + if duplicate_keys := { + k for k, count in Counter(mit.flatten(reader_dfs)).items() if count > 1 + }: + logger.error(msg := "readers produced duplicate keys", keys=duplicate_keys) + + raise RuntimeError(msg) + + dfs = reduce(operator.or_, reader_dfs) df = self._merger.merge(dfs) + return df.sql(f"select * from self where ({self._filter or True})") # noqa: S608 diff --git a/src/rbyte/io/table/carla/__init__.py b/src/rbyte/io/table/carla/__init__.py deleted file mode 100644 index bb42088..0000000 --- a/src/rbyte/io/table/carla/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .builder import CarlaRecordsTableBuilder - -__all__ = ["CarlaRecordsTableBuilder"] diff --git a/src/rbyte/io/table/carla/builder.py b/src/rbyte/io/table/carla/builder.py deleted file mode 100644 index 0b85fd9..0000000 --- a/src/rbyte/io/table/carla/builder.py +++ /dev/null @@ -1,54 +0,0 @@ -from collections.abc import Sequence -from os import PathLike -from pathlib import Path -from typing import Annotated, override - -import polars as pl -from polars import selectors as cs -from pydantic import ConfigDict, StringConstraints, validate_call - -from rbyte.io.table.base import TableBuilderBase -from rbyte.io.table.transforms.base import TableTransform -from rbyte.utils.dataframe.misc import unnest_all - - -class CarlaRecordsTableBuilder(TableBuilderBase): - RECORD_KEY = "records" - - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) - def __init__( - self, - *, - index_column: Annotated[ - str, StringConstraints(strip_whitespace=True) - ] = "frame_idx", - select: str | frozenset[str] = "*", - filter: Annotated[str, StringConstraints(strip_whitespace=True)] | None = None, # noqa: A002 - transforms: Sequence[TableTransform] = (), - ) -> None: - super().__init__() - - self._index_column = index_column - self._select = select - self._filter = filter - self._transforms = transforms - - @override - def build(self, path: PathLike[str]) -> pl.DataFrame: - df = pl.read_json(Path(path)).explode(self.RECORD_KEY).unnest(self.RECORD_KEY) - df = ( - df.select(unnest_all(df.collect_schema())) - .select(self._select) - # 32 bits ought to be enough for anybody - .cast({ - cs.by_dtype(pl.Int64): pl.Int32, - cs.by_dtype(pl.UInt64): pl.UInt32, - cs.by_dtype(pl.Float64): pl.Float32, - }) - .sql(f"select * from self where ({self._filter or True})") # noqa: S608 - ) - - for transform in self._transforms: - df = transform(df) - - return df.with_row_index(self._index_column) diff --git a/src/rbyte/io/table/concater.py b/src/rbyte/io/table/concater.py index 62c2ca2..c8bab57 100644 --- a/src/rbyte/io/table/concater.py +++ b/src/rbyte/io/table/concater.py @@ -1,10 +1,9 @@ import json from collections.abc import Hashable, Mapping -from typing import Annotated, override +from typing import override import polars as pl from polars._typing import ConcatMethod -from pydantic import StringConstraints from xxhash import xxh3_64_intdigest as digest from rbyte.config import BaseModel @@ -12,7 +11,7 @@ class Config(BaseModel): - separator: Annotated[str, StringConstraints(strip_whitespace=True)] = "/" + separator: str | None = None method: ConcatMethod = "horizontal" @@ -22,7 +21,13 @@ def __init__(self, **kwargs: object) -> None: @override def merge(self, src: Mapping[str, pl.DataFrame]) -> pl.DataFrame: - return pl.concat(src.values(), how=self._config.method) + if (separator := self._config.separator) is not None: + src = { + k: df.select(pl.all().name.prefix(f"{k}{separator}")) + for k, df in src.items() + } + + return pl.concat(src.values(), how=self._config.method, rechunk=True) @override def __hash__(self) -> int: diff --git a/src/rbyte/io/table/hdf5/reader.py b/src/rbyte/io/table/hdf5/reader.py index bf01ec0..7772634 100644 --- a/src/rbyte/io/table/hdf5/reader.py +++ b/src/rbyte/io/table/hdf5/reader.py @@ -8,12 +8,13 @@ import numpy.typing as npt import polars as pl from h5py import Dataset, File, Group +from optree import tree_map from polars._typing import PolarsDataType from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 ) -from pydantic import ConfigDict, ImportString +from pydantic import ConfigDict from xxhash import xxh3_64_intdigest as digest from rbyte.config import BaseModel @@ -24,14 +25,11 @@ class Config(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - fields: Mapping[ - str, - Mapping[str, HydraConfig[PolarsDataType] | ImportString[PolarsDataType] | None], - ] + fields: Mapping[str, Mapping[str, HydraConfig[PolarsDataType] | None]] @unique -class SpecialFields(StrEnum): +class SpecialField(StrEnum): idx = "_idx_" @@ -52,13 +50,13 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: dfs: Mapping[str, pl.DataFrame] = {} with File(path) as f: - for group_key, schema in self.schemas.items(): + for group_key, schema in self._fields.items(): match group := f[group_key]: case Group(): series: list[pl.Series] = [] for name, dtype in schema.items(): match name: - case SpecialFields.idx: + case SpecialField.idx: pass case _: @@ -77,12 +75,12 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: raise NotImplementedError df = pl.DataFrame(data=series) # pyright: ignore[reportGeneralTypeIssues] - if (idx_name := SpecialFields.idx) in schema: + if (idx_name := SpecialField.idx) in schema: df = df.with_row_index(idx_name).cast({ idx_name: schema[idx_name] or pl.UInt32 }) - dfs[group_key] = df + dfs[group_key] = df.rechunk() case _: raise NotImplementedError @@ -90,13 +88,5 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: return dfs @cached_property - def schemas(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: - return { - group_key: { - dataset_key: leaf.instantiate() - if isinstance(leaf, HydraConfig) - else leaf - for dataset_key, leaf in fields.items() - } - for group_key, fields in self._config.fields.items() - } + def _fields(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: + return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType, reportReturnType] diff --git a/src/rbyte/io/table/json/__init__.py b/src/rbyte/io/table/json/__init__.py new file mode 100644 index 0000000..8296e67 --- /dev/null +++ b/src/rbyte/io/table/json/__init__.py @@ -0,0 +1,3 @@ +from .reader import JsonTableReader + +__all__ = ["JsonTableReader"] diff --git a/src/rbyte/io/table/json/reader.py b/src/rbyte/io/table/json/reader.py new file mode 100644 index 0000000..4d23573 --- /dev/null +++ b/src/rbyte/io/table/json/reader.py @@ -0,0 +1,87 @@ +import json +from collections.abc import Hashable, Mapping, Sequence +from enum import StrEnum, unique +from functools import cached_property +from os import PathLike +from pathlib import Path +from typing import override + +import polars as pl +from optree import tree_map +from polars._typing import PolarsDataType +from polars.datatypes import ( + DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 + DataTypeClass, # pyright: ignore[reportUnusedImport] # noqa: F401 +) +from pydantic import ConfigDict, Field +from xxhash import xxh3_64_intdigest as digest + +from rbyte.config.base import BaseModel, HydraConfig +from rbyte.io.table.base import TableReaderBase +from rbyte.io.table.transforms.base import TableTransform +from rbyte.utils.dataframe.misc import unnest_all + + +@unique +class SpecialField(StrEnum): + idx = "_idx_" + + +class Config(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + fields: Mapping[str, Mapping[str, HydraConfig[PolarsDataType] | None]] + transforms: Sequence[HydraConfig[TableTransform]] = Field(default=()) + + +class JsonTableReader(TableReaderBase, Hashable): + def __init__(self, **kwargs: object) -> None: + self._config = Config.model_validate(kwargs) + + @override + def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: + dfs: Mapping[str, pl.DataFrame] = {} + + for k, series in ( + pl.read_json(Path(path)).select(self._fields).to_dict().items() + ): + df_schema = { + name: dtype + for name, dtype in self._fields[k].items() + if dtype is not None + } + df = pl.DataFrame(series).lazy().explode(k).unnest(k) + df = ( + df.select(unnest_all(df.collect_schema())) + .select(self._config.fields[k].keys() - set(SpecialField)) + .cast(df_schema) # pyright: ignore[reportArgumentType] + .collect() + ) + + for transform in self._transforms: + df = transform(df) + + if (idx_name := SpecialField.idx) in (df_schema := self._fields[k]): + df = df.with_row_index(idx_name).cast({ + idx_name: df_schema[idx_name] or pl.UInt32 + }) + + dfs[k] = df + + return dfs + + @override + def __hash__(self) -> int: + config = self._config.model_dump_json() + # roundtripping json to work around https://github.com/pydantic/pydantic/issues/7424 + config_str = json.dumps(json.loads(config), sort_keys=True) + + return digest(config_str) + + @cached_property + def _fields(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: + return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType, reportReturnType] + + @cached_property + def _transforms(self) -> Sequence[TableTransform]: + return tuple(transform.instantiate() for transform in self._config.transforms) diff --git a/src/rbyte/io/table/mcap/reader.py b/src/rbyte/io/table/mcap/reader.py index 7d2acfc..89eb9ce 100644 --- a/src/rbyte/io/table/mcap/reader.py +++ b/src/rbyte/io/table/mcap/reader.py @@ -1,4 +1,5 @@ import json +from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping, Sequence from enum import StrEnum, unique from functools import cached_property @@ -11,7 +12,8 @@ import more_itertools as mit import polars as pl from mcap.decoder import DecoderFactory -from mcap.reader import DecodedMessageTuple, SeekingReader +from mcap.reader import SeekingReader +from optree import tree_map from polars._typing import PolarsDataType from polars.datatypes import ( DataType, # pyright: ignore[reportUnusedImport] # noqa: F401 @@ -31,6 +33,7 @@ from rbyte.config.base import BaseModel, HydraConfig from rbyte.io.table.base import TableReaderBase +from rbyte.utils.dataframe import unnest_all logger = get_logger(__name__) @@ -38,13 +41,8 @@ class Config(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - decoder_factories: Sequence[ImportString[type[DecoderFactory]]] - - fields: Mapping[ - str, - Mapping[str, HydraConfig[PolarsDataType] | ImportString[PolarsDataType] | None], - ] - + decoder_factories: frozenset[ImportString[type[DecoderFactory]]] + fields: Mapping[str, Mapping[str, HydraConfig[PolarsDataType] | None]] validate_crcs: bool = False @field_serializer("decoder_factories", when_used="json", mode="wrap") @@ -63,7 +61,7 @@ class RowValues(NamedTuple): @unique -class SpecialFields(StrEnum): +class SpecialField(StrEnum): log_time = "log_time" publish_time = "publish_time" idx = "_idx_" @@ -90,7 +88,7 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: logger.error(msg := "missing summary") raise ValueError(msg) - topics = self.schemas.keys() + topics = self._fields.keys() if missing_topics := topics - ( available_topics := {ch.topic for ch in summary.channels.values()} ): @@ -111,35 +109,46 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: else None ) - row_values = ( - RowValues( - dmt.channel.topic, - self._get_values(dmt, self.schemas[dmt.channel.topic]), - ) - for dmt in tqdm( - reader.iter_decoded_messages(topics), - desc="messages", - total=message_count, + rows: Mapping[str, list[pl.DataFrame]] = defaultdict(list) + + for dmt in tqdm( + reader.iter_decoded_messages(topics), + desc="messages", + total=message_count, + ): + schema = self._fields[dmt.channel.topic] + message_fields, special_fields = map( + dict, + mit.partition(lambda kv: kv[0] in SpecialField, schema.items()), ) - ) - row_values_by_topic = mit.bucket(row_values, key=lambda rd: rd.topic) + special_fields = { + k: v for k, v in special_fields.items() if k != SpecialField.idx + } - dfs: Mapping[str, pl.DataFrame] = {} - for topic, schema in self.schemas.items(): - df_schema = {k: v for k, v in schema.items() if k != SpecialFields.idx} - df = pl.DataFrame( - data=(tuple(x.values) for x in row_values_by_topic[topic]), - schema=df_schema, # pyright: ignore[reportArgumentType] - orient="row", + row_df = pl.DataFrame( + [getattr(dmt.message, field) for field in special_fields], + schema=special_fields, # pyright: ignore[reportArgumentType] ) - if (idx_name := SpecialFields.idx) in schema: - df = df.with_row_index(idx_name).cast({ - idx_name: schema[idx_name] or pl.UInt32 - }) + if ( + message_df := self._build_message_df( + dmt.decoded_message, message_fields + ) + ) is not None: + row_df = row_df.hstack(message_df) - dfs[topic] = df + rows[dmt.channel.topic].append(row_df) + + dfs: Mapping[str, pl.DataFrame] = {} + for topic, row_dfs in rows.items(): + df = pl.concat(row_dfs, how="vertical") + if (idx_name := SpecialField.idx) in (schema := self._fields[topic]): + df = df.with_row_index(idx_name).cast({ + idx_name: schema[idx_name] or pl.UInt32 + }) + + dfs[topic] = df.rechunk() return dfs @@ -152,27 +161,28 @@ def __hash__(self) -> int: return digest(config_str) @cached_property - def schemas(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: - return { - topic: { - path: leaf.instantiate() if isinstance(leaf, HydraConfig) else leaf - for path, leaf in fields.items() - } - for topic, fields in self._config.fields.items() - } + def _fields(self) -> Mapping[str, Mapping[str, PolarsDataType | None]]: + return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownArgumentType, reportUnknownMemberType, reportUnknownVariableType, reportReturnType] @staticmethod - def _get_values(dmt: DecodedMessageTuple, fields: Iterable[str]) -> Iterable[Any]: - for field in fields: - match field: - case SpecialFields.log_time: - yield dmt.message.log_time - - case SpecialFields.publish_time: - yield dmt.message.publish_time - - case SpecialFields.idx: - pass # added later - - case _: - yield attrgetter(field)(dmt.decoded_message) + def _build_message_df( + message: object, fields: Mapping[str, PolarsDataType | None] + ) -> pl.DataFrame | None: + if not fields: + return None + + df_schema = {name: dtype for name, dtype in fields.items() if dtype is not None} + + match message: + case pl.DataFrame(): + return ( + message.lazy() + .select(unnest_all(message.collect_schema())) + .select(fields) + .cast(df_schema) # pyright: ignore[reportArgumentType] + ).collect() + + case _: + return pl.from_dict({ + field: attrgetter(field)(message) for field in fields + }).cast(df_schema) # pyright: ignore[reportArgumentType] diff --git a/src/rbyte/io/table/transforms/base.py b/src/rbyte/io/table/transforms/base.py index a287bb2..51180a4 100644 --- a/src/rbyte/io/table/transforms/base.py +++ b/src/rbyte/io/table/transforms/base.py @@ -1,8 +1,8 @@ from typing import Protocol, runtime_checkable -import polars as pl +from rbyte.io.table.base import Table @runtime_checkable class TableTransform(Protocol): - def __call__(self, src: pl.DataFrame) -> pl.DataFrame: ... + def __call__(self, src: Table) -> Table: ... diff --git a/src/rbyte/io/table/transforms/fps_resampler.py b/src/rbyte/io/table/transforms/fps_resampler.py index 5e1635b..1b2c128 100644 --- a/src/rbyte/io/table/transforms/fps_resampler.py +++ b/src/rbyte/io/table/transforms/fps_resampler.py @@ -4,6 +4,8 @@ import polars as pl from pydantic import PositiveInt, validate_call +from rbyte.io.table.base import Table + from .base import TableTransform @@ -19,7 +21,7 @@ def __init__(self, source_fps: PositiveInt, target_fps: PositiveInt) -> None: self._fps_lcm = lcm(source_fps, target_fps) @override - def __call__(self, src: pl.DataFrame) -> pl.DataFrame: + def __call__(self, src: Table) -> Table: return ( src.with_row_index(self.IDX_COL) .with_columns(pl.col(self.IDX_COL) * (self._fps_lcm // self._source_fps)) diff --git a/src/rbyte/io/table/yaak/reader.py b/src/rbyte/io/table/yaak/reader.py index 0605950..dd75054 100644 --- a/src/rbyte/io/table/yaak/reader.py +++ b/src/rbyte/io/table/yaak/reader.py @@ -35,8 +35,7 @@ class Config(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) fields: Mapping[ - ImportString[type[Message]], - Mapping[str, HydraConfig[PolarsDataType] | ImportString[PolarsDataType] | None], + ImportString[type[Message]], Mapping[str, HydraConfig[PolarsDataType] | None] ] @@ -70,6 +69,7 @@ def read(self, path: PathLike[str]) -> Mapping[str, pl.DataFrame]: ]) .select(schema), schema=schema, # pyright: ignore[reportArgumentType] + rechunk=True, ), ) for msg_type, schema in self._fields.items() @@ -95,7 +95,4 @@ def __hash__(self) -> int: @cached_property def _fields(self) -> Mapping[type[Message], Mapping[str, PolarsDataType | None]]: - return tree_map( # pyright: ignore[reportUnknownVariableType, reportReturnType] - lambda x: x.instantiate() if isinstance(x, HydraConfig) else x, # pyright: ignore[reportUnknownLambdaType, reportUnknownArgumentType] - self._config.fields, # pyright: ignore[reportArgumentType] - ) + return tree_map(HydraConfig.instantiate, self._config.fields) # pyright: ignore[reportArgumentType, reportUnknownVariableType, reportReturnType, reportUnknownMemberType, reportUnknownArgumentType] diff --git a/src/rbyte/sample/builder.py b/src/rbyte/sample/builder.py index 7178479..c8a8fc0 100644 --- a/src/rbyte/sample/builder.py +++ b/src/rbyte/sample/builder.py @@ -58,5 +58,5 @@ def build(self, source: pl.LazyFrame) -> pl.LazyFrame: .sort(sample_idx_col) .select(pl.exclude(sample_idx_col)) # TODO: https://github.com/pola-rs/polars/issues/18810 # noqa: FIX002 - # .select(pl.all().list.to_array(self.length)) + # .select(pl.all().list.to_array(self._length)) ) diff --git a/src/rbyte/scripts/build_table.py b/src/rbyte/scripts/build_table.py index df8b10a..0144279 100644 --- a/src/rbyte/scripts/build_table.py +++ b/src/rbyte/scripts/build_table.py @@ -15,7 +15,7 @@ def main(config: DictConfig) -> None: table_builder = cast(TableBuilderBase, instantiate(config.table_builder)) table_writer = cast(Callable[[Table], None], instantiate(config.table_writer)) - table = table_builder.build(config.path) + table = table_builder.build() return table_writer(table) diff --git a/src/rbyte/utils/mcap/__init__.py b/src/rbyte/utils/mcap/__init__.py index 8a476ed..9d039fa 100644 --- a/src/rbyte/utils/mcap/__init__.py +++ b/src/rbyte/utils/mcap/__init__.py @@ -1,3 +1,3 @@ -from .json_decoder_factory import McapJsonDecoderFactory +from .decoders import JsonDecoderFactory, ProtobufDecoderFactory -__all__ = ["McapJsonDecoderFactory"] +__all__ = ["JsonDecoderFactory", "ProtobufDecoderFactory"] diff --git a/src/rbyte/utils/mcap/decoders/__init__.py b/src/rbyte/utils/mcap/decoders/__init__.py new file mode 100644 index 0000000..2eadf40 --- /dev/null +++ b/src/rbyte/utils/mcap/decoders/__init__.py @@ -0,0 +1,4 @@ +from .json_decoder_factory import JsonDecoderFactory +from .protobuf_decoder_factory import ProtobufDecoderFactory + +__all__ = ["JsonDecoderFactory", "ProtobufDecoderFactory"] diff --git a/src/rbyte/utils/mcap/decoders/json_decoder_factory.py b/src/rbyte/utils/mcap/decoders/json_decoder_factory.py new file mode 100644 index 0000000..f5c2d13 --- /dev/null +++ b/src/rbyte/utils/mcap/decoders/json_decoder_factory.py @@ -0,0 +1,24 @@ +from collections.abc import Callable +from typing import override + +import polars as pl +from mcap.decoder import DecoderFactory as McapDecoderFactory +from mcap.records import Schema +from structlog import get_logger + +logger = get_logger(__name__) + + +class JsonDecoderFactory(McapDecoderFactory): + @override + def decoder_for( + self, message_encoding: str, schema: Schema | None + ) -> Callable[[bytes], pl.DataFrame] | None: + if ( + message_encoding == "json" + and schema is not None + and schema.encoding == "jsonschema" + ): + return pl.read_json + + return None diff --git a/src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py b/src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py new file mode 100644 index 0000000..151651d --- /dev/null +++ b/src/rbyte/utils/mcap/decoders/protobuf_decoder_factory.py @@ -0,0 +1,72 @@ +from collections.abc import Callable +from operator import attrgetter +from typing import override + +import more_itertools as mit +import polars as pl +from cachetools import cached +from google.protobuf.descriptor_pb2 import FileDescriptorProto, FileDescriptorSet +from google.protobuf.descriptor_pool import DescriptorPool +from google.protobuf.message import Message +from google.protobuf.message_factory import GetMessageClassesForFiles +from mcap.decoder import DecoderFactory as McapDecoderFactory +from mcap.exceptions import McapError +from mcap.records import Schema +from mcap.well_known import MessageEncoding, SchemaEncoding +from ptars import HandlerPool +from structlog import get_logger + +logger = get_logger(__name__) + + +class ProtobufDecoderFactory(McapDecoderFactory): + def __init__(self) -> None: + self._handler_pool = HandlerPool() + + @override + def decoder_for( + self, message_encoding: str, schema: Schema | None + ) -> Callable[[bytes], pl.DataFrame] | None: + if ( + message_encoding == MessageEncoding.Protobuf + and schema is not None + and schema.encoding == SchemaEncoding.Protobuf + ): + message_type = self._get_message_type(schema) + handler = self._handler_pool.get_for_message(message_type.DESCRIPTOR) # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + + def decoder(data: bytes) -> pl.DataFrame: + record_batch = handler.list_to_record_batch([data]) # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType] + return pl.from_arrow(record_batch, rechunk=False) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportReturnType] + + return decoder + + return None + + @staticmethod + @cached(cache={}, key=attrgetter("id")) + def _get_message_type(schema: Schema) -> type[Message]: + fds = FileDescriptorSet.FromString(schema.data) + pool = DescriptorPool() + descriptor_by_name = mit.map_reduce( + fds.file, keyfunc=attrgetter("name"), reducefunc=mit.one + ) + + def _add(fd: FileDescriptorProto) -> None: + for dependency in fd.dependency: + if dependency in descriptor_by_name: + _add(descriptor_by_name.pop(dependency)) + + pool.Add(fd) # pyright: ignore[reportUnknownMemberType] + + while descriptor_by_name: + _add(descriptor_by_name.popitem()[1]) + + message_types = GetMessageClassesForFiles([fd.name for fd in fds.file], pool) + + if (message_type := message_types.get(schema.name, None)) is None: + logger.error(msg := "FileDescriptorSet missing schema", schema=schema) + + raise McapError(msg) + + return message_type diff --git a/src/rbyte/utils/mcap/json_decoder_factory.py b/src/rbyte/utils/mcap/json_decoder_factory.py deleted file mode 100644 index 51f093b..0000000 --- a/src/rbyte/utils/mcap/json_decoder_factory.py +++ /dev/null @@ -1,27 +0,0 @@ -import json -from collections.abc import Callable -from typing import override - -from box import Box -from mcap.decoder import DecoderFactory as McapDecoderFactory -from mcap.records import Schema -from structlog import get_logger - -logger = get_logger(__name__) - - -class McapJsonDecoderFactory(McapDecoderFactory): - @override - def decoder_for( - self, message_encoding: str, schema: Schema | None - ) -> Callable[[bytes], Box] | None: - match message_encoding, getattr(schema, "encoding", None): - case "json", "jsonschema": - return self._decoder - - case _: - return None - - @staticmethod - def _decoder(data: bytes) -> Box: - return Box(json.loads(data))