Skip to content

Commit

Permalink
fix(datasets): verify file exists if on Polars 1.0 (#957)
Browse files Browse the repository at this point in the history
* ci(datasets): unbound Polars for test requirements

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

* test(datasets): use a more version-agnostic assert

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

* revert(datasets): undo `assert_frame_equal` change

Refs: 10af4db
Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

* chore(datasets): use the Polars 1.0 equality check

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

* chore(datasets): use calamine engine in Polars 1.0

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

* revert(datasets): undo swap to the calamine engine

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

* fix(datasets): raise error manually for Polars 1.0

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

* ci(datasets): skip a failing doctest in Windows CI

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

* test(datasets): skip failing save tests on Windows

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>

---------

Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
  • Loading branch information
deepyaman authored Dec 6, 2024
1 parent 1d768ad commit 7baa826
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 5 deletions.
10 changes: 8 additions & 2 deletions kedro-datasets/kedro_datasets/polars/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,21 @@ class CSVDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):
.. code-block:: pycon
>>> from kedro_datasets.polars import CSVDataset
>>> import sys
>>>
>>> import polars as pl
>>> import pytest
>>> from kedro_datasets.polars import CSVDataset
>>>
>>> if sys.platform.startswith("win"):
... pytest.skip("this doctest fails on Windows CI runner")
...
>>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = CSVDataset(filepath=tmp_path / "test.csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
>>> assert data.equals(reloaded)
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class EagerPolarsDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):
>>> dataset = EagerPolarsDataset(filepath=tmp_path / "test.parquet", file_format="parquet")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
>>> assert data.equals(reloaded)
"""

Expand Down
6 changes: 5 additions & 1 deletion kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""
from __future__ import annotations

import errno
import logging
import os
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, ClassVar
Expand Down Expand Up @@ -69,7 +71,7 @@ class LazyPolarsDataset(
>>> dataset = LazyPolarsDataset(filepath=tmp_path / "test.csv", file_format="csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded.collect())
>>> assert data.equals(reloaded.collect())
"""

Expand Down Expand Up @@ -199,6 +201,8 @@ def _describe(self) -> dict[str, Any]:

def load(self) -> pl.LazyFrame:
load_path = str(self._get_load_path())
if not self._exists():
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), load_path)

if self._protocol == "file":
# With local filesystems, we can use Polar's build-in I/O method:
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ test = [
"pandas>=2.0",
"Pillow~=10.0",
"plotly>=4.8.0, <6.0",
"polars[xlsx2csv, deltalake]~=0.18.0",
"polars[deltalake,xlsx2csv]>=1.0",
"pyarrow>=1.0; python_version < '3.11'",
"pyarrow>=7.0; python_version >= '3.11'", # Adding to avoid numpy build errors
"pyodbc~=5.0",
Expand Down
10 changes: 10 additions & 0 deletions kedro-datasets/tests/polars/test_csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ def mocked_csv_in_s3(mocked_s3_bucket, mocked_dataframe: pl.DataFrame):


class TestCSVDataset:
@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_save_and_load(self, csv_dataset, dummy_dataframe):
"""Test saving and reloading the dataset."""
csv_dataset.save(dummy_dataframe)
reloaded = csv_dataset.load()
assert_frame_equal(dummy_dataframe, reloaded)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_exists(self, csv_dataset, dummy_dataframe):
"""Test `exists` method invocation for both existing and
nonexistent dataset."""
Expand Down Expand Up @@ -202,13 +204,15 @@ def test_version_str_repr(self, load_version, save_version):
assert "load_args={'rechunk': True}" in str(ds)
assert "load_args={'rechunk': True}" in str(ds_versioned)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_save_and_load(self, versioned_csv_dataset, dummy_dataframe):
"""Test that saved and reloaded data matches the original one for
the versioned dataset."""
versioned_csv_dataset.save(dummy_dataframe)
reloaded_df = versioned_csv_dataset.load()
assert_frame_equal(dummy_dataframe, reloaded_df)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_csv):
"""Test that if a new version is created mid-run, by an
external system, it won't be loaded in the current run."""
Expand All @@ -232,6 +236,7 @@ def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_c
ds_new.resolve_load_version() == v_new
) # new version is discoverable by a new instance

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_multiple_saves(self, dummy_dataframe, filepath_csv):
"""Test multiple cycles of save followed by load for the same dataset"""
ds_versioned = CSVDataset(filepath=filepath_csv, version=Version(None, None))
Expand All @@ -254,6 +259,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv):
ds_new = CSVDataset(filepath=filepath_csv, version=Version(None, None))
assert ds_new.resolve_load_version() == second_load_version

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_release_instance_cache(self, dummy_dataframe, filepath_csv):
"""Test that cache invalidation does not affect other instances"""
ds_a = CSVDataset(filepath=filepath_csv, version=Version(None, None))
Expand Down Expand Up @@ -282,12 +288,14 @@ def test_no_versions(self, versioned_csv_dataset):
with pytest.raises(DatasetError, match=pattern):
versioned_csv_dataset.load()

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_exists(self, versioned_csv_dataset, dummy_dataframe):
"""Test `exists` method invocation for versioned dataset."""
assert not versioned_csv_dataset.exists()
versioned_csv_dataset.save(dummy_dataframe)
assert versioned_csv_dataset.exists()

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe):
"""Check the error when attempting to override the dataset if the
corresponding CSV file for a given save version already exists."""
Expand All @@ -299,6 +307,7 @@ def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe):
with pytest.raises(DatasetError, match=pattern):
versioned_csv_dataset.save(dummy_dataframe)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
@pytest.mark.parametrize(
"load_version", ["2019-01-01T23.59.59.999Z"], indirect=True
)
Expand All @@ -325,6 +334,7 @@ def test_http_filesystem_no_versioning(self):
filepath="https://example.com/file.csv", version=Version(None, None)
)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_versioning_existing_dataset(
self, csv_dataset, versioned_csv_dataset, dummy_dataframe
):
Expand Down
1 change: 1 addition & 0 deletions kedro-datasets/tests/polars/test_eager_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def excel_dataset(dummy_dataframe: pl.DataFrame, filepath_excel):
return EagerPolarsDataset(
filepath=filepath_excel.as_posix(),
file_format="excel",
load_args={"engine": "xlsx2csv"},
)


Expand Down

0 comments on commit 7baa826

Please sign in to comment.