From b4751feb95a8c1528b020b2eed60a9628727823b Mon Sep 17 00:00:00 2001 From: Muhammad Nabi Yasinzai Date: Mon, 28 Oct 2024 23:02:41 +1300 Subject: [PATCH] Reformatted code using black to fix linting issues --- .../brain_extraction/brain_extractor.py | 17 ++++--- brainles_preprocessing/defacing/defacer.py | 17 ++++--- brainles_preprocessing/modality.py | 50 ++++++++++++------- brainles_preprocessing/preprocessor.py | 41 +++++++++------ .../registration/ANTs/ANTs.py | 21 ++++---- 5 files changed, 88 insertions(+), 58 deletions(-) diff --git a/brainles_preprocessing/brain_extraction/brain_extractor.py b/brainles_preprocessing/brain_extraction/brain_extractor.py index 03118aa..b90441e 100644 --- a/brainles_preprocessing/brain_extraction/brain_extractor.py +++ b/brainles_preprocessing/brain_extraction/brain_extractor.py @@ -1,5 +1,5 @@ # TODO add typing and docs -import shutil +import shutil from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Union @@ -10,8 +10,9 @@ class Mode(Enum): - FAST = 'fast' - ACCURATE = 'accurate' + FAST = "fast" + ACCURATE = "accurate" + class BrainExtractor: @abstractmethod @@ -51,7 +52,6 @@ def apply_mask( mask_path (str or Path): Path to the brain mask image (NIfTI format). bet_image_path (str or Path): Path to save the resulting masked image (NIfTI format). """ - try: # Read data @@ -105,11 +105,11 @@ def extract( device (str or int): Device to use for computation (e.g., 0 for GPU 0, 'cpu' for CPU). do_tta (bool): whether to do test time data augmentation by mirroring along all axes. """ - + # Ensure mode is a Mode enum instance if isinstance(mode, str): try: - mode_enum = Mode(mode.lower()) + mode_enum = Mode(mode.lower()) except ValueError: raise ValueError(f"'{mode}' is not a valid Mode.") elif isinstance(mode, Mode): @@ -132,7 +132,9 @@ def extract( # Construct the path to the generated mask masked_image_path = Path(masked_image_path) - hdbet_mask_path = masked_image_path.with_name(masked_image_path.name.replace('.nii.gz', '_mask.nii.gz')) + hdbet_mask_path = masked_image_path.with_name( + masked_image_path.name.replace(".nii.gz", "_mask.nii.gz") + ) if hdbet_mask_path.resolve() != Path(brain_mask_path).resolve(): try: @@ -142,4 +144,3 @@ def extract( ) except Exception as e: raise RuntimeError(f"Error copying mask file: {e}") from e - diff --git a/brainles_preprocessing/defacing/defacer.py b/brainles_preprocessing/defacing/defacer.py index 9114073..454cf93 100644 --- a/brainles_preprocessing/defacing/defacer.py +++ b/brainles_preprocessing/defacing/defacer.py @@ -12,6 +12,7 @@ class Defacer(ABC): Subclasses should implement the `deface` method to generate a defaced image based on the provided input image and mask. """ + @abstractmethod def deface( self, @@ -41,23 +42,27 @@ def apply_mask( mask_path (str or Path): Path to the brain mask image (NIfTI format). defaced_image_path (str or Path): Path to save the resulting defaced image (NIfTI format). """ - + if not input_image_path.is_file(): - raise FileNotFoundError(f"Input image file does not exist: {input_image_path}") + raise FileNotFoundError( + f"Input image file does not exist: {input_image_path}" + ) if not mask_path.is_file(): raise FileNotFoundError(f"Mask file does not exist: {mask_path}") - + try: # Read data input_data = read_nifti(str(input_image_path)) mask_data = read_nifti(str(mask_path)) except Exception as e: - raise RuntimeError(f"An error occurred while reading input files: {e}") from e - + raise RuntimeError( + f"An error occurred while reading input files: {e}" + ) from e + # Check that the input and mask have the same shape if input_data.shape != mask_data.shape: raise ValueError("Input image and mask must have the same dimensions.") - + # Apply mask (element-wise multiplication) masked_data = input_data * mask_data diff --git a/brainles_preprocessing/modality.py b/brainles_preprocessing/modality.py index 57de37f..4e311bd 100644 --- a/brainles_preprocessing/modality.py +++ b/brainles_preprocessing/modality.py @@ -58,12 +58,12 @@ def __init__( modality_name: str, input_path: Union[str, Path], normalizer: Optional[Normalizer] = None, - raw_bet_output_path: Optional[Union[str, Path]] = None, - raw_skull_output_path: Optional[Union[str, Path]] = None, - raw_defaced_output_path: Optional[Union[str, Path]] = None, - normalized_bet_output_path: Optional[Union[str, Path]] = None, - normalized_skull_output_path: Optional[Union[str, Path]] = None, - normalized_defaced_output_path: Optional[Union[str, Path]] = None, + raw_bet_output_path: Optional[Union[str, Path]] = None, + raw_skull_output_path: Optional[Union[str, Path]] = None, + raw_defaced_output_path: Optional[Union[str, Path]] = None, + normalized_bet_output_path: Optional[Union[str, Path]] = None, + normalized_skull_output_path: Optional[Union[str, Path]] = None, + normalized_defaced_output_path: Optional[Union[str, Path]] = None, atlas_correction: bool = True, ) -> None: # Basics @@ -89,9 +89,15 @@ def __init__( ) # handle input paths - self.raw_bet_output_path = Path(raw_bet_output_path) if raw_bet_output_path else None - self.raw_skull_output_path = Path(raw_skull_output_path) if raw_skull_output_path else None - self.raw_defaced_output_path = Path(raw_defaced_output_path) if raw_defaced_output_path else None + self.raw_bet_output_path = ( + Path(raw_bet_output_path) if raw_bet_output_path else None + ) + self.raw_skull_output_path = ( + Path(raw_skull_output_path) if raw_skull_output_path else None + ) + self.raw_defaced_output_path = ( + Path(raw_defaced_output_path) if raw_defaced_output_path else None + ) if normalized_bet_output_path: if normalizer is None: @@ -163,7 +169,9 @@ def normalize( store_unnormalized.mkdir(parents=True, exist_ok=True) shutil.copyfile( src=str(self.current), - dst=str(store_unnormalized / f"unnormalized__{self.modality_name}.nii.gz"), + dst=str( + store_unnormalized / f"unnormalized__{self.modality_name}.nii.gz" + ), ) if temporary_directory: @@ -171,7 +179,9 @@ def normalize( unnormalized_dir.mkdir(parents=True, exist_ok=True) shutil.copyfile( src=str(self.current), - dst=str(unnormalized_dir / f"unnormalized__{self.modality_name}.nii.gz"), + dst=str( + unnormalized_dir / f"unnormalized__{self.modality_name}.nii.gz" + ), ) # Normalize the image @@ -208,13 +218,13 @@ def register( """ fixed_image_path = Path(fixed_image_path) registration_dir = Path(registration_dir) - + registered = registration_dir / f"{moving_image_name}.nii.gz" registered_log = registration_dir / f"{moving_image_name}.log" - + # Note, add file ending depending on registration backend! registered_matrix = registration_dir / f"{moving_image_name}" - + registrator.register( fixed_image_path=fixed_image_path, moving_image_path=self.current, @@ -277,9 +287,11 @@ def apply_deface_mask( deface_dir = Path(deface_dir) defaced_img = deface_dir / f"atlas__{self.modality_name}_defaced.nii.gz" input_img = self.steps[ - PreprocessorSteps.ATLAS_CORRECTED - if self.atlas_correction - else PreprocessorSteps.ATLAS_REGISTERED + ( + PreprocessorSteps.ATLAS_CORRECTED + if self.atlas_correction + else PreprocessorSteps.ATLAS_REGISTERED + ) ] defacer.apply_mask( input_image_path=input_img, @@ -381,7 +393,7 @@ def deface( """ if isinstance(defacer, QuickshearDefacer): - + defaced_dir_path = Path(defaced_dir_path) atlas_mask_path = ( defaced_dir_path / f"atlas__{self.modality_name}_deface_mask.nii.gz" @@ -430,4 +442,4 @@ def save_current_image( shutil.copyfile( src=str(self.current), dst=str(output_path), - ) \ No newline at end of file + ) diff --git a/brainles_preprocessing/preprocessor.py b/brainles_preprocessing/preprocessor.py index 3e2d8f9..7115977 100644 --- a/brainles_preprocessing/preprocessor.py +++ b/brainles_preprocessing/preprocessor.py @@ -55,12 +55,14 @@ def __init__( self.center_modality = center_modality self.moving_modalities = moving_modalities - + if atlas_image_path is None: - self.atlas_image_path = Path(__file__).parent / "registration" / "atlas" / "t1_brats_space.nii" + self.atlas_image_path = ( + Path(__file__).parent / "registration" / "atlas" / "t1_brats_space.nii" + ) else: self.atlas_image_path = Path(atlas_image_path) - + self.registrator = registrator if self.registrator is None: logger.warning( @@ -86,7 +88,7 @@ def __init__( self.atlas_dir = self.temp_folder / "atlas-space" self.atlas_dir.mkdir(exist_ok=True, parents=True) - + def _configure_gpu( self, use_gpu: Optional[bool], limit_cuda_visible_devices: Optional[str] = None ) -> None: @@ -249,7 +251,9 @@ def run( self._set_log_file(log_file=log_file) logger.info(f"{' Starting preprocessing ':=^80}") logger.info(f"Logs are saved to {self.log_file_handler.baseFilename}") - modality_names = ', '.join([modality.modality_name for modality in self.moving_modalities]) + modality_names = ", ".join( + [modality.modality_name for modality in self.moving_modalities] + ) logger.info( f"Received center modality: {self.center_modality.modality_name} " f"and moving modalities: {modality_names}" @@ -298,17 +302,19 @@ def run( self.run_brain_extraction( save_dir_brain_extraction=save_dir_brain_extraction, ) - + # Defacing logger.info(f"{' Checking optional defacing ':-^80}") self.run_defacing( save_dir_defacing=save_dir_defacing, ) - + # End logger.info(f"{' Preprocessing complete ':=^80}") - def run_coregistration(self, save_dir_coregistration: Optional[Union[str, Path]] = None) -> None: + def run_coregistration( + self, save_dir_coregistration: Optional[Union[str, Path]] = None + ) -> None: """ Coregister moving modalities to center modality. @@ -336,8 +342,10 @@ def run_coregistration(self, save_dir_coregistration: Optional[Union[str, Path]] shutil.copyfile( src=str(self.center_modality.input_path), - dst=str(coregistration_dir / - f"native__{self.center_modality.modality_name}.nii.gz"), + dst=str( + coregistration_dir + / f"native__{self.center_modality.modality_name}.nii.gz" + ), ) self._save_output( @@ -417,8 +425,11 @@ def run_atlas_correction( ) if self.center_modality.atlas_correction: - center_atlas_corrected_path = atlas_correction_dir / f"atlas_corrected__{self.center_modality.modality_name}.nii.gz" - + center_atlas_corrected_path = ( + atlas_correction_dir + / f"atlas_corrected__{self.center_modality.modality_name}.nii.gz" + ) + shutil.copyfile( src=str(self.center_modality.current), dst=str(center_atlas_corrected_path), @@ -503,7 +514,9 @@ def run_brain_extraction( normalization=True, ) - def run_defacing(self, save_dir_defacing: Optional[Union[str, Path]] = None) -> None: + def run_defacing( + self, save_dir_defacing: Optional[Union[str, Path]] = None + ) -> None: """Deface images to remove facial features using specified Defacer. Args: @@ -563,7 +576,7 @@ def run_defacing(self, save_dir_defacing: Optional[Union[str, Path]] = None) -> def _save_output( self, - src: Union[str, Path], + src: Union[str, Path], save_dir: Optional[Union[str, Path]], ): """ diff --git a/brainles_preprocessing/registration/ANTs/ANTs.py b/brainles_preprocessing/registration/ANTs/ANTs.py index bc73a8e..5bde6d5 100644 --- a/brainles_preprocessing/registration/ANTs/ANTs.py +++ b/brainles_preprocessing/registration/ANTs/ANTs.py @@ -60,11 +60,11 @@ def register( **kwargs: Additional registration parameters to update the instantiated defaults. """ start_time = datetime.datetime.now() - + # TODO - self.registration_params # We update the registration parameters with the provided kwargs registration_kwargs = {**self.registration_params, **kwargs} - + # Convert all paths to Path objects fixed_image_path = Path(fixed_image_path) moving_image_path = Path(moving_image_path) @@ -76,7 +76,7 @@ def register( raise FileNotFoundError(f"Fixed image not found: {fixed_image_path}") if not moving_image_path.is_file(): raise FileNotFoundError(f"Moving image not found: {moving_image_path}") - + # Ensure matrix_path has .mat suffix if matrix_path.suffix != ".mat": matrix_path = matrix_path.with_suffix(".mat") @@ -89,17 +89,17 @@ def register( **registration_kwargs, ) transformed_image = registration_result["warpedmovout"] - + # Ensure output directories exist transformed_image_path.parent.mkdir(parents=True, exist_ok=True) matrix_path.parent.mkdir(parents=True, exist_ok=True) - + ants.image_write(transformed_image, str(transformed_image_path)) shutil.copyfile( - src=registration_result["fwdtransforms"][0], - dst=str(matrix_path), - ) + src=registration_result["fwdtransforms"][0], + dst=str(matrix_path), + ) end_time = datetime.datetime.now() @@ -142,7 +142,7 @@ def transform( # TODO - self.transformation_params # we update the transformation parameters with the provided kwargs transform_kwargs = {**self.transformation_params, **kwargs} - + # Convert all paths to Path objects fixed_image_path = Path(fixed_image_path) moving_image_path = Path(moving_image_path) @@ -155,10 +155,9 @@ def transform( if not moving_image_path.is_file(): raise FileNotFoundError(f"Moving image not found: {moving_image_path}") - fixed_image = ants.image_read(str(fixed_image_path)) moving_image = ants.image_read(str(moving_image_path)) - + # Ensure output directory exist transformed_image_path.parent.mkdir(parents=True, exist_ok=True)