Skip to content

Commit

Permalink
manually handle process management
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Jan 2, 2025
1 parent e1c7cec commit 2548cdf
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 30 deletions.
104 changes: 75 additions & 29 deletions src/datatrove/pipeline/extractors/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import abstractmethod
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from multiprocessing import Pipe, Process

from datatrove.data import DocumentsPipeline
from datatrove.pipeline.base import PipelineStep
Expand All @@ -14,7 +13,7 @@ class BaseExtractor(PipelineStep):
type = "🛢 - EXTRAC"

@abstractmethod
def __init__(self, timeout: float = 0.1):
def __init__(self, timeout: float = 1):
"""
Args:
Expand Down Expand Up @@ -47,42 +46,28 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do
Returns:
"""
executor = ProcessPoolExecutor(max_workers=1)
try:
with ExtractorSandbox(timeout=self.timeout) as extractor:
for doc in data:
self.stat_update(StatHints.total)
with self.track_time():
# If submit fails, the pool was already broken from previous task
try:
future = executor.submit(self.extract, doc.text)
except BrokenProcessPool:
logger.warning(
"Found broken process pool, creating new executor and retrying current document"
)
executor.shutdown(wait=False)
executor = ProcessPoolExecutor(max_workers=1)
self.stat_update("broken_pool")
try:
future = executor.submit(self.extract, doc.text)
except BrokenProcessPool:
logger.error("New pool also broke, skipping document")
continue

try:
doc.text = future.result(timeout=self.timeout)
doc.text = extractor.process_document(doc.text, self.extract)
self.stat_update("extracted")
except TimeoutError:
future.cancel()
logger.warning("⏰ Timeout while cleaning record text. Skipping record.")
self.stat_update("timeout")
logger.warning("⏰ Timeout while cleaning record text. Skipping record.")
continue
except EOFError:
# Process died unexpectedly
self.stat_update("broken_process")
logger.warning("Process died unexpectedly, will create new process for next document")
continue
except Exception as e:
future.cancel()
self.stat_update("clean_error")
if not self._warned_error:
logger.warning(
f'❌ Error "{e}" while cleaning record text. Skipping record. This message will only '
f"appear once."
f'❌ Error "{e}" while cleaning record text. Skipping record. '
f"This message will only appear once."
)
self._warned_error = True
continue
Expand All @@ -93,5 +78,66 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do
yield doc
else:
self.stat_update(StatHints.dropped)
finally:
executor.shutdown(wait=False, cancel_futures=True)


class ExtractorSandbox:
def __init__(self, timeout):
self.timeout = timeout
self.process = None
self.parent_conn = None
self.child_conn = None

def _cleanup_process(self):
if self.process is not None:
self.parent_conn.close()
self.child_conn.close()
self.process.terminate()
self.process.join(timeout=0.1) # small clean up window
if self.process.is_alive():
self.process.kill()
self.process = None
self.parent_conn = None
self.child_conn = None

def _worker(self, conn, extract_fn):
extract_fn("") # "warmup"
conn.send(None) # ready
while True:
try:
text = conn.recv()
result = extract_fn(text)
conn.send(result)
except EOFError:
break

def process_document(self, text, extract_fn):
self._ensure_process(extract_fn)
try:
self.parent_conn.send(text)
if self.parent_conn.poll(timeout=self.timeout):
result = self.parent_conn.recv()
if isinstance(result, Exception):
raise result
return result
else:
raise TimeoutError("Document extraction timed out")
except (TimeoutError, EOFError):
self._cleanup_process()
raise

def _ensure_process(self, extract_fn):
if self.process is None or not self.process.is_alive():
if self.process is not None:
self._cleanup_process()

self.parent_conn, self.child_conn = Pipe()
self.process = Process(target=self._worker, args=(self.child_conn, extract_fn))
self.process.start()
self.parent_conn.recv()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._cleanup_process()
return False
2 changes: 1 addition & 1 deletion src/datatrove/pipeline/extractors/trafilatura.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
self,
favour_precision: bool = True,
include_images: bool = False,
timeout: float = 0.1,
timeout: float = 1,
deduplicate: bool = True,
**kwargs,
):
Expand Down

0 comments on commit 2548cdf

Please sign in to comment.