Skip to content

Commit

Permalink
[SPARK-43620][CONNECT][PS] Fix Pandas APIs depends on unsupported fea…
Browse files Browse the repository at this point in the history
…tures

### What changes were proposed in this pull request?

This PR proposes to fix the Pandas APIs that have dependency on unsupported PySpark features.

### Why are the changes needed?

To increate the API coverage for Pandas API on Spark with Spark Connect.

### Does this PR introduce _any_ user-facing change?

Pandas data type APIs such as `astype` and `factorize` are supported on Spark Connect.

### How was this patch tested?

Enabling the existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #43120 from itholic/SPARK-43620.

Authored-by: Haejoon Lee <haejoon.lee@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
itholic authored and HyukjinKwon committed Oct 4, 2023
1 parent 9830901 commit d86208d
Show file tree
Hide file tree
Showing 16 changed files with 39 additions and 131 deletions.
25 changes: 13 additions & 12 deletions python/pyspark/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,16 +1704,10 @@ def factorize(
if len(categories) == 0:
scol = F.lit(None)
else:
kvs = list(
chain(
*[
(F.lit(code), F.lit(category))
for code, category in enumerate(categories)
]
)
)
map_scol = F.create_map(*kvs)
scol = map_scol[self.spark.column]
scol = F.lit(None)
for code, category in reversed(list(enumerate(categories))):
scol = F.when(self.spark.column == F.lit(code), F.lit(category)).otherwise(scol)

codes, uniques = self._with_new_scol(
scol.alias(self._internal.data_spark_column_names[0])
).factorize(use_na_sentinel=use_na_sentinel)
Expand Down Expand Up @@ -1761,9 +1755,16 @@ def factorize(
if len(kvs) == 0: # uniques are all missing values
new_scol = F.lit(na_sentinel_code)
else:
map_scol = F.create_map(*kvs)
null_scol = F.when(self.isnull().spark.column, F.lit(na_sentinel_code))
new_scol = null_scol.otherwise(map_scol[self.spark.column])
mapped_scol = None
for i in range(0, len(kvs), 2):
key = kvs[i]
value = kvs[i + 1]
if mapped_scol is None:
mapped_scol = F.when(self.spark.column == key, value)
else:
mapped_scol = mapped_scol.when(self.spark.column == key, value)
new_scol = null_scol.otherwise(mapped_scol)

codes = self._with_new_scol(new_scol.alias(self._internal.data_spark_column_names[0]))

Expand Down
26 changes: 20 additions & 6 deletions python/pyspark/pandas/data_type_ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import numbers
from abc import ABCMeta
from itertools import chain
from typing import Any, Optional, Union

import numpy as np
Expand Down Expand Up @@ -130,12 +129,27 @@ def _as_categorical_type(
if len(categories) == 0:
scol = F.lit(-1)
else:
kvs = chain(
*[(F.lit(category), F.lit(code)) for code, category in enumerate(categories)]
)
map_scol = F.create_map(*kvs)
scol = F.lit(-1)
if isinstance(
index_ops._internal.spark_type_for(index_ops._internal.column_labels[0]), BinaryType
):
from pyspark.sql.functions import base64

stringified_column = base64(index_ops.spark.column)
for code, category in enumerate(categories):
# Convert each category to base64 before comparison
base64_category = F.base64(F.lit(category))
scol = F.when(stringified_column == base64_category, F.lit(code)).otherwise(
scol
)
else:
stringified_column = F.format_string("%s", index_ops.spark.column)

for code, category in enumerate(categories):
scol = F.when(stringified_column == F.lit(category), F.lit(code)).otherwise(
scol
)

scol = F.coalesce(map_scol[index_ops.spark.column], F.lit(-1))
return index_ops._with_new_scol(
scol.cast(spark_type),
field=index_ops._internal.data_fields[0].copy(
Expand Down
7 changes: 3 additions & 4 deletions python/pyspark/pandas/data_type_ops/categorical_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

from itertools import chain
from typing import cast, Any, Union

import pandas as pd
Expand Down Expand Up @@ -135,7 +134,7 @@ def _to_cat(index_ops: IndexOpsLike) -> IndexOpsLike:
if len(categories) == 0:
scol = F.lit(None)
else:
kvs = chain(*[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)])
map_scol = F.create_map(*kvs)
scol = map_scol[index_ops.spark.column]
scol = F.lit(None)
for code, category in reversed(list(enumerate(categories))):
scol = F.when(index_ops.spark.column == F.lit(code), F.lit(category)).otherwise(scol)
return index_ops._with_new_scol(scol)
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
class BinaryOpsParityTests(
BinaryOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase, ReusedConnectTestCase
):
@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()
pass


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class BooleanOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_boolean_ops import * # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,6 @@ class CategoricalOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_eq(self):
super().test_eq()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_ne(self):
super().test_ne()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_categorical_ops import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class DateOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_date_ops import * # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class DatetimeOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_datetime_ops import * # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
class NullOpsParityTests(
NullOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase, ReusedConnectTestCase
):
@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()
pass


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class NumOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_num_ops import * # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class StringOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_string_ops import * # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class TimedeltaOpsParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.data_type_ops.test_parity_timedelta_ops import * # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ class IndexesParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_factorize(self):
super().test_factorize()

@unittest.skip("TODO(SPARK-43704): Enable IndexesParityTests.test_to_series.")
def test_to_series(self):
super().test_to_series()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,7 @@
class CategoricalIndexParityTests(
CategoricalIndexTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_append(self):
super().test_append()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_factorize(self):
super().test_factorize()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_intersection(self):
super().test_intersection()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_remove_categories(self):
super().test_remove_categories()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_remove_unused_categories(self):
super().test_remove_unused_categories()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_reorder_categories(self):
super().test_reorder_categories()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_set_categories(self):
super().test_set_categories()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_union(self):
super().test_union()
pass


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
pass

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_factorize(self):
super().test_factorize()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.series.test_parity_compute import * # noqa: F401
Expand Down
24 changes: 0 additions & 24 deletions python/pyspark/pandas/tests/connect/test_parity_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,6 @@ class CategoricalParityTests(
def psdf(self):
return ps.from_pandas(self.pdf)

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_astype(self):
super().test_astype()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_factorize(self):
super().test_factorize()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_remove_categories(self):
super().test_remove_categories()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_remove_unused_categories(self):
super().test_remove_unused_categories()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_reorder_categories(self):
super().test_reorder_categories()

@unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.")
def test_set_categories(self):
super().test_set_categories()


if __name__ == "__main__":
from pyspark.pandas.tests.connect.test_parity_categorical import * # noqa: F401
Expand Down

0 comments on commit d86208d

Please sign in to comment.