diff --git a/src/datatrove/executor/slurm.py b/src/datatrove/executor/slurm.py index 3e74a662..a8133f68 100644 --- a/src/datatrove/executor/slurm.py +++ b/src/datatrove/executor/slurm.py @@ -112,6 +112,7 @@ def __init__( mail_type: str = "ALL", mail_user: str = None, requeue: bool = True, + srun_args: dict = None, tasks_per_job: int = 1, ): super().__init__(pipeline, logging_dir, skip_completed) @@ -139,6 +140,7 @@ def __init__( self.requeue_signals = requeue_signals self.mail_type = mail_type self.mail_user = mail_user + self.srun_args = srun_args self.slurm_logs_folder = ( slurm_logs_folder if slurm_logs_folder @@ -256,9 +258,10 @@ def launch_job(self): max_array = min(nb_jobs_to_launch, self.max_array_size) if self.max_array_size != -1 else nb_jobs_to_launch # create the actual sbatch script + srun_args_str = " ".join([f"--{k}={v}" for k, v in self.srun_args.items()]) if self.srun_args else "" launch_file_contents = self.get_launch_file_contents( self.get_sbatch_args(max_array), - f"srun -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}", + f"srun {srun_args_str} -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}", ) # save it with self.logging_dir.open("launch_script.slurm", "w") as launchscript_f: