From d98ed8ef00979e4e7262c5bf7a0158288a158a69 Mon Sep 17 00:00:00 2001 From: Danijar Hafner Date: Thu, 12 Sep 2024 05:06:09 +0000 Subject: [PATCH] Add proc_alive() util that handles zombies --- portal/__init__.py | 3 ++- portal/process.py | 22 ++++++++-------------- portal/thread.py | 6 ++++-- portal/utils.py | 17 ++++++++++++++++- tests/test_errfile.py | 17 ++++------------- tests/test_process.py | 25 +++++++++---------------- 6 files changed, 43 insertions(+), 47 deletions(-) diff --git a/portal/__init__.py b/portal/__init__.py index 2d9c69d..8ba781a 100644 --- a/portal/__init__.py +++ b/portal/__init__.py @@ -1,4 +1,4 @@ -__version__ = '3.2.1' +__version__ = '3.2.0' import multiprocessing as mp try: @@ -29,4 +29,5 @@ from .sharray import SharedArray from .utils import free_port +from .utils import proc_alive from .utils import run diff --git a/portal/process.py b/portal/process.py index 8e01887..2f2f50b 100644 --- a/portal/process.py +++ b/portal/process.py @@ -37,8 +37,10 @@ def __init__(self, fn, *args, name=None, start=False): name = name or getattr(fn, '__name__', 'process') fn = cloudpickle.dumps(fn) options = contextlib.context.options() + self.ready = contextlib.context.mp.Barrier(2) self.process = contextlib.context.mp.Process( - target=self._wrapper, name=name, args=(options, name, fn, args)) + target=self._wrapper, name=name, + args=(options, self.ready, name, fn, args)) self.started = False self.killed = False self.thepid = None @@ -60,17 +62,10 @@ def running(self): return False if not self.process.is_alive(): return False - try: - os.kill(self.pid, 0) - except OSError as err: - if err.errno == errno.ESRCH: - return False - return True + return utils.proc_alive(self.pid) @property def exitcode(self): - if not self.started or self.running: - return None exitcode = self.process.exitcode if self.killed and exitcode is None: return -9 @@ -80,20 +75,18 @@ def start(self): assert not self.started self.started = True self.process.start() + self.ready.wait() self.thepid = self.process.pid assert self.thepid is not None return self def join(self, timeout=None): assert self.started - if self.running: - self.process.join(timeout) + self.process.join(timeout) return self def kill(self, timeout=1): assert self.started - if not self.running: - return self try: children = list(psutil.Process(self.pid).children(recursive=True)) except psutil.NoSuchProcess: @@ -116,10 +109,11 @@ def __repr__(self): return 'Process(' + ', '.join(attrs) + ')' @staticmethod - def _wrapper(options, name, fn, args): + def _wrapper(options, ready, name, fn, args): exitcode = 0 try: contextlib.setup(**options) + ready.wait() fn = cloudpickle.loads(fn) exitcode = fn(*args) exitcode = exitcode if isinstance(exitcode, int) else 0 diff --git a/portal/thread.py b/portal/thread.py index 32652dd..b5e9028 100644 --- a/portal/thread.py +++ b/portal/thread.py @@ -29,6 +29,7 @@ def __init__(self, fn, *args, name=None, start=False): target=self._wrapper, args=args, name=name, daemon=True) self.thread.children = [] self.started = False + self.ready = threading.Barrier(2) contextlib.context.add_worker(self) start and self.start() @@ -54,11 +55,11 @@ def start(self): assert not self.started self.started = True self.thread.start() + self.ready.wait() return self def join(self, timeout=None): - if self.running: - self.thread.join(timeout) + self.thread.join(timeout) return self def kill(self, timeout=1.0): @@ -75,6 +76,7 @@ def __repr__(self): def _wrapper(self, *args): try: + self.ready.wait() exitcode = self.fn(*args) exitcode = exitcode if isinstance(exitcode, int) else 0 self.excode = exitcode diff --git a/portal/utils.py b/portal/utils.py index 6eb7958..5a1ebb7 100644 --- a/portal/utils.py +++ b/portal/utils.py @@ -1,4 +1,6 @@ import ctypes +import errno +import os import socket import sys import threading @@ -73,9 +75,22 @@ def eachproc(fn, procs): # Should never happen but print warning if any survived. eachproc(lambda p: ( print('Killed subprocess is still alive.') - if p.status() != psutil.STATUS_ZOMBIE else None), procs) + if proc_alive(p.pid) else None), procs) +def proc_alive(pid): + try: + if psutil.Process(pid).status() == psutil.STATUS_ZOMBIE: + return False + except psutil.NoSuchProcess: + return False + try: + os.kill(pid, 0) + except OSError as e: + if e.errno == errno.ESRCH: + return False + assert True + def free_port(): # Return a port that is currently free. This function is not thread or diff --git a/tests/test_errfile.py b/tests/test_errfile.py index 8b01fd3..a67a8a3 100644 --- a/tests/test_errfile.py +++ b/tests/test_errfile.py @@ -107,16 +107,7 @@ def hang_process(ready, queue): assert not worker.running pids = [queue.get() for _ in range(4)] time.sleep(2.0) # On some systems this can take a while. - assert not alive(pids[0]) - assert not alive(pids[1]) - assert not alive(pids[2]) - assert not alive(pids[3]) - - -def alive(pid): - try: - os.kill(pid, 0) - except OSError: - assert True - else: - assert False + assert not portal.proc_alive(pids[0]) + assert not portal.proc_alive(pids[1]) + assert not portal.proc_alive(pids[2]) + assert not portal.proc_alive(pids[3]) diff --git a/tests/test_process.py b/tests/test_process.py index b86b74f..acd8d94 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -48,27 +48,29 @@ def fn(): @pytest.mark.parametrize('repeat', range(5)) def test_kill_with_subproc(self, repeat): - ready = portal.context.mp.Semaphore(0) + ready = portal.context.mp.Barrier(3) queue = portal.context.mp.Queue() + def outer(ready, queue): queue.put(os.getpid()) portal.Process(inner, ready, queue, start=True) - ready.release() + ready.wait() while True: time.sleep(0.1) + def inner(ready, queue): queue.put(os.getpid()) - ready.release() + ready.wait() while True: time.sleep(0.1) + worker = portal.Process(outer, ready, queue, start=True) - ready.acquire() - ready.acquire() + ready.wait() worker.kill() assert not worker.running assert worker.exitcode < 0 - assert not alive(queue.get()) - assert not alive(queue.get()) + assert not portal.proc_alive(queue.get()) + assert not portal.proc_alive(queue.get()) @pytest.mark.parametrize('repeat', range(5)) def test_kill_with_subthread(self, repeat): @@ -103,12 +105,3 @@ def inner(ready): ready.wait() assert ready.is_set() portal.context.initfns.clear() - - -def alive(pid): - try: - os.kill(pid, 0) - except OSError: - assert True - else: - assert False