Skip to content

Commit

Permalink
Use barrier to ensure process is running
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 12, 2024
1 parent 242ebaa commit b119cac
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
11 changes: 6 additions & 5 deletions portal/process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import atexit
import errno
import os
import traceback

import cloudpickle
Expand Down Expand Up @@ -37,8 +36,10 @@ def __init__(self, fn, *args, name=None, start=False):
name = name or getattr(fn, '__name__', 'process')
fn = cloudpickle.dumps(fn)
options = contextlib.context.options()
self.ready = contextlib.context.mp.Barrier(2)
self.process = contextlib.context.mp.Process(
target=self._wrapper, name=name, args=(options, name, fn, args))
target=self._wrapper, name=name,
args=(options, self.ready, name, fn, args))
self.started = False
self.killed = False
self.thepid = None
Expand Down Expand Up @@ -73,6 +74,7 @@ def start(self):
assert not self.started
self.started = True
self.process.start()
self.ready.wait()
self.thepid = self.process.pid
assert self.thepid is not None
return self
Expand All @@ -83,8 +85,6 @@ def join(self, timeout=None):
return self

def kill(self, timeout=1):
# Cannot early exit if process is not running, because it may just be
# starting up.
assert self.started
try:
children = list(psutil.Process(self.pid).children(recursive=True))
Expand All @@ -108,9 +108,10 @@ def __repr__(self):
return 'Process(' + ', '.join(attrs) + ')'

@staticmethod
def _wrapper(options, name, fn, args):
def _wrapper(options, ready, name, fn, args):
exitcode = 0
try:
ready.wait()
contextlib.setup(**options)
fn = cloudpickle.loads(fn)
exitcode = fn(*args)
Expand Down
21 changes: 19 additions & 2 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,39 @@ def test_exitcode(self):
assert worker.exitcode == 42

def test_error(self):

def fn():
raise KeyError('foo')

Check failure on line 21 in tests/test_process.py

View workflow job for this annotation

GitHub Actions / build (3.9)

foo

Check failure on line 21 in tests/test_process.py

View workflow job for this annotation

GitHub Actions / build (3.10)

foo

Check failure on line 21 in tests/test_process.py

View workflow job for this annotation

GitHub Actions / build (3.11)

foo

Check failure on line 21 in tests/test_process.py

View workflow job for this annotation

GitHub Actions / build (3.12)

foo

worker = portal.Process(fn, start=True)
worker.join()
assert not worker.running
assert worker.exitcode == 1

def test_error_with_children(self):

def hang():
while True:
time.sleep(0.1)

def fn():
portal.Process(hang, start=True)
portal.Thread(hang, start=True)
time.sleep(0.1)
raise KeyError('foo')

Check failure on line 38 in tests/test_process.py

View workflow job for this annotation

GitHub Actions / build (3.9)

foo

Check failure on line 38 in tests/test_process.py

View workflow job for this annotation

GitHub Actions / build (3.10)

foo

Check failure on line 38 in tests/test_process.py

View workflow job for this annotation

GitHub Actions / build (3.11)

foo

Check failure on line 38 in tests/test_process.py

View workflow job for this annotation

GitHub Actions / build (3.12)

foo

worker = portal.Process(fn, start=True)
worker.join()
assert not worker.running
assert worker.exitcode == 1

def test_kill(self):
@pytest.mark.parametrize('repeat', range(5))
def test_kill_basic(self, repeat):

def fn():
while True:
time.sleep(0.1)

worker = portal.Process(fn, start=True)
worker.kill()
assert not worker.running
Expand All @@ -52,14 +60,18 @@ def test_kill_with_subproc(self, repeat):
queue = portal.context.mp.Queue()

def outer(ready, queue):
queue.put(os.getpid())
portal.Process(inner, ready, queue, start=True)
queue.put(os.getpid())
queue.close()
queue.join_thread()
ready.wait()
while True:
time.sleep(0.1)

def inner(ready, queue):
queue.put(os.getpid())
queue.close()
queue.join_thread()
ready.wait()
while True:
time.sleep(0.1)
Expand Down Expand Up @@ -90,17 +102,22 @@ def inner(ready):
assert worker.exitcode < 0

def test_initfn(self):

def init():
portal.foo = 42

portal.initfn(init)
ready = portal.context.mp.Event()
assert portal.foo == 42

def outer(ready):
assert portal.foo == 42
portal.Process(inner, ready, start=True).join()

def inner(ready):
assert portal.foo == 42
ready.set()

portal.Process(outer, ready, start=True).join()
ready.wait()
assert ready.is_set()
Expand Down

0 comments on commit b119cac

Please sign in to comment.