Skip to content

Commit

Permalink
[SPARK-50051][PYTHON][CONNECT] Make lit works with empty numpy ndarray
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Make `lit` works with empty numpy ndarray

### Why are the changes needed?
bug fix, the schema for empty ndarray is incorrect

PySpark Classic:
```
In [3]: spark.range(1).select(sf.lit(np.array([1,2,3], np.int32))).schema
Out[3]: StructType([StructField('ARRAY(1, 2, 3)', ArrayType(IntegerType(), True), False)])

In [4]: spark.range(1).select(sf.lit(np.array([], np.int32))).schema
Out[4]: StructType([StructField('ARRAY()', ArrayType(IntegerType(), True), False)])
```

### Does this PR introduce _any_ user-facing change?
before:
```
In [7]: spark.range(1).select(sf.lit(np.array([], np.int32))).schema
Out[7]: StructType([StructField('array()', ArrayType(NullType(), False), False)])
```

after:
```
In [3]: spark.range(1).select(sf.lit(np.array([], np.int32))).schema
Out[3]: StructType([StructField('array()', ArrayType(IntegerType(), True), False)])
```

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48589 from zhengruifeng/connect_empty_ndarray.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Oct 22, 2024
1 parent e7cdb5a commit dc35ba8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
11 changes: 7 additions & 4 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
StructType,
ArrayType,
StringType,
ByteType,
ShortType,
)
from pyspark.sql.utils import enum_to_value as _enum_to_value

Expand Down Expand Up @@ -267,7 +269,8 @@ def lit(col: Any) -> Column:
)
return array(*[lit(c) for c in col])
elif isinstance(col, np.ndarray) and col.ndim == 1:
if _from_numpy_type(col.dtype) is None:
dt = _from_numpy_type(col.dtype)
if dt is None:
raise PySparkTypeError(
errorClass="UNSUPPORTED_NUMPY_ARRAY_SCALAR",
messageParameters={"dtype": col.dtype.name},
Expand All @@ -276,10 +279,10 @@ def lit(col: Any) -> Column:
# NumpyArrayConverter for Py4J can not support ndarray with int8 values.
# Actually this is not a problem for Connect, but here still convert it
# to int16 for compatibility.
if col.dtype == np.int8:
col = col.astype(np.int16)
if dt == ByteType():
dt = ShortType()

return array(*[lit(c) for c in col])
return array(*[lit(c) for c in col]).cast(ArrayType(dt))
else:
return ConnectColumn(LiteralExpression._from_value(col))

Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/sql/tests/connect/test_parity_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ def test_input_file_name_reset_for_rdd(self):
def test_str_ndarray(self):
super().test_str_ndarray()

@unittest.skip("SPARK-50051: Spark Connect should empty ndarray.")
def test_empty_ndarray(self):
super().test_empty_ndarray()


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_functions import * # noqa: F401
Expand Down
17 changes: 13 additions & 4 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,10 +1281,19 @@ def test_str_ndarray(self):
def test_empty_ndarray(self):
import numpy as np

self.assertEqual(
[("a", "array<int>")],
self.spark.range(1).select(F.lit(np.array([], np.int32)).alias("a")).dtypes,
)
arr_dtype_to_spark_dtypes = [
("int8", [("b", "array<smallint>")]),
("int16", [("b", "array<smallint>")]),
("int32", [("b", "array<int>")]),
("int64", [("b", "array<bigint>")]),
("float32", [("b", "array<float>")]),
("float64", [("b", "array<double>")]),
]
for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes:
arr = np.array([]).astype(t)
self.assertEqual(
expected_spark_dtypes, self.spark.range(1).select(F.lit(arr).alias("b")).dtypes
)

def test_binary_math_function(self):
funcs, expected = zip(
Expand Down

0 comments on commit dc35ba8

Please sign in to comment.