Skip to content

Commit

Permalink
Use barrier to ensure process is running
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 12, 2024
1 parent 242ebaa commit d3bfbba
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
11 changes: 6 additions & 5 deletions portal/process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import atexit
import errno
import os
import traceback

import cloudpickle
Expand Down Expand Up @@ -37,8 +36,10 @@ def __init__(self, fn, *args, name=None, start=False):
name = name or getattr(fn, '__name__', 'process')
fn = cloudpickle.dumps(fn)
options = contextlib.context.options()
self.ready = contextlib.context.mp.Barrier(2)
self.process = contextlib.context.mp.Process(
target=self._wrapper, name=name, args=(options, name, fn, args))
target=self._wrapper, name=name,
args=(options, self.ready, name, fn, args))
self.started = False
self.killed = False
self.thepid = None
Expand Down Expand Up @@ -73,6 +74,7 @@ def start(self):
assert not self.started
self.started = True
self.process.start()
self.ready.wait()
self.thepid = self.process.pid
assert self.thepid is not None
return self
Expand All @@ -83,8 +85,6 @@ def join(self, timeout=None):
return self

def kill(self, timeout=1):
# Cannot early exit if process is not running, because it may just be
# starting up.
assert self.started
try:
children = list(psutil.Process(self.pid).children(recursive=True))
Expand All @@ -108,9 +108,10 @@ def __repr__(self):
return 'Process(' + ', '.join(attrs) + ')'

@staticmethod
def _wrapper(options, name, fn, args):
def _wrapper(options, ready, name, fn, args):
exitcode = 0
try:
ready.wait()
contextlib.setup(**options)
fn = cloudpickle.loads(fn)
exitcode = fn(*args)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def fn():
assert not worker.running
assert worker.exitcode == 1

def test_kill(self):
@pytest.mark.parametrize('repeat', range(5))
def test_kill_basic(self, repeat):
def fn():
while True:
time.sleep(0.1)
Expand Down

0 comments on commit d3bfbba

Please sign in to comment.