From 2548cdfaacb5e77d1fc4cf1403c8d643c7996090 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Thu, 2 Jan 2025 19:15:14 +0000 Subject: [PATCH] manually handle process management --- src/datatrove/pipeline/extractors/base.py | 104 +++++++++++++----- .../pipeline/extractors/trafilatura.py | 2 +- 2 files changed, 76 insertions(+), 30 deletions(-) diff --git a/src/datatrove/pipeline/extractors/base.py b/src/datatrove/pipeline/extractors/base.py index 6aa44b78..bf289799 100644 --- a/src/datatrove/pipeline/extractors/base.py +++ b/src/datatrove/pipeline/extractors/base.py @@ -1,6 +1,5 @@ from abc import abstractmethod -from concurrent.futures import ProcessPoolExecutor -from concurrent.futures.process import BrokenProcessPool +from multiprocessing import Pipe, Process from datatrove.data import DocumentsPipeline from datatrove.pipeline.base import PipelineStep @@ -14,7 +13,7 @@ class BaseExtractor(PipelineStep): type = "🛢 - EXTRAC" @abstractmethod - def __init__(self, timeout: float = 0.1): + def __init__(self, timeout: float = 1): """ Args: @@ -47,42 +46,28 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do Returns: """ - executor = ProcessPoolExecutor(max_workers=1) - try: + with ExtractorSandbox(timeout=self.timeout) as extractor: for doc in data: self.stat_update(StatHints.total) with self.track_time(): - # If submit fails, the pool was already broken from previous task - try: - future = executor.submit(self.extract, doc.text) - except BrokenProcessPool: - logger.warning( - "Found broken process pool, creating new executor and retrying current document" - ) - executor.shutdown(wait=False) - executor = ProcessPoolExecutor(max_workers=1) - self.stat_update("broken_pool") - try: - future = executor.submit(self.extract, doc.text) - except BrokenProcessPool: - logger.error("New pool also broke, skipping document") - continue - try: - doc.text = future.result(timeout=self.timeout) + doc.text = extractor.process_document(doc.text, self.extract) self.stat_update("extracted") except TimeoutError: - future.cancel() - logger.warning("⏰ Timeout while cleaning record text. Skipping record.") self.stat_update("timeout") + logger.warning("⏰ Timeout while cleaning record text. Skipping record.") + continue + except EOFError: + # Process died unexpectedly + self.stat_update("broken_process") + logger.warning("Process died unexpectedly, will create new process for next document") continue except Exception as e: - future.cancel() self.stat_update("clean_error") if not self._warned_error: logger.warning( - f'❌ Error "{e}" while cleaning record text. Skipping record. This message will only ' - f"appear once." + f'❌ Error "{e}" while cleaning record text. Skipping record. ' + f"This message will only appear once." ) self._warned_error = True continue @@ -93,5 +78,66 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do yield doc else: self.stat_update(StatHints.dropped) - finally: - executor.shutdown(wait=False, cancel_futures=True) + + +class ExtractorSandbox: + def __init__(self, timeout): + self.timeout = timeout + self.process = None + self.parent_conn = None + self.child_conn = None + + def _cleanup_process(self): + if self.process is not None: + self.parent_conn.close() + self.child_conn.close() + self.process.terminate() + self.process.join(timeout=0.1) # small clean up window + if self.process.is_alive(): + self.process.kill() + self.process = None + self.parent_conn = None + self.child_conn = None + + def _worker(self, conn, extract_fn): + extract_fn("") # "warmup" + conn.send(None) # ready + while True: + try: + text = conn.recv() + result = extract_fn(text) + conn.send(result) + except EOFError: + break + + def process_document(self, text, extract_fn): + self._ensure_process(extract_fn) + try: + self.parent_conn.send(text) + if self.parent_conn.poll(timeout=self.timeout): + result = self.parent_conn.recv() + if isinstance(result, Exception): + raise result + return result + else: + raise TimeoutError("Document extraction timed out") + except (TimeoutError, EOFError): + self._cleanup_process() + raise + + def _ensure_process(self, extract_fn): + if self.process is None or not self.process.is_alive(): + if self.process is not None: + self._cleanup_process() + + self.parent_conn, self.child_conn = Pipe() + self.process = Process(target=self._worker, args=(self.child_conn, extract_fn)) + self.process.start() + self.parent_conn.recv() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._cleanup_process() + return False diff --git a/src/datatrove/pipeline/extractors/trafilatura.py b/src/datatrove/pipeline/extractors/trafilatura.py index dbebd62c..bebaddb8 100644 --- a/src/datatrove/pipeline/extractors/trafilatura.py +++ b/src/datatrove/pipeline/extractors/trafilatura.py @@ -23,7 +23,7 @@ def __init__( self, favour_precision: bool = True, include_images: bool = False, - timeout: float = 0.1, + timeout: float = 1, deduplicate: bool = True, **kwargs, ):