Skip to content

Commit

Permalink
Add proc_alive() util that handles zombies
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 12, 2024
1 parent 6770465 commit 4cdf974
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 56 deletions.
3 changes: 2 additions & 1 deletion portal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '3.2.1'
__version__ = '3.2.0'

import multiprocessing as mp
try:
Expand Down Expand Up @@ -29,4 +29,5 @@
from .sharray import SharedArray

from .utils import free_port
from .utils import proc_alive
from .utils import run
18 changes: 10 additions & 8 deletions portal/contextlib.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import multiprocessing as mp
import os
import pathlib
import signal
import sys
import threading
import traceback

import cloudpickle
import psutil

from . import process
from . import utils


Expand Down Expand Up @@ -122,12 +121,15 @@ def error(self, e, name=None):
print(f'Wrote errorfile: {self.errfile}', file=sys.stderr)

def shutdown(self, exitcode):
if exitcode == 0:
for child in self.children(threading.main_thread()):
child.kill()
os._exit(0)
else:
os._exit(exitcode)
children = list(psutil.Process(os.getpid()).children(recursive=True))
utils.kill_proc(children, timeout=1)
# TODO
# if exitcode == 0:
# for child in self.children(threading.main_thread()):
# child.kill()
# os._exit(0)
# else:
os._exit(exitcode)

def close(self):
self.done.set()
Expand Down
18 changes: 4 additions & 14 deletions portal/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,10 @@ def running(self):
return False
if not self.process.is_alive():
return False
try:
os.kill(self.pid, 0)
except OSError as err:
if err.errno == errno.ESRCH:
return False
return True
return utils.proc_alive(self.pid)

@property
def exitcode(self):
if not self.started or self.running:
return None
exitcode = self.process.exitcode
if self.killed and exitcode is None:
return -9
Expand All @@ -86,14 +79,13 @@ def start(self):

def join(self, timeout=None):
assert self.started
if self.running:
self.process.join(timeout)
self.process.join(timeout)
return self

def kill(self, timeout=1):
# Cannot early exit if process is not running, because it may just be
# starting up.
assert self.started
if not self.running:
return self
try:
children = list(psutil.Process(self.pid).children(recursive=True))
except psutil.NoSuchProcess:
Expand Down Expand Up @@ -132,6 +124,4 @@ def _wrapper(options, name, fn, args):
contextlib.context.error(e, name)
exitcode = 1
finally:
children = list(psutil.Process(os.getpid()).children(recursive=True))
utils.kill_proc(children, timeout=1)
contextlib.context.shutdown(exitcode)
6 changes: 4 additions & 2 deletions portal/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, fn, *args, name=None, start=False):
target=self._wrapper, args=args, name=name, daemon=True)
self.thread.children = []
self.started = False
self.ready = threading.Barrier(2)
contextlib.context.add_worker(self)
start and self.start()

Expand All @@ -54,11 +55,11 @@ def start(self):
assert not self.started
self.started = True
self.thread.start()
self.ready.wait()
return self

def join(self, timeout=None):
if self.running:
self.thread.join(timeout)
self.thread.join(timeout)
return self

def kill(self, timeout=1.0):
Expand All @@ -75,6 +76,7 @@ def __repr__(self):

def _wrapper(self, *args):
try:
self.ready.wait()
exitcode = self.fn(*args)
exitcode = exitcode if isinstance(exitcode, int) else 0
self.excode = exitcode
Expand Down
17 changes: 16 additions & 1 deletion portal/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ctypes
import errno
import os
import socket
import sys
import threading
Expand Down Expand Up @@ -73,9 +75,22 @@ def eachproc(fn, procs):
# Should never happen but print warning if any survived.
eachproc(lambda p: (
print('Killed subprocess is still alive.')
if p.status() != psutil.STATUS_ZOMBIE else None), procs)
if proc_alive(p.pid) else None), procs)


def proc_alive(pid):
try:
if psutil.Process(pid).status() == psutil.STATUS_ZOMBIE:
return False
except psutil.NoSuchProcess:
return False
try:
os.kill(pid, 0)
except OSError as e:
if e.errno == errno.ESRCH:
return False
return True


def free_port():
# Return a port that is currently free. This function is not thread or
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def fn(x):
future.result(timeout=0.01)
with pytest.raises(TimeoutError):
future.result(timeout=0)
assert future.result(timeout=0.2) == 42
assert future.result(timeout=1) == 42
client.close()
server.close()

Expand Down
17 changes: 4 additions & 13 deletions tests/test_errfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,7 @@ def hang_process(ready, queue):
assert not worker.running
pids = [queue.get() for _ in range(4)]
time.sleep(2.0) # On some systems this can take a while.
assert not alive(pids[0])
assert not alive(pids[1])
assert not alive(pids[2])
assert not alive(pids[3])


def alive(pid):
try:
os.kill(pid, 0)
except OSError:
assert True
else:
assert False
assert not portal.proc_alive(pids[0])
assert not portal.proc_alive(pids[1])
assert not portal.proc_alive(pids[2])
assert not portal.proc_alive(pids[3])
25 changes: 9 additions & 16 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,29 @@ def fn():

@pytest.mark.parametrize('repeat', range(5))
def test_kill_with_subproc(self, repeat):
ready = portal.context.mp.Semaphore(0)
ready = portal.context.mp.Barrier(3)
queue = portal.context.mp.Queue()

def outer(ready, queue):
queue.put(os.getpid())
portal.Process(inner, ready, queue, start=True)
ready.release()
ready.wait()
while True:
time.sleep(0.1)

def inner(ready, queue):
queue.put(os.getpid())
ready.release()
ready.wait()
while True:
time.sleep(0.1)

worker = portal.Process(outer, ready, queue, start=True)
ready.acquire()
ready.acquire()
ready.wait()
worker.kill()
assert not worker.running
assert worker.exitcode < 0
assert not alive(queue.get())
assert not alive(queue.get())
assert not portal.proc_alive(queue.get())
assert not portal.proc_alive(queue.get())

@pytest.mark.parametrize('repeat', range(5))
def test_kill_with_subthread(self, repeat):
Expand Down Expand Up @@ -103,12 +105,3 @@ def inner(ready):
ready.wait()
assert ready.is_set()
portal.context.initfns.clear()


def alive(pid):
try:
os.kill(pid, 0)
except OSError:
assert True
else:
assert False

0 comments on commit 4cdf974

Please sign in to comment.