Skip to content

Commit

Permalink
[SPARK-48638][CONNECT] Add ExecutionInfo support for DataFrame
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

One of the interesting shortcomings in Spark Connect is that the query execution metrics are not easily accessible directly. In Spark Classic, the query execution is only accessible via the `_jdf` private variable and this is not available in Spark Connect.

However, since the first release of Spark Connect, the response messages were already containing the metrics from the executed plan.

This patch makes them accessible directly and provides a way to visualize them.

```python
df = spark.range(100)
df.collect()
metrics = df.executionInfo.metrics
metrics.toDot()
```

The `toDot()` method returns an instance of the `graphviz.Digraph` object that can be either directly displayed in a notebook or further manipulated.

<img width="909" alt="image" src="https://github.com/apache/spark/assets/3421/e1710350-54d2-4e6d-9b80-0aaf1b8583e3">

<img width="871" alt="image" src="https://github.com/apache/spark/assets/3421/6a972119-76b6-4e36-bc81-8d01110fa31c">

The purpose of the `executionInfo` property and the associated `ExecutionInfo` class is not to provide equivalence to the `QueryExecution` class used internally by Spark (and, for example, access to the analyzed, optimized, and executed plan) but rather provide a convenient way of accessing execution related information.

### Why are the changes needed?
User Experience

### Does this PR introduce _any_ user-facing change?
Adding a new API for accessing the query execution of a Spark SQL execution.

### How was this patch tested?
Added new UT

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #46996 from grundprinzip/SPARK-48638.

Lead-authored-by: Martin Grund <martin.grund@databricks.com>
Co-authored-by: Martin Grund <grundprinzip@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
2 people authored and HyukjinKwon committed Jun 25, 2024
1 parent 5928908 commit 9d4abaf
Show file tree
Hide file tree
Showing 18 changed files with 544 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ private[connect] object MetricGenerator extends AdaptiveSparkPlanHelper {
.newBuilder()
.setName(p.nodeName)
.setPlanId(p.id)
.setParent(parentId)
.putAllExecutionMetrics(mv.asJava)
.build()
Seq(mo) ++ transformChildren(p)
Expand Down
3 changes: 3 additions & 0 deletions dev/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ mypy-protobuf==3.3.0
googleapis-common-protos-stubs==2.2.0
grpc-stubs==1.24.11

# Debug for Spark and Spark Connect
graphviz==0.20.3

# TorchDistributor dependencies
torch
torchvision
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_pandas_udf_window",
"pyspark.sql.tests.connect.test_resources",
"pyspark.sql.tests.connect.shell.test_progress",
"pyspark.sql.tests.connect.test_df_debug",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ Package Supported version Note
`grpcio` >=1.62.0 Required for Spark Connect
`grpcio-status` >=1.62.0 Required for Spark Connect
`googleapis-common-protos` >=1.56.4 Required for Spark Connect
`graphviz` >=0.20 Optional for Spark Connect
========================== ================= ==========================

Spark SQL
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ DataFrame
DataFrame.dropna
DataFrame.dtypes
DataFrame.exceptAll
DataFrame.executionInfo
DataFrame.explain
DataFrame.fillna
DataFrame.filter
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@
"Cannot <condition1> without <condition2>."
]
},
"CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF": {
"message": [
"Calling property or member '<member>' is not supported in PySpark Classic, please use Spark Connect instead."
]
},
"COLLATION_INVALID_PROVIDER" : {
"message" : [
"The value <provider> does not represent a correct collation provider. Supported providers are: [<supportedProviders>]."
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from pyspark.sql.session import SparkSession
from pyspark.sql.group import GroupedData
from pyspark.sql.observation import Observation
from pyspark.sql.metrics import ExecutionInfo


class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
Expand Down Expand Up @@ -1835,6 +1836,13 @@ def toArrow(self) -> "pa.Table":
def toPandas(self) -> "PandasDataFrameLike":
return PandasConversionMixin.toPandas(self)

@property
def executionInfo(self) -> Optional["ExecutionInfo"]:
raise PySparkValueError(
error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF",
message_parameters={"member": "queryExecution"},
)


def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject":
"""
Expand Down
91 changes: 30 additions & 61 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from pyspark.loose_version import LooseVersion
from pyspark.version import __version__
from pyspark.resource.information import ResourceInformation
from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.client.logging import logger
from pyspark.sql.connect.profiler import ConnectProfilerCollector
Expand Down Expand Up @@ -447,56 +448,7 @@ def toChannel(self) -> grpc.Channel:
return self._secure_channel(self.endpoint, creds)


class MetricValue:
def __init__(self, name: str, value: Union[int, float], type: str):
self._name = name
self._type = type
self._value = value

def __repr__(self) -> str:
return f"<{self._name}={self._value} ({self._type})>"

@property
def name(self) -> str:
return self._name

@property
def value(self) -> Union[int, float]:
return self._value

@property
def metric_type(self) -> str:
return self._type


class PlanMetrics:
def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]):
self._name = name
self._id = id
self._parent_id = parent
self._metrics = metrics

def __repr__(self) -> str:
return f"Plan({self._name})={self._metrics}"

@property
def name(self) -> str:
return self._name

@property
def plan_id(self) -> int:
return self._id

@property
def parent_plan_id(self) -> int:
return self._parent_id

@property
def metrics(self) -> List[MetricValue]:
return self._metrics


class PlanObservedMetrics:
class PlanObservedMetrics(ObservedMetrics):
def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys: List[str]):
self._name = name
self._metrics = metrics
Expand All @@ -513,6 +465,13 @@ def name(self) -> str:
def metrics(self) -> List[pb2.Expression.Literal]:
return self._metrics

@property
def pairs(self) -> dict[str, Any]:
result = {}
for x in range(len(self._metrics)):
result[self.keys[x]] = LiteralExpression._to_value(self.metrics[x])
return result

@property
def keys(self) -> List[str]:
return self._keys
Expand Down Expand Up @@ -888,7 +847,7 @@ def _resources(self) -> Dict[str, ResourceInformation]:
logger.info("Fetching the resources")
cmd = pb2.Command()
cmd.get_resources_command.SetInParent()
(_, properties) = self.execute_command(cmd)
(_, properties, _) = self.execute_command(cmd)
resources = properties["get_resources_command_result"]
return resources

Expand All @@ -915,18 +874,23 @@ def to_table_as_iterator(

def to_table(
self, plan: pb2.Plan, observations: Dict[str, Observation]
) -> Tuple["pa.Table", Optional[StructType]]:
) -> Tuple["pa.Table", Optional[StructType], ExecutionInfo]:
"""
Return given plan as a PyArrow Table.
"""
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
table, schema, _, _, _ = self._execute_and_fetch(req, observations)
table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations)

# Create a query execution object.
ei = ExecutionInfo(metrics, observed_metrics)
assert table is not None
return table, schema
return table, schema, ei

def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd.DataFrame":
def to_pandas(
self, plan: pb2.Plan, observations: Dict[str, Observation]
) -> Tuple["pd.DataFrame", "ExecutionInfo"]:
"""
Return given plan as a pandas DataFrame.
"""
Expand All @@ -941,6 +905,7 @@ def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd
req, observations, self_destruct=self_destruct
)
assert table is not None
ei = ExecutionInfo(metrics, observed_metrics)

schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
assert schema is not None and isinstance(schema, StructType)
Expand Down Expand Up @@ -1007,7 +972,7 @@ def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd
pdf.attrs["metrics"] = metrics
if len(observed_metrics) > 0:
pdf.attrs["observed_metrics"] = observed_metrics
return pdf
return pdf, ei

def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
"""
Expand Down Expand Up @@ -1051,7 +1016,7 @@ def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:

def execute_command(
self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None
) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]:
) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], ExecutionInfo]:
"""
Execute given command.
"""
Expand All @@ -1060,11 +1025,15 @@ def execute_command(
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
data, _, _, _, properties = self._execute_and_fetch(req, observations or {})
data, _, metrics, observed_metrics, properties = self._execute_and_fetch(
req, observations or {}
)
# Create a query execution object.
ei = ExecutionInfo(metrics, observed_metrics)
if data is not None:
return (data.to_pandas(), properties)
return (data.to_pandas(), properties, ei)
else:
return (None, properties)
return (None, properties, ei)

def execute_command_as_iterator(
self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None
Expand Down Expand Up @@ -1849,6 +1818,6 @@ def _create_profile(self, profile: pb2.ResourceProfile) -> int:
logger.info("Creating the ResourceProfile")
cmd = pb2.Command()
cmd.create_resource_profile_command.profile.CopyFrom(profile)
(_, properties) = self.execute_command(cmd)
(_, properties, _) = self.execute_command(cmd)
profile_id = properties["create_resource_profile_command_result"]
return profile_id
44 changes: 34 additions & 10 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.session import SparkSession
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
from pyspark.sql.metrics import ExecutionInfo


class DataFrame(ParentDataFrame):
Expand Down Expand Up @@ -137,6 +138,7 @@ def __init__(
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
self._cached_schema: Optional[StructType] = None
self._execution_info: Optional["ExecutionInfo"] = None

def __reduce__(self) -> Tuple:
"""
Expand Down Expand Up @@ -206,7 +208,10 @@ def _repr_html_(self) -> Optional[str]:

@property
def write(self) -> "DataFrameWriter":
return DataFrameWriter(self._plan, self._session)
def cb(qe: "ExecutionInfo") -> None:
self._execution_info = qe

return DataFrameWriter(self._plan, self._session, cb)

@functools.cache
def isEmpty(self) -> bool:
Expand Down Expand Up @@ -1839,7 +1844,9 @@ def collect(self) -> List[Row]:

def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
query = self._plan.to_proto(self._session.client)
table, schema = self._session.client.to_table(query, self._plan.observations)
table, schema, self._execution_info = self._session.client.to_table(
query, self._plan.observations
)
assert table is not None
return (table, schema)

Expand All @@ -1850,7 +1857,9 @@ def toArrow(self) -> "pa.Table":

def toPandas(self) -> "PandasDataFrameLike":
query = self._plan.to_proto(self._session.client)
return self._session.client.to_pandas(query, self._plan.observations)
pdf, ei = self._session.client.to_pandas(query, self._plan.observations)
self._execution_info = ei
return pdf

@property
def schema(self) -> StructType:
Expand Down Expand Up @@ -1976,25 +1985,29 @@ def createTempView(self, name: str) -> None:
command = plan.CreateView(
child=self._plan, name=name, is_global=False, replace=False
).command(session=self._session.client)
self._session.client.execute_command(command, self._plan.observations)
_, _, ei = self._session.client.execute_command(command, self._plan.observations)
self._execution_info = ei

def createOrReplaceTempView(self, name: str) -> None:
command = plan.CreateView(
child=self._plan, name=name, is_global=False, replace=True
).command(session=self._session.client)
self._session.client.execute_command(command, self._plan.observations)
_, _, ei = self._session.client.execute_command(command, self._plan.observations)
self._execution_info = ei

def createGlobalTempView(self, name: str) -> None:
command = plan.CreateView(
child=self._plan, name=name, is_global=True, replace=False
).command(session=self._session.client)
self._session.client.execute_command(command, self._plan.observations)
_, _, ei = self._session.client.execute_command(command, self._plan.observations)
self._execution_info = ei

def createOrReplaceGlobalTempView(self, name: str) -> None:
command = plan.CreateView(
child=self._plan, name=name, is_global=True, replace=True
).command(session=self._session.client)
self._session.client.execute_command(command, self._plan.observations)
_, _, ei = self._session.client.execute_command(command, self._plan.observations)
self._execution_info = ei

def cache(self) -> ParentDataFrame:
return self.persist()
Expand Down Expand Up @@ -2169,22 +2182,29 @@ def semanticHash(self) -> int:
)

def writeTo(self, table: str) -> "DataFrameWriterV2":
return DataFrameWriterV2(self._plan, self._session, table)
def cb(ei: "ExecutionInfo") -> None:
self._execution_info = ei

return DataFrameWriterV2(self._plan, self._session, table, cb)

def offset(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session)

def checkpoint(self, eager: bool = True) -> "DataFrame":
cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager)
_, properties = self._session.client.execute_command(cmd.command(self._session.client))
_, properties, self._execution_info = self._session.client.execute_command(
cmd.command(self._session.client)
)
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
return checkpointed

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
_, properties = self._session.client.execute_command(cmd.command(self._session.client))
_, properties, self._execution_info = self._session.client.execute_command(
cmd.command(self._session.client)
)
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
Expand All @@ -2205,6 +2225,10 @@ def rdd(self) -> "RDD[Row]":
message_parameters={"feature": "rdd"},
)

@property
def executionInfo(self) -> Optional["ExecutionInfo"]:
return self._execution_info


class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
Expand Down
Loading

0 comments on commit 9d4abaf

Please sign in to comment.