Skip to content

Commit

Permalink
Revert "Revert "[SPARK-45302][PYTHON] Remove PID communication betwee…
Browse files Browse the repository at this point in the history
…n Python workers when no demon is used""

This reverts commit e8f529b.
  • Loading branch information
HyukjinKwon committed Apr 24, 2024
1 parent e8f529b commit b3c11ef
Show file tree
Hide file tree
Showing 10 changed files with 16 additions and 27 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class SparkEnv (
workerModule: String,
daemonModule: String,
envVars: Map[String, String],
useDaemon: Boolean): (PythonWorker, Option[Int]) = {
useDaemon: Boolean): (PythonWorker, Option[Long]) = {
synchronized {
val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars)
val workerFactory = pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(
Expand All @@ -161,7 +161,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
envVars: Map[String, String],
useDaemon: Boolean): (PythonWorker, Option[Int]) = {
useDaemon: Boolean): (PythonWorker, Option[Long]) = {
createPythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, useDaemon)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private object BasePythonRunner {

private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")

private def faultHandlerLogPath(pid: Int): Path = {
private def faultHandlerLogPath(pid: Long): Path = {
new File(faultHandlerLogDir, pid.toString).toPath
}
}
Expand Down Expand Up @@ -204,7 +204,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](

envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))

val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker(
val (worker: PythonWorker, pid: Option[Long]) = env.createPythonWorker(
pythonExec, workerModule, daemonModule, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
Expand Down Expand Up @@ -257,7 +257,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT]

Expand Down Expand Up @@ -465,7 +465,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext)
extends Iterator[OUT] {
Expand Down Expand Up @@ -842,7 +842,7 @@ private[spark] class PythonRunner(
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private[spark] class PythonWorkerFactory(
envVars.getOrElse("PYTHONPATH", ""),
sys.env.getOrElse("PYTHONPATH", ""))

def create(): (PythonWorker, Option[Int]) = {
def create(): (PythonWorker, Option[Long]) = {
if (useDaemon) {
self.synchronized {
// Pull from idle workers until we one that is alive, otherwise create a new one.
Expand Down Expand Up @@ -122,9 +122,9 @@ private[spark] class PythonWorkerFactory(
* processes itself to avoid the high cost of forking from Java. This currently only works
* on UNIX-based systems.
*/
private def createThroughDaemon(): (PythonWorker, Option[Int]) = {
private def createThroughDaemon(): (PythonWorker, Option[Long]) = {

def createWorker(): (PythonWorker, Option[Int]) = {
def createWorker(): (PythonWorker, Option[Long]) = {
val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
// These calls are blocking.
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
Expand Down Expand Up @@ -165,7 +165,7 @@ private[spark] class PythonWorkerFactory(
/**
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
*/
private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = {
private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Long]) = {
var serverSocketChannel: ServerSocketChannel = null
try {
serverSocketChannel = ServerSocketChannel.open()
Expand Down Expand Up @@ -209,8 +209,7 @@ private[spark] class PythonWorkerFactory(
"Timed out while waiting for the Python worker to connect back")
}
authHelper.authClient(socketChannel.socket())
// TODO: When we drop JDK 8, we can just use workerProcess.pid()
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
val pid = workerProcess.toHandle.pid()
if (pid < 0) {
throw new IllegalStateException("Python failed to launch worker with code " + pid)
}
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT

from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
from pyspark.serializers import read_long, write_int, write_with_length, UTF8Deserializer

if len(sys.argv) > 1:
import importlib
Expand Down Expand Up @@ -139,7 +139,7 @@ def handle_sigterm(*args):

if 0 in ready_fds:
try:
worker_pid = read_int(stdin_bin)
worker_pid = read_long(stdin_bin)
except EOFError:
# Spark told us to exit by closing stdin
shutdown(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,4 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def]
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
# There could be a long time between each micro batch.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,4 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
# There could be a long time between each listener event.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
3 changes: 0 additions & 3 deletions python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,4 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
3 changes: 0 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,4 @@ def process():
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT] = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ abstract class BasePythonUDFRunner(
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(
Expand Down

0 comments on commit b3c11ef

Please sign in to comment.