diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 34c7d955fedd8..50d0358004d40 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -142,7 +142,7 @@ class SparkEnv ( workerModule: String, daemonModule: String, envVars: Map[String, String], - useDaemon: Boolean): (PythonWorker, Option[Int]) = { + useDaemon: Boolean): (PythonWorker, Option[Long]) = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) val workerFactory = pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory( @@ -161,7 +161,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, envVars: Map[String, String], - useDaemon: Boolean): (PythonWorker, Option[Int]) = { + useDaemon: Boolean): (PythonWorker, Option[Long]) = { createPythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, useDaemon) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 7ff782db210d3..17cb0c5a55ddf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -88,7 +88,7 @@ private object BasePythonRunner { private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") - private def faultHandlerLogPath(pid: Int): Path = { + private def faultHandlerLogPath(pid: Long): Path = { new File(faultHandlerLogDir, pid.toString).toPath } } @@ -204,7 +204,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker( + val (worker: PythonWorker, pid: Option[Long]) = env.createPythonWorker( pythonExec, workerModule, daemonModule, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make @@ -257,7 +257,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] @@ -465,7 +465,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext) extends Iterator[OUT] { @@ -842,7 +842,7 @@ private[spark] class PythonRunner( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 875cf6bc27709..eb740b72987c8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -92,7 +92,7 @@ private[spark] class PythonWorkerFactory( envVars.getOrElse("PYTHONPATH", ""), sys.env.getOrElse("PYTHONPATH", "")) - def create(): (PythonWorker, Option[Int]) = { + def create(): (PythonWorker, Option[Long]) = { if (useDaemon) { self.synchronized { // Pull from idle workers until we one that is alive, otherwise create a new one. @@ -122,9 +122,9 @@ private[spark] class PythonWorkerFactory( * processes itself to avoid the high cost of forking from Java. This currently only works * on UNIX-based systems. */ - private def createThroughDaemon(): (PythonWorker, Option[Int]) = { + private def createThroughDaemon(): (PythonWorker, Option[Long]) = { - def createWorker(): (PythonWorker, Option[Int]) = { + def createWorker(): (PythonWorker, Option[Long]) = { val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)) // These calls are blocking. val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() @@ -165,7 +165,7 @@ private[spark] class PythonWorkerFactory( /** * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ - private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = { + private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Long]) = { var serverSocketChannel: ServerSocketChannel = null try { serverSocketChannel = ServerSocketChannel.open() @@ -209,8 +209,7 @@ private[spark] class PythonWorkerFactory( "Timed out while waiting for the Python worker to connect back") } authHelper.authClient(socketChannel.socket()) - // TODO: When we drop JDK 8, we can just use workerProcess.pid() - val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() + val pid = workerProcess.toHandle.pid() if (pid < 0) { throw new IllegalStateException("Python failed to launch worker with code " + pid) } diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index bbbc495d053ed..b0e06d13beda7 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -28,7 +28,7 @@ from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT -from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer +from pyspark.serializers import read_long, write_int, write_with_length, UTF8Deserializer if len(sys.argv) > 1: import importlib @@ -139,7 +139,7 @@ def handle_sigterm(*args): if 0 in ready_fds: try: - worker_pid = read_int(stdin_bin) + worker_pid = read_long(stdin_bin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py index 0c92de6372b6f..80cc691269163 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py @@ -96,6 +96,4 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def] (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) # There could be a long time between each micro batch. sock.settimeout(None) - write_int(os.getpid(), sock_file) - sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index a7a5066ca0d77..3709e50ba0261 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -110,6 +110,4 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) # There could be a long time between each listener event. sock.settimeout(None) - write_int(os.getpid(), sock_file) - sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index 7dafb87c42211..d0a24363c0c1e 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -264,7 +264,4 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) - # TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8. - write_int(os.getpid(), sock_file) - sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e9c259e68a27a..41f6c35bc4452 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1868,7 +1868,4 @@ def process(): java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) - # TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8. - write_int(os.getpid(), sock_file) - sock_file.flush() main(sock_file, sock_file) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index e7d4aa9f04607..90922d89ad10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -49,7 +49,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 87ff5a0ec4333..bbe9fbfc748db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -80,7 +80,7 @@ abstract class BasePythonUDFRunner( startTime: Long, env: SparkEnv, worker: PythonWorker, - pid: Option[Int], + pid: Option[Long], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator(