From 9d4abaf19fa4d508aca35ac00528b2ce9f4e8805 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Wed, 26 Jun 2024 08:35:32 +0900 Subject: [PATCH] [SPARK-48638][CONNECT] Add ExecutionInfo support for DataFrame ### What changes were proposed in this pull request? One of the interesting shortcomings in Spark Connect is that the query execution metrics are not easily accessible directly. In Spark Classic, the query execution is only accessible via the `_jdf` private variable and this is not available in Spark Connect. However, since the first release of Spark Connect, the response messages were already containing the metrics from the executed plan. This patch makes them accessible directly and provides a way to visualize them. ```python df = spark.range(100) df.collect() metrics = df.executionInfo.metrics metrics.toDot() ``` The `toDot()` method returns an instance of the `graphviz.Digraph` object that can be either directly displayed in a notebook or further manipulated. image image The purpose of the `executionInfo` property and the associated `ExecutionInfo` class is not to provide equivalence to the `QueryExecution` class used internally by Spark (and, for example, access to the analyzed, optimized, and executed plan) but rather provide a convenient way of accessing execution related information. ### Why are the changes needed? User Experience ### Does this PR introduce _any_ user-facing change? Adding a new API for accessing the query execution of a Spark SQL execution. ### How was this patch tested? Added new UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #46996 from grundprinzip/SPARK-48638. Lead-authored-by: Martin Grund Co-authored-by: Martin Grund Signed-off-by: Hyukjin Kwon --- .../sql/connect/utils/MetricGenerator.scala | 1 + dev/requirements.txt | 3 + dev/sparktestsupport/modules.py | 1 + .../docs/source/getting_started/install.rst | 1 + .../reference/pyspark.sql/dataframe.rst | 1 + python/pyspark/errors/error-conditions.json | 5 + python/pyspark/sql/classic/dataframe.py | 8 + python/pyspark/sql/connect/client/core.py | 91 ++---- python/pyspark/sql/connect/dataframe.py | 44 ++- python/pyspark/sql/connect/readwriter.py | 48 ++- python/pyspark/sql/connect/session.py | 7 +- python/pyspark/sql/connect/streaming/query.py | 4 +- .../sql/connect/streaming/readwriter.py | 2 +- python/pyspark/sql/dataframe.py | 26 ++ python/pyspark/sql/metrics.py | 287 ++++++++++++++++++ .../sql/tests/connect/test_df_debug.py | 86 ++++++ python/pyspark/sql/tests/test_dataframe.py | 11 +- python/pyspark/testing/connectutils.py | 7 + 18 files changed, 544 insertions(+), 89 deletions(-) create mode 100644 python/pyspark/sql/metrics.py create mode 100644 python/pyspark/sql/tests/connect/test_df_debug.py diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala index e2e4128311871..d76bec5454abb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala @@ -70,6 +70,7 @@ private[connect] object MetricGenerator extends AdaptiveSparkPlanHelper { .newBuilder() .setName(p.nodeName) .setPlanId(p.id) + .setParent(parentId) .putAllExecutionMetrics(mv.asJava) .build() Seq(mo) ++ transformChildren(p) diff --git a/dev/requirements.txt b/dev/requirements.txt index d6530d8ce2821..88883a963950e 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -60,6 +60,9 @@ mypy-protobuf==3.3.0 googleapis-common-protos-stubs==2.2.0 grpc-stubs==1.24.11 +# Debug for Spark and Spark Connect +graphviz==0.20.3 + # TorchDistributor dependencies torch torchvision diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 8c17af559c250..66927066faa7a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1063,6 +1063,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_pandas_udf_window", "pyspark.sql.tests.connect.test_resources", "pyspark.sql.tests.connect.shell.test_progress", + "pyspark.sql.tests.connect.test_df_debug", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 21926ae295bfd..6cc68cd46b117 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -210,6 +210,7 @@ Package Supported version Note `grpcio` >=1.62.0 Required for Spark Connect `grpcio-status` >=1.62.0 Required for Spark Connect `googleapis-common-protos` >=1.56.4 Required for Spark Connect +`graphviz` >=0.20 Optional for Spark Connect ========================== ================= ========================== Spark SQL diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst b/python/docs/source/reference/pyspark.sql/dataframe.rst index ec39b645b1403..d0196baa7a05b 100644 --- a/python/docs/source/reference/pyspark.sql/dataframe.rst +++ b/python/docs/source/reference/pyspark.sql/dataframe.rst @@ -55,6 +55,7 @@ DataFrame DataFrame.dropna DataFrame.dtypes DataFrame.exceptAll + DataFrame.executionInfo DataFrame.explain DataFrame.fillna DataFrame.filter diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 30db373872491..dd70e814b1ea8 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -149,6 +149,11 @@ "Cannot without ." ] }, + "CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF": { + "message": [ + "Calling property or member '' is not supported in PySpark Classic, please use Spark Connect instead." + ] + }, "COLLATION_INVALID_PROVIDER" : { "message" : [ "The value does not represent a correct collation provider. Supported providers are: []." diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index a03467aff1945..1bedd624603e1 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -94,6 +94,7 @@ from pyspark.sql.session import SparkSession from pyspark.sql.group import GroupedData from pyspark.sql.observation import Observation + from pyspark.sql.metrics import ExecutionInfo class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin): @@ -1835,6 +1836,13 @@ def toArrow(self) -> "pa.Table": def toPandas(self) -> "PandasDataFrameLike": return PandasConversionMixin.toPandas(self) + @property + def executionInfo(self) -> Optional["ExecutionInfo"]: + raise PySparkValueError( + error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF", + message_parameters={"member": "queryExecution"}, + ) + def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject": """ diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index f3bbab69f2719..e91324150cbd8 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -61,6 +61,7 @@ from pyspark.loose_version import LooseVersion from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation +from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics from pyspark.sql.connect.client.artifact import ArtifactManager from pyspark.sql.connect.client.logging import logger from pyspark.sql.connect.profiler import ConnectProfilerCollector @@ -447,56 +448,7 @@ def toChannel(self) -> grpc.Channel: return self._secure_channel(self.endpoint, creds) -class MetricValue: - def __init__(self, name: str, value: Union[int, float], type: str): - self._name = name - self._type = type - self._value = value - - def __repr__(self) -> str: - return f"<{self._name}={self._value} ({self._type})>" - - @property - def name(self) -> str: - return self._name - - @property - def value(self) -> Union[int, float]: - return self._value - - @property - def metric_type(self) -> str: - return self._type - - -class PlanMetrics: - def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]): - self._name = name - self._id = id - self._parent_id = parent - self._metrics = metrics - - def __repr__(self) -> str: - return f"Plan({self._name})={self._metrics}" - - @property - def name(self) -> str: - return self._name - - @property - def plan_id(self) -> int: - return self._id - - @property - def parent_plan_id(self) -> int: - return self._parent_id - - @property - def metrics(self) -> List[MetricValue]: - return self._metrics - - -class PlanObservedMetrics: +class PlanObservedMetrics(ObservedMetrics): def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys: List[str]): self._name = name self._metrics = metrics @@ -513,6 +465,13 @@ def name(self) -> str: def metrics(self) -> List[pb2.Expression.Literal]: return self._metrics + @property + def pairs(self) -> dict[str, Any]: + result = {} + for x in range(len(self._metrics)): + result[self.keys[x]] = LiteralExpression._to_value(self.metrics[x]) + return result + @property def keys(self) -> List[str]: return self._keys @@ -888,7 +847,7 @@ def _resources(self) -> Dict[str, ResourceInformation]: logger.info("Fetching the resources") cmd = pb2.Command() cmd.get_resources_command.SetInParent() - (_, properties) = self.execute_command(cmd) + (_, properties, _) = self.execute_command(cmd) resources = properties["get_resources_command_result"] return resources @@ -915,18 +874,23 @@ def to_table_as_iterator( def to_table( self, plan: pb2.Plan, observations: Dict[str, Observation] - ) -> Tuple["pa.Table", Optional[StructType]]: + ) -> Tuple["pa.Table", Optional[StructType], ExecutionInfo]: """ Return given plan as a PyArrow Table. """ logger.info(f"Executing plan {self._proto_to_string(plan)}") req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) - table, schema, _, _, _ = self._execute_and_fetch(req, observations) + table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations) + + # Create a query execution object. + ei = ExecutionInfo(metrics, observed_metrics) assert table is not None - return table, schema + return table, schema, ei - def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd.DataFrame": + def to_pandas( + self, plan: pb2.Plan, observations: Dict[str, Observation] + ) -> Tuple["pd.DataFrame", "ExecutionInfo"]: """ Return given plan as a pandas DataFrame. """ @@ -941,6 +905,7 @@ def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd req, observations, self_destruct=self_destruct ) assert table is not None + ei = ExecutionInfo(metrics, observed_metrics) schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True) assert schema is not None and isinstance(schema, StructType) @@ -1007,7 +972,7 @@ def to_pandas(self, plan: pb2.Plan, observations: Dict[str, Observation]) -> "pd pdf.attrs["metrics"] = metrics if len(observed_metrics) > 0: pdf.attrs["observed_metrics"] = observed_metrics - return pdf + return pdf, ei def _proto_to_string(self, p: google.protobuf.message.Message) -> str: """ @@ -1051,7 +1016,7 @@ def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str: def execute_command( self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None - ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]: + ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], ExecutionInfo]: """ Execute given command. """ @@ -1060,11 +1025,15 @@ def execute_command( if self._user_id: req.user_context.user_id = self._user_id req.plan.command.CopyFrom(command) - data, _, _, _, properties = self._execute_and_fetch(req, observations or {}) + data, _, metrics, observed_metrics, properties = self._execute_and_fetch( + req, observations or {} + ) + # Create a query execution object. + ei = ExecutionInfo(metrics, observed_metrics) if data is not None: - return (data.to_pandas(), properties) + return (data.to_pandas(), properties, ei) else: - return (None, properties) + return (None, properties, ei) def execute_command_as_iterator( self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None @@ -1849,6 +1818,6 @@ def _create_profile(self, profile: pb2.ResourceProfile) -> int: logger.info("Creating the ResourceProfile") cmd = pb2.Command() cmd.create_resource_profile_command.profile.CopyFrom(profile) - (_, properties) = self.execute_command(cmd) + (_, properties, _) = self.execute_command(cmd) profile_id = properties["create_resource_profile_command_result"] return profile_id diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 678e66ee2b7b0..1aa8fc00cfcc9 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -101,6 +101,7 @@ from pyspark.sql.connect.observation import Observation from pyspark.sql.connect.session import SparkSession from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame + from pyspark.sql.metrics import ExecutionInfo class DataFrame(ParentDataFrame): @@ -137,6 +138,7 @@ def __init__( # by __repr__ and _repr_html_ while eager evaluation opens. self._support_repr_html = False self._cached_schema: Optional[StructType] = None + self._execution_info: Optional["ExecutionInfo"] = None def __reduce__(self) -> Tuple: """ @@ -206,7 +208,10 @@ def _repr_html_(self) -> Optional[str]: @property def write(self) -> "DataFrameWriter": - return DataFrameWriter(self._plan, self._session) + def cb(qe: "ExecutionInfo") -> None: + self._execution_info = qe + + return DataFrameWriter(self._plan, self._session, cb) @functools.cache def isEmpty(self) -> bool: @@ -1839,7 +1844,9 @@ def collect(self) -> List[Row]: def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: query = self._plan.to_proto(self._session.client) - table, schema = self._session.client.to_table(query, self._plan.observations) + table, schema, self._execution_info = self._session.client.to_table( + query, self._plan.observations + ) assert table is not None return (table, schema) @@ -1850,7 +1857,9 @@ def toArrow(self) -> "pa.Table": def toPandas(self) -> "PandasDataFrameLike": query = self._plan.to_proto(self._session.client) - return self._session.client.to_pandas(query, self._plan.observations) + pdf, ei = self._session.client.to_pandas(query, self._plan.observations) + self._execution_info = ei + return pdf @property def schema(self) -> StructType: @@ -1976,25 +1985,29 @@ def createTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=False, replace=False ).command(session=self._session.client) - self._session.client.execute_command(command, self._plan.observations) + _, _, ei = self._session.client.execute_command(command, self._plan.observations) + self._execution_info = ei def createOrReplaceTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=False, replace=True ).command(session=self._session.client) - self._session.client.execute_command(command, self._plan.observations) + _, _, ei = self._session.client.execute_command(command, self._plan.observations) + self._execution_info = ei def createGlobalTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=False ).command(session=self._session.client) - self._session.client.execute_command(command, self._plan.observations) + _, _, ei = self._session.client.execute_command(command, self._plan.observations) + self._execution_info = ei def createOrReplaceGlobalTempView(self, name: str) -> None: command = plan.CreateView( child=self._plan, name=name, is_global=True, replace=True ).command(session=self._session.client) - self._session.client.execute_command(command, self._plan.observations) + _, _, ei = self._session.client.execute_command(command, self._plan.observations) + self._execution_info = ei def cache(self) -> ParentDataFrame: return self.persist() @@ -2169,14 +2182,19 @@ def semanticHash(self) -> int: ) def writeTo(self, table: str) -> "DataFrameWriterV2": - return DataFrameWriterV2(self._plan, self._session, table) + def cb(ei: "ExecutionInfo") -> None: + self._execution_info = ei + + return DataFrameWriterV2(self._plan, self._session, table, cb) def offset(self, n: int) -> ParentDataFrame: return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session) def checkpoint(self, eager: bool = True) -> "DataFrame": cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager) - _, properties = self._session.client.execute_command(cmd.command(self._session.client)) + _, properties, self._execution_info = self._session.client.execute_command( + cmd.command(self._session.client) + ) assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) @@ -2184,7 +2202,9 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": def localCheckpoint(self, eager: bool = True) -> "DataFrame": cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) - _, properties = self._session.client.execute_command(cmd.command(self._session.client)) + _, properties, self._execution_info = self._session.client.execute_command( + cmd.command(self._session.client) + ) assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) @@ -2205,6 +2225,10 @@ def rdd(self) -> "RDD[Row]": message_parameters={"feature": "rdd"}, ) + @property + def executionInfo(self) -> Optional["ExecutionInfo"]: + return self._execution_info + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index bf7dc4d369057..de62cf65b01ed 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -19,7 +19,7 @@ check_dependencies(__name__) from typing import Dict -from typing import Optional, Union, List, overload, Tuple, cast +from typing import Optional, Union, List, overload, Tuple, cast, Callable from typing import TYPE_CHECKING from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation, WriteOperationV2 @@ -37,6 +37,7 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType from pyspark.sql.connect.session import SparkSession + from pyspark.sql.metrics import ExecutionInfo __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -486,11 +487,18 @@ def _jreader(self) -> None: class DataFrameWriter(OptionUtils): - def __init__(self, plan: "LogicalPlan", session: "SparkSession"): + def __init__( + self, + plan: "LogicalPlan", + session: "SparkSession", + callback: Optional[Callable[["ExecutionInfo"], None]] = None, + ): self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session self._write: "WriteOperation" = WriteOperation(self._df) + self._callback = callback if callback is not None else lambda _: None + def mode(self, saveMode: Optional[str]) -> "DataFrameWriter": # At the JVM side, the default value of mode is already set to "error". # So, if the given saveMode is None, we will not call JVM-side's mode method. @@ -649,9 +657,10 @@ def save( if format is not None: self.format(format) self._write.path = path - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) save.__doc__ = PySparkDataFrameWriter.save.__doc__ @@ -660,9 +669,10 @@ def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> None: self.mode("overwrite" if overwrite else "append") self._write.table_name = tableName self._write.table_save_method = "insert_into" - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__ @@ -681,9 +691,10 @@ def saveAsTable( self.format(format) self._write.table_name = name self._write.table_save_method = "save_as_table" - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__ @@ -845,11 +856,18 @@ def jdbc( class DataFrameWriterV2(OptionUtils): - def __init__(self, plan: "LogicalPlan", session: "SparkSession", table: str): + def __init__( + self, + plan: "LogicalPlan", + session: "SparkSession", + table: str, + callback: Optional[Callable[["ExecutionInfo"], None]] = None, + ): self._df: "LogicalPlan" = plan self._spark: "SparkSession" = session self._table_name: str = table self._write: "WriteOperationV2" = WriteOperationV2(self._df, self._table_name) + self._callback = callback if callback is not None else lambda _: None def using(self, provider: str) -> "DataFrameWriterV2": self._write.provider = provider @@ -884,50 +902,56 @@ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFram def create(self) -> None: self._write.mode = "create" - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) create.__doc__ = PySparkDataFrameWriterV2.create.__doc__ def replace(self) -> None: self._write.mode = "replace" - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__ def createOrReplace(self) -> None: self._write.mode = "create_or_replace" - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__ def append(self) -> None: self._write.mode = "append" - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) append.__doc__ = PySparkDataFrameWriterV2.append.__doc__ def overwrite(self, condition: "ColumnOrName") -> None: self._write.mode = "overwrite" self._write.overwrite_condition = F._to_col(condition) - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__ def overwritePartitions(self) -> None: self._write.mode = "overwrite_partitions" - self._spark.client.execute_command( + _, _, ei = self._spark.client.execute_command( self._write.command(self._spark.client), self._write.observations ) + self._callback(ei) overwritePartitions.__doc__ = PySparkDataFrameWriterV2.overwritePartitions.__doc__ diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index f359ab829483a..8e277b3fc63aa 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -720,9 +720,12 @@ def sql( _views.append(SubqueryAlias(df._plan, name)) cmd = SQL(sqlQuery, _args, _named_args, _views) - data, properties = self.client.execute_command(cmd.command(self._client)) + data, properties, ei = self.client.execute_command(cmd.command(self._client)) if "sql_command_result" in properties: - return DataFrame(CachedRelation(properties["sql_command_result"]), self) + df = DataFrame(CachedRelation(properties["sql_command_result"]), self) + # A command result contains the execution. + df._execution_info = ei + return df else: return DataFrame(cmd, self) diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 98ecdc4966c75..13458d650fa9f 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -181,7 +181,7 @@ def _execute_streaming_query_cmd( cmd.query_id.run_id = self._run_id exec_cmd = pb2.Command() exec_cmd.streaming_query_command.CopyFrom(cmd) - (_, properties) = self._session.client.execute_command(exec_cmd) + (_, properties, _) = self._session.client.execute_command(exec_cmd) return cast(pb2.StreamingQueryCommandResult, properties["streaming_query_command_result"]) @@ -260,7 +260,7 @@ def _execute_streaming_query_manager_cmd( ) -> pb2.StreamingQueryManagerCommandResult: exec_cmd = pb2.Command() exec_cmd.streaming_query_manager_command.CopyFrom(cmd) - (_, properties) = self._session.client.execute_command(exec_cmd) + (_, properties, _) = self._session.client.execute_command(exec_cmd) return cast( pb2.StreamingQueryManagerCommandResult, properties["streaming_query_manager_command_result"], diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index b5bb7f2a09128..9b11bf328b853 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -601,7 +601,7 @@ def _start_internal( self._write_proto.table_name = tableName cmd = self._write_stream.command(self._session.client) - (_, properties) = self._session.client.execute_command(cmd) + (_, properties, _) = self._session.client.execute_command(cmd) start_result = cast( pb2.WriteStreamOperationStartResult, properties["write_stream_operation_start_result"] diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 62c46cfec93cd..625678588bf9e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -64,6 +64,7 @@ ArrowMapIterFunction, DataFrameLike as PandasDataFrameLike, ) + from pyspark.sql.metrics import ExecutionInfo __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -6281,6 +6282,31 @@ def toPandas(self) -> "PandasDataFrameLike": """ ... + @property + def executionInfo(self) -> Optional["ExecutionInfo"]: + """ + Returns a QueryExecution object after the query was executed. + + The queryExecution method allows to introspect information about the actual + query execution after the successful execution. Accessing this member before + the query execution will return None. + + If the same DataFrame is executed multiple times, the execution info will be + overwritten by the latest operation. + + .. versionadded:: 4.0.0 + + Returns + ------- + An instance of QueryExecution or None when the value is not set yet. + + Notes + ----- + This is an API dedicated to Spark Connect client only. With regular Spark Session, it throws + an exception. + """ + ... + class DataFrameNaFunctions: """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py new file mode 100644 index 0000000000000..6664582952014 --- /dev/null +++ b/python/pyspark/sql/metrics.py @@ -0,0 +1,287 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import dataclasses +from typing import Optional, List, Tuple, Dict, Any, Union, TYPE_CHECKING, Sequence + +from pyspark.errors import PySparkValueError + +if TYPE_CHECKING: + from pyspark.testing.connectutils import have_graphviz + + if have_graphviz: + import graphviz # type: ignore + + +class ObservedMetrics(abc.ABC): + @property + @abc.abstractmethod + def name(self) -> str: + ... + + @property + @abc.abstractmethod + def pairs(self) -> Dict[str, Any]: + ... + + @property + @abc.abstractmethod + def keys(self) -> List[str]: + ... + + +class MetricValue: + """The metric values is the Python representation of a plan metric value from the JVM. + However, it does not have any reference to the original value.""" + + def __init__(self, name: str, value: Union[int, float], type: str): + self._name = name + self._type = type + self._value = value + + def __repr__(self) -> str: + return f"<{self._name}={self._value} ({self._type})>" + + @property + def name(self) -> str: + return self._name + + @property + def value(self) -> Union[int, float]: + return self._value + + @property + def metric_type(self) -> str: + return self._type + + +class PlanMetrics: + """Represents a particular plan node and the associated metrics of this node.""" + + def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]): + self._name = name + self._id = id + self._parent_id = parent + self._metrics = metrics + + def __repr__(self) -> str: + return f"Plan({self._name}: {self._id}->{self._parent_id})={self._metrics}" + + @property + def name(self) -> str: + return self._name + + @property + def plan_id(self) -> int: + return self._id + + @property + def parent_plan_id(self) -> int: + return self._parent_id + + @property + def metrics(self) -> List[MetricValue]: + return self._metrics + + +class CollectedMetrics: + @dataclasses.dataclass + class Node: + id: int + name: str = dataclasses.field(default="") + metrics: List[MetricValue] = dataclasses.field(default_factory=list) + children: List[int] = dataclasses.field(default_factory=list) + + def text(self, current: "Node", graph: Dict[int, "Node"], prefix: str = "") -> str: + """ + Converts the current node and its children into a textual representation. This is used + to provide a usable output for the command line or other text-based interfaces. However, + it is recommended to use the Graphviz representation for a more visual representation. + + Parameters + ---------- + current: Node + Current node in the graph. + graph: dict + A dictionary representing the full graph mapping from node ID (int) to the node itself. + The node is an instance of :class:`CollectedMetrics:Node`. + prefix: str + String prefix used for generating the output buffer. + + Returns + ------- + The full string representation of the current node as root. + """ + base_metrics = set(["numPartitions", "peakMemory", "numOutputRows", "spillSize"]) + + # Format the metrics of this node: + metric_buffer = [] + for m in current.metrics: + if m.name in base_metrics: + metric_buffer.append(f"{m.name}: {m.value} ({m.metric_type})") + + buffer = f"{prefix}+- {current.name}({','.join(metric_buffer)})\n" + for i, child in enumerate(current.children): + c = graph[child] + new_prefix = prefix + " " if i == len(c.children) - 1 else prefix + if current.id != c.id: + buffer += self.text(c, graph, new_prefix) + return buffer + + def __init__(self, metrics: List[PlanMetrics]): + # Sort the input list + self._metrics = sorted(metrics, key=lambda x: x._parent_id, reverse=False) + + def extract_graph(self) -> Tuple[int, Dict[int, "CollectedMetrics.Node"]]: + """ + Builds the graph of the query plan. The graph is represented as a dictionary where the key + is the node ID and the value is the node itself. The root node is the node that has no + parent. + + Returns + ------- + The root node ID and the graph of all nodes. + """ + all_nodes: Dict[int, CollectedMetrics.Node] = {} + + for m in self._metrics: + # Add yourself to the list if you have to. + if m.plan_id not in all_nodes: + all_nodes[m.plan_id] = CollectedMetrics.Node(m.plan_id, m.name, m.metrics) + else: + all_nodes[m.plan_id].name = m.name + all_nodes[m.plan_id].metrics = m.metrics + + # Now check for the parent of this node if it's in + if m.parent_plan_id not in all_nodes: + all_nodes[m.parent_plan_id] = CollectedMetrics.Node(m.parent_plan_id) + + all_nodes[m.parent_plan_id].children.append(m.plan_id) + + # Next step is to find all the root nodes. Root nodes are never used in children. + # So we start with all node ids as candidates. + candidates = set(all_nodes.keys()) + for k, v in all_nodes.items(): + for c in v.children: + if c in candidates and c != k: + candidates.remove(c) + + assert len(candidates) == 1, f"Expected 1 root node, found {len(candidates)}" + return candidates.pop(), all_nodes + + def toText(self) -> str: + """ + Converts the execution graph from a graph into a textual representation + that can be read at the command line for example. + + Returns + ------- + A string representation of the collected metrics. + """ + root, graph = self.extract_graph() + return self.text(graph[root], graph) + + def toDot(self, filename: Optional[str] = None, out_format: str = "png") -> "graphviz.Digraph": + """ + Converts the collected metrics into a dot representation. Since the graphviz Digraph + implementation provides the ability to render the result graph directory in a + notebook, we return the graph object directly. + + If the graphviz package is not available, a PACKAGE_NOT_INSTALLED error is raised. + + Parameters + ---------- + filename : str, optional + The filename to save the graph to given an output format. The path can be + relative or absolute. + + out_format : str + The output format of the graph. The default is 'png'. + + Returns + ------- + An instance of the graphviz.Digraph object. + """ + try: + import graphviz + + dot = graphviz.Digraph( + comment="Query Plan", + node_attr={ + "shape": "box", + "font-size": "10pt", + }, + ) + + root, graph = self.extract_graph() + for k, v in graph.items(): + # Build table rows for the metrics + rows = "\n".join( + [ + ( + f'{x.name}' + f'{x.value} ({x.metric_type})' + ) + for x in v.metrics + ] + ) + + dot.node( + str(k), + """< + + + + + {} +
+ {} +
Metrics
>""".format( + v.name, rows + ), + ) + for c in v.children: + dot.edge(str(k), str(c)) + + if filename: + dot.render(filename, format=out_format, cleanup=True) + return dot + + except ImportError: + raise PySparkValueError( + error_class="PACKAGE_NOT_INSTALLED", + message_parameters={"package_name": "graphviz", "minimum_version": "0.20"}, + ) + + +class ExecutionInfo: + """The query execution class allows users to inspect the query execution of this particular + data frame. This value is only set in the data frame if it was executed.""" + + def __init__( + self, metrics: Optional[list[PlanMetrics]], obs: Optional[Sequence[ObservedMetrics]] + ): + self._metrics = CollectedMetrics(metrics) if metrics else None + self._observations = obs if obs else [] + + @property + def metrics(self) -> Optional[CollectedMetrics]: + return self._metrics + + @property + def flows(self) -> List[Tuple[str, Dict[str, Any]]]: + return [(f.name, f.pairs) for f in self._observations] diff --git a/python/pyspark/sql/tests/connect/test_df_debug.py b/python/pyspark/sql/tests/connect/test_df_debug.py new file mode 100644 index 0000000000000..8a4ec68fda844 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_df_debug.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.testing.connectutils import ( + should_test_connect, + have_graphviz, + graphviz_requirement_message, +) +from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase + +if should_test_connect: + from pyspark.sql.connect.dataframe import DataFrame + + +class SparkConnectDataFrameDebug(SparkConnectSQLTestCase): + def test_df_debug_basics(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + x = df.collect() # noqa: F841 + ei = df.executionInfo + + root, graph = ei.metrics.extract_graph() + self.assertIn(root, graph, "The root must be rooted in the graph") + + def test_df_quey_execution_empty_before_execution(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + ei = df.executionInfo + self.assertIsNone(ei, "The query execution must be None before the action is executed") + + def test_df_query_execution_with_writes(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df.write.save("/tmp/test_df_query_execution_with_writes", format="json", mode="overwrite") + ei = df.executionInfo + self.assertIsNotNone( + ei, "The query execution must be None after the write action is executed" + ) + + def test_query_execution_text_format(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df.collect() + self.assertIn("HashAggregate", df.executionInfo.metrics.toText()) + + # Different execution mode. + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + df.toPandas() + self.assertIn("HashAggregate", df.executionInfo.metrics.toText()) + + @unittest.skipIf(not have_graphviz, graphviz_requirement_message) + def test_df_query_execution_metrics_to_dot(self): + df: DataFrame = self.connect.range(100).repartition(10).groupBy("id").count() + x = df.collect() # noqa: F841 + ei = df.executionInfo + + dot = ei.metrics.toDot() + source = dot.source + self.assertIsNotNone(dot, "The dot representation must not be None") + self.assertGreater(len(source), 0, "The dot representation must not be empty") + self.assertIn("digraph", source, "The dot representation must contain the digraph keyword") + self.assertIn("Metrics", source, "The dot representation must contain the Metrics keyword") + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_df_debug import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index d7b31bbc5215b..c7cf43a334541 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -37,6 +37,7 @@ AnalysisException, IllegalArgumentException, PySparkTypeError, + PySparkValueError, ) from pyspark.testing.sqlutils import ( ReusedSQLTestCase, @@ -849,7 +850,15 @@ def test_checkpoint_dataframe(self): class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase): - pass + def test_query_execution_unsupported_in_classic(self): + with self.assertRaises(PySparkValueError) as pe: + self.spark.range(1).executionInfo + + self.check_error( + exception=pe.exception, + error_class="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF", + message_parameters={"member": "queryExecution"}, + ) if __name__ == "__main__": diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 191505741eb40..b3004693724bf 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -45,6 +45,13 @@ googleapis_common_protos_requirement_message = str(e) have_googleapis_common_protos = googleapis_common_protos_requirement_message is None +graphviz_requirement_message = None +try: + import graphviz +except ImportError as e: + graphviz_requirement_message = str(e) +have_graphviz: bool = graphviz_requirement_message is None + from pyspark import Row, SparkConf from pyspark.util import is_remote_only from pyspark.testing.utils import PySparkErrorTestUtils