diff --git a/brainles_preprocessing/brain_extraction/brain_extractor.py b/brainles_preprocessing/brain_extraction/brain_extractor.py index 3a65426..eec84ad 100644 --- a/brainles_preprocessing/brain_extraction/brain_extractor.py +++ b/brainles_preprocessing/brain_extraction/brain_extractor.py @@ -23,32 +23,29 @@ def extract( def apply_mask( self, - input_image_path: str, - mask_image_path: str, - masked_image_path: str, + input_image_path: Path, + mask_path: Path, + bet_image_path: Path, ) -> None: """ Apply a brain mask to an input image. - Parameters: - - input_image_path (str): Path to the input image (NIfTI format). - - mask_image_path (str): Path to the brain mask image (NIfTI format). - - masked_image_path (str): Path to save the resulting masked image (NIfTI format). - - Returns: - - str: Path to the saved masked image. + Args: + input_image_path (str): Path to the input image (NIfTI format). + mask_path (str): Path to the brain mask image (NIfTI format). + bet_image_path (str): Path to save the resulting masked image (NIfTI format). """ # read data input_data = read_nifti(input_image_path) - mask_data = read_nifti(mask_image_path) + mask_data = read_nifti(mask_path) # mask and save it masked_data = input_data * mask_data write_nifti( input_array=masked_data, - output_nifti_path=masked_image_path, + output_nifti_path=bet_image_path, reference_nifti_path=input_image_path, create_parent_directory=True, ) diff --git a/brainles_preprocessing/constants.py b/brainles_preprocessing/constants.py new file mode 100644 index 0000000..83844b9 --- /dev/null +++ b/brainles_preprocessing/constants.py @@ -0,0 +1,10 @@ +from enum import IntEnum + + +class PreprocessorSteps(IntEnum): + INPUT = 0 + COREGISTERED = 1 + ATLAS_REGISTERED = 2 + ATLAS_CORRECTED = 3 + BET = 4 + DEFACED = 5 diff --git a/brainles_preprocessing/defacing/__init__.py b/brainles_preprocessing/defacing/__init__.py new file mode 100644 index 0000000..4a93c40 --- /dev/null +++ b/brainles_preprocessing/defacing/__init__.py @@ -0,0 +1,2 @@ +from .defacer import Defacer +from .quickshear.quickshear import QuickshearDefacer diff --git a/brainles_preprocessing/defacing/defacer.py b/brainles_preprocessing/defacing/defacer.py new file mode 100644 index 0000000..e3fed5f --- /dev/null +++ b/brainles_preprocessing/defacing/defacer.py @@ -0,0 +1,43 @@ +from abc import abstractmethod +from pathlib import Path + +from auxiliary.nifti.io import read_nifti, write_nifti + + +class Defacer: + @abstractmethod + def deface( + self, + input_image_path: Path, + mask_image_path: Path, + ) -> None: + pass + + def apply_mask( + self, + input_image_path: str, + mask_path: str, + defaced_image_path: str, + ) -> None: + """ + Apply a brain mask to an input image. + + Args: + input_image_path (str): Path to the input image (NIfTI format). + mask_path (str): Path to the brain mask image (NIfTI format). + defaced_image_path (str): Path to save the resulting defaced image (NIfTI format). + """ + + # read data + input_data = read_nifti(input_image_path) + mask_data = read_nifti(mask_path) + + # mask and save it + masked_data = input_data * mask_data + + write_nifti( + input_array=masked_data, + output_nifti_path=defaced_image_path, + reference_nifti_path=input_image_path, + create_parent_directory=True, + ) diff --git a/brainles_preprocessing/defacing/quickshear/__init__.py b/brainles_preprocessing/defacing/quickshear/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/brainles_preprocessing/defacing/quickshear/nipy_quickshear.py b/brainles_preprocessing/defacing/quickshear/nipy_quickshear.py new file mode 100644 index 0000000..da9f2a9 --- /dev/null +++ b/brainles_preprocessing/defacing/quickshear/nipy_quickshear.py @@ -0,0 +1,223 @@ +# Code adapted from: https://github.com/nipy/quickshear/blob/master/quickshear.py (23.10.2024) +# Minor adaptions in terms of parameters and return values +# Original Author': Copyright (c) 2011, Nakeisha Schimke. All rights reserved. + +import argparse +import logging + +#!/usr/bin/python +import sys + +import nibabel as nb +import numpy as np +from numpy.typing import NDArray + +try: + from duecredit import BibTeX, due +except ImportError: + # Adapted from + # https://github.com/duecredit/duecredit/blob/2221bfd/duecredit/stub.py + class InactiveDueCreditCollector: + """Just a stub at the Collector which would not do anything""" + + def _donothing(self, *args, **kwargs): + """Perform no good and no bad""" + pass + + def dcite(self, *args, **kwargs): + """If I could cite I would""" + + def nondecorating_decorator(func): + return func + + return nondecorating_decorator + + cite = load = add = _donothing + + def __repr__(self): + return self.__class__.__name__ + "()" + + due = InactiveDueCreditCollector() + + def BibTeX(*args, **kwargs): + pass + + +citation_text = """@inproceedings{Schimke2011, +abstract = {Data sharing offers many benefits to the neuroscience research +community. It encourages collaboration and interorganizational research +efforts, enables reproducibility and peer review, and allows meta-analysis and +data reuse. However, protecting subject privacy and implementing HIPAA +compliance measures can be a burdensome task. For high resolution structural +neuroimages, subject privacy is threatened by the neuroimage itself, which can +contain enough facial features to re-identify an individual. To sufficiently +de-identify an individual, the neuroimage pixel data must also be removed. +Quickshear Defacing accomplishes this task by effectively shearing facial +features while preserving desirable brain tissue.}, +address = {San Francisco}, +author = {Schimke, Nakeisha and Hale, John}, +booktitle = {Proceedings of the 2nd USENIX Conference on Health Security and Privacy}, +title = {{Quickshear Defacing for Neuroimages}}, +year = {2011}, +month = sep +} +""" +# __version__ = "1.3.0.dev0" + + +def edge_mask(mask): + """Find the edges of a mask or masked image + + Parameters + ---------- + mask : 3D array + Binary mask (or masked image) with axis orientation LPS or RPS, and the + non-brain region set to 0 + + Returns + ------- + 2D array + Outline of sagittal profile (PS orientation) of mask + """ + # Sagittal profile + brain = mask.any(axis=0) + + # Simple edge detection + edgemask = ( + 4 * brain + - np.roll(brain, 1, 0) + - np.roll(brain, -1, 0) + - np.roll(brain, 1, 1) + - np.roll(brain, -1, 1) + != 0 + ) + return edgemask.astype("uint8") + + +def convex_hull(brain): + """Find the lower half of the convex hull of non-zero points + + Implements Andrew's monotone chain algorithm [0]. + + [0] https://en.wikibooks.org/wiki/Algorithm_Implementation/Geometry/Convex_hull/Monotone_chain + + Parameters + ---------- + brain : 2D array + 2D array in PS axis ordering + + Returns + ------- + (2, N) array + Sequence of points in the lower half of the convex hull of brain + """ + # convert brain to a list of points in an n x 2 matrix where n_i = (x,y) + pts = np.vstack(np.nonzero(brain)).T + + def cross(o, a, b): + return np.cross(a - o, b - o) + + lower = [] + for p in pts: + while len(lower) >= 2 and cross(lower[-2], lower[-1], p) <= 0: + lower.pop() + lower.append(p) + + return np.array(lower).T + + +@due.dcite( + BibTeX(citation_text), + description="Geometric neuroimage defacer", + path="quickshear", +) +def run_quickshear(bet_img: nb.nifti1.Nifti1Image, buffer: int = 10) -> NDArray: + """Deface image using Quickshear algorithm + + Parameters + ---------- + bet_img : Nifti1Image + Nibabel image of skull-stripped brain mask or masked anatomical + buffer : int + Distance from mask to set shearing plane + + Returns + ------- + defaced_mask: NDArray + Defaced image mask + """ + src_ornt = nb.io_orientation(bet_img.affine) + tgt_ornt = nb.orientations.axcodes2ornt("RPS") + to_RPS = nb.orientations.ornt_transform(src_ornt, tgt_ornt) + from_RPS = nb.orientations.ornt_transform(tgt_ornt, src_ornt) + + mask_RPS = nb.orientations.apply_orientation(bet_img.dataobj, to_RPS) + + edgemask = edge_mask(mask_RPS) + low = convex_hull(edgemask) + xdiffs, ydiffs = np.diff(low) + slope = ydiffs[0] / xdiffs[0] + + yint = low[1][0] - (low[0][0] * slope) - buffer + ys = np.arange(0, mask_RPS.shape[2]) * slope + yint + defaced_mask_RPS = np.ones(mask_RPS.shape, dtype="bool") + + for x, y in zip(np.nonzero(ys > 0)[0], ys.astype(int)): + defaced_mask_RPS[:, x, :y] = 0 + + defaced_mask = nb.orientations.apply_orientation(defaced_mask_RPS, from_RPS) + + # return anat_img.__class__( + # np.asanyarray(anat_img.dataobj) * defaced_mask, + # anat_img.affine, + # anat_img.header, + # ) + + return defaced_mask + + +# def main(): +# logger = logging.getLogger(__name__) +# logger.setLevel(logging.DEBUG) +# ch = logging.StreamHandler() +# ch.setLevel(logging.DEBUG) +# logger.addHandler(ch) + +# parser = argparse.ArgumentParser( +# description="Quickshear defacing for neuroimages", +# formatter_class=argparse.ArgumentDefaultsHelpFormatter, +# ) +# parser.add_argument("anat_file", type=str, help="filename of neuroimage to deface") +# parser.add_argument("mask_file", type=str, help="filename of brain mask") +# parser.add_argument( +# "defaced_file", type=str, help="filename of defaced output image" +# ) +# parser.add_argument( +# "buffer", +# type=float, +# nargs="?", +# default=10.0, +# help="buffer size (in voxels) between shearing plane and the brain", +# ) + +# opts = parser.parse_args() + +# anat_img = nb.load(opts.anat_file) +# bet_img = nb.load(opts.mask_file) + +# if not ( +# anat_img.shape == bet_img.shape +# and np.allclose(anat_img.affine, bet_img.affine) +# ): +# logger.warning( +# "Anatomical and mask images do not have the same shape and affine." +# ) +# return -1 + +# new_anat = quickshear(anat_img, bet_img, opts.buffer) +# new_anat.to_filename(opts.defaced_file) +# logger.info(f"Defaced file: {opts.defaced_file}") + + +# if __name__ == "__main__": +# sys.exit(main()) diff --git a/brainles_preprocessing/defacing/quickshear/quickshear.py b/brainles_preprocessing/defacing/quickshear/quickshear.py new file mode 100644 index 0000000..0bfc10f --- /dev/null +++ b/brainles_preprocessing/defacing/quickshear/quickshear.py @@ -0,0 +1,53 @@ +from pathlib import Path + +import nibabel as nib +from brainles_preprocessing.defacing.defacer import Defacer +from brainles_preprocessing.defacing.quickshear.nipy_quickshear import run_quickshear +from numpy.typing import NDArray +from auxiliary.nifti.io import write_nifti + + +class QuickshearDefacer(Defacer): + """ + Defacer using Quickshear algorithm. + + Quickshear uses a skull stripped version of an anatomical images as a reference to deface the unaltered anatomical image. + + Base publication: + - PDF: https://www.researchgate.net/profile/J-Hale/publication/262319696_Quickshear_defacing_for_neuroimages/links/570b97ee08aed09e917516b1/Quickshear-defacing-for-neuroimages.pdf + - Bibtex: + ``` + @article{schimke2011quickshear, + title={Quickshear Defacing for Neuroimages.}, + author={Schimke, Nakeisha and Hale, John}, + journal={HealthSec}, + volume={11}, + pages={11}, + year={2011} + } + ``` + """ + + def __init__(self, buffer: float = 10.0): + """Initialize Quickshear defacer + + Args: + buffer (float, optional): buffer parameter from quickshear algorithm. Defaults to 10.0. + """ + super().__init__() + self.buffer = buffer + + def deface(self, mask_image_path: Path, bet_img_path: Path) -> None: + """Deface image using Quickshear algorithm + + Args: + bet_img_path (Path): Path to the brain extracted image + """ + + bet_img = nib.load(bet_img_path) + mask = run_quickshear(bet_img=bet_img, buffer=self.buffer) + write_nifti( + input_array=mask, + output_nifti_path=mask_image_path, + reference_nifti_path=bet_img_path, + ) diff --git a/brainles_preprocessing/modality.py b/brainles_preprocessing/modality.py index caa540a..6af5730 100644 --- a/brainles_preprocessing/modality.py +++ b/brainles_preprocessing/modality.py @@ -1,13 +1,19 @@ +import logging import os import shutil -from typing import List, Optional +from pathlib import Path +from typing import Optional from auxiliary.nifti.io import read_nifti, write_nifti from auxiliary.turbopath import turbopath from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor +from brainles_preprocessing.constants import PreprocessorSteps +from brainles_preprocessing.defacing import Defacer, QuickshearDefacer from brainles_preprocessing.normalization.normalizer_base import Normalizer from brainles_preprocessing.registration.registrator import Registrator +logger = logging.getLogger(__name__) + class Modality: """ @@ -16,23 +22,35 @@ class Modality: Args: modality_name (str): Name of the modality, e.g., "T1", "T2", "FLAIR". input_path (str): Path to the input modality data. - output_path (str): Path to save the preprocessed modality data. - bet (bool): Indicates whether brain extraction should be performed (True) or not (False). normalizer (Normalizer, optional): An optional normalizer for intensity normalization. + raw_bet_output_path (str, optional): Path to save the raw brain extracted modality data. + raw_skull_output_path (str, optional): Path to save the raw modality data with skull. + raw_defaced_output_path (str, optional): Path to save the raw defaced modality data. + normalized_bet_output_path (str, optional): Path to save the normalized brain extracted modality data. Requires a normalizer. + normalized_skull_output_path (str, optional): Path to save the normalized modality data with skull. Requires a normalizer. + normalized_defaced_output_path (str, optional): Path to save the normalized defaced modality data. Requires a normalizer. + atlas_correction (bool, optional): Indicates whether atlas correction should be performed. Attributes: modality_name (str): Name of the modality. input_path (str): Path to the input modality data. - output_path (str): Path to save the preprocessed modality data. - bet (bool): Indicates whether brain extraction is enabled. normalizer (Normalizer, optional): An optional normalizer for intensity normalization. + raw_bet_output_path (str, optional): Path to save the raw brain extracted modality data. + raw_skull_output_path (str, optional): Path to save the raw modality data with skull. + raw_defaced_output_path (str, optional): Path to save the raw defaced modality data. + normalized_bet_output_path (str, optional): Path to save the normalized brain extracted modality data. Requires a normalizer. + normalized_skull_output_path (str, optional): Path to save the normalized modality data with skull. Requires a normalizer. + normalized_defaced_output_path (str, optional): Path to save the normalized defaced modality data. Requires a normalizer. + bet (bool): Indicates whether brain extraction is enabled. + atlas_correction (bool): Indicates whether atlas correction should be performed. Example: >>> t1_modality = Modality( ... modality_name="T1", ... input_path="/path/to/input_t1.nii", - ... output_path="/path/to/preprocessed_t1.nii", - ... bet=True + ... normalizer=PercentileNormalizer(), + ... raw_bet_output_path="/path/to/raw_bet_t1.nii", + ... normalized_bet_output_path="/path/to/norm_bet_t1.nii", ... ) """ @@ -41,11 +59,13 @@ def __init__( self, modality_name: str, input_path: str, + normalizer: Optional[Normalizer] = None, raw_bet_output_path: Optional[str] = None, raw_skull_output_path: Optional[str] = None, + raw_defaced_output_path: Optional[str] = None, normalized_bet_output_path: Optional[str] = None, normalized_skull_output_path: Optional[str] = None, - normalizer: Optional[Normalizer] = None, + normalized_defaced_output_path: Optional[str] = None, atlas_correction: bool = True, ) -> None: # basics @@ -63,9 +83,11 @@ def __init__( and normalized_bet_output_path is None and raw_skull_output_path is None and normalized_skull_output_path is None + and raw_defaced_output_path is None + and normalized_defaced_output_path is None ): raise ValueError( - "All output paths are None. At least one output path must be provided." + "All output paths are None. At least one output paths must be provided." ) # handle input paths @@ -79,6 +101,11 @@ def __init__( else: self.raw_skull_output_path = raw_skull_output_path + if raw_defaced_output_path is not None: + self.raw_defaced_output_path = turbopath(raw_defaced_output_path) + else: + self.raw_defaced_output_path = raw_defaced_output_path + if normalized_bet_output_path is not None: if normalizer is None: raise ValueError( @@ -97,17 +124,46 @@ def __init__( else: self.normalized_skull_output_path = normalized_skull_output_path - # print("self.raw_bet_output_path", self.raw_bet_output_path) - # print("self.normalized_skull_output_path", self.normalized_skull_output_path) - # print("self.bet:", self.bet) + if normalized_defaced_output_path is not None: + if normalizer is None: + raise ValueError( + "A normalizer must be provided if normalized_defaced_output_path is not None." + ) + self.normalized_defaced_output_path = turbopath( + normalized_defaced_output_path + ) + else: + self.normalized_defaced_output_path = normalized_defaced_output_path + + self.steps = {k: None for k in PreprocessorSteps} @property def bet(self) -> bool: + """Check if any brain extraction output is specified. + + Returns: + bool: True if any brain extraction output is specified, False otherwise. + """ return any( path is not None for path in [self.raw_bet_output_path, self.normalized_bet_output_path] ) + @property + def requires_deface(self) -> bool: + """Check if any defacing output is specified. + + Returns: + bool: True if any defacing output is specified, False otherwise. + """ + return any( + path is not None + for path in [ + self.raw_defaced_output_path, + self.normalized_defaced_output_path, + ] + ) + def normalize( self, temporary_directory: str, @@ -155,6 +211,7 @@ def register( fixed_image_path: str, registration_dir: str, moving_image_name: str, + step: PreprocessorSteps, ) -> str: """ Register the current modality to a fixed image using the specified registrator. @@ -182,36 +239,67 @@ def register( log_file_path=registered_log, ) self.current = registered + self.steps[step] = registered return registered_matrix - def apply_mask( + def apply_bet_mask( self, brain_extractor: BrainExtractor, - brain_masked_dir_path: str, - atlas_mask_path: str, + mask_path: Path, + bet_dir: Path, ) -> None: """ Apply a brain mask to the current modality using the specified brain extractor. Args: brain_extractor (BrainExtractor): The brain extractor object. - brain_masked_dir_path (str): Directory to store masked images. - atlas_mask_path (str): Path to the brain mask. + mask_path (str): Path to the brain mask. + bet_dir (str): Directory to store computed bet images. Returns: None """ if self.bet: - brain_masked = os.path.join( - brain_masked_dir_path, - f"brain_masked__{self.modality_name}.nii.gz", - ) + bet_img = bet_dir / f"atlas__{self.modality_name}_bet.nii.gz" + brain_extractor.apply_mask( input_image_path=self.current, - mask_image_path=atlas_mask_path, - masked_image_path=brain_masked, + mask_path=mask_path, + bet_image_path=bet_img, ) - self.current = brain_masked + self.current = bet_img + self.steps[PreprocessorSteps.BET] = bet_img + + def apply_deface_mask( + self, + defacer: Defacer, + mask_path: str, + deface_dir: str, + ) -> None: + """ + Apply a deface mask to the current modality using the specified brain extractor. + + Args: + defacer (Defacer): The Defacer object. + mask_path (str): Path to the deface mask. + defaced_masked_dir_path (str): Directory to store masked images. + """ + if self.deface: + defaced_img = deface_dir / f"atlas__{self.modality_name}_defaced.nii.gz" + input_img = self.steps[ + ( + PreprocessorSteps.ATLAS_CORRECTED + if self.atlas_correction + else PreprocessorSteps.ATLAS_REGISTERED + ) + ] + defacer.apply_mask( + input_image_path=input_img, + mask_path=mask_path, + defaced_image_path=defaced_img, + ) + self.current = defaced_img + self.steps[PreprocessorSteps.DEFACED] = defaced_img def transform( self, @@ -220,6 +308,7 @@ def transform( registration_dir_path: str, moving_image_name: str, transformation_matrix_path: str, + step: PreprocessorSteps, ) -> None: """ Transform the current modality using the specified registrator and transformation matrix. @@ -247,6 +336,7 @@ def transform( log_file_path=transformed_log, ) self.current = transformed + self.steps[step] = transformed def extract_brain_region( self, @@ -263,23 +353,56 @@ def extract_brain_region( Returns: str: Path to the extracted brain mask. """ - bet_log = os.path.join(bet_dir_path, "brain-extraction.log") - atlas_bet_cm = os.path.join( - bet_dir_path, f"atlas_bet_{self.modality_name}.nii.gz" - ) - atlas_mask_path = os.path.join( - bet_dir_path, f"atlas_bet_{self.modality_name}_mask.nii.gz" - ) + bet_log = bet_dir_path / "brain-extraction.log" + + atlas_bet_cm = bet_dir_path / f"atlas__{self.modality_name}_bet.nii.gz" + mask_path = bet_dir_path / f"atlas__{self.modality_name}_brain_mask.nii.gz" brain_extractor.extract( input_image_path=self.current, masked_image_path=atlas_bet_cm, - brain_mask_path=atlas_mask_path, + brain_mask_path=mask_path, log_file_path=bet_log, ) + + # always temporarily store bet image for center modality, since e.g. quickshear defacing could require it + # down the line even if the user does not wish to save the bet image + self.steps[PreprocessorSteps.BET] = atlas_bet_cm + if self.bet is True: self.current = atlas_bet_cm - return atlas_mask_path + return mask_path + + def deface( + self, + defacer, + defaced_dir_path: str, + ) -> str: + """ + Deface the current modality using the specified defacer. + + Args: + defacer (Defacer): The defacer object. + defaced_dir_path (str): Directory to store defacing results. + + Returns: + str: Path to the extracted brain mask. + """ + + if isinstance(defacer, QuickshearDefacer): + atlas_mask_path = os.path.join( + defaced_dir_path, f"atlas__{self.modality_name}_deface_mask.nii.gz" + ) + defacer.deface( + mask_image_path=atlas_mask_path, + bet_img_path=self.steps[PreprocessorSteps.BET], + ) + return atlas_mask_path + else: + logger.warning( + "Defacing method not implemented yet. Skipping defacing for this modality." + ) + pass def save_current_image( self, @@ -295,7 +418,7 @@ def save_current_image( ) elif normalization is True: image = read_nifti(self.current) - print("current image", self.current) + # print("current image", self.current) normalized_image = self.normalizer.normalize(image=image) write_nifti( input_array=normalized_image, diff --git a/brainles_preprocessing/preprocessor.py b/brainles_preprocessing/preprocessor.py index 96060fd..3f73397 100644 --- a/brainles_preprocessing/preprocessor.py +++ b/brainles_preprocessing/preprocessor.py @@ -1,7 +1,5 @@ -from functools import wraps import logging import os -from pathlib import Path import shutil import signal import subprocess @@ -9,12 +7,17 @@ import tempfile import traceback from datetime import datetime +from functools import wraps +from pathlib import Path from typing import List, Optional from auxiliary.turbopath import turbopath +from brainles_preprocessing.constants import PreprocessorSteps +from brainles_preprocessing.defacing import Defacer, QuickshearDefacer -from .brain_extraction.brain_extractor import BrainExtractor +from .brain_extraction.brain_extractor import BrainExtractor, HDBetExtractor from .modality import Modality +from .registration import ANTsRegistrator from .registration.registrator import Registrator logger = logging.getLogger(__name__) @@ -28,9 +31,10 @@ class Preprocessor: center_modality (Modality): The central modality for coregistration. moving_modalities (List[Modality]): List of modalities to be coregistered to the central modality. registrator (Registrator): The registrator object for coregistration and registration to the atlas. - brain_extractor (BrainExtractor): The brain extractor object for brain extraction. - atlas_image_path (str, optional): Path to the atlas image for registration (default is the T1 atlas). - temp_folder (str, optional): Path to a temporary folder for storing intermediate results. + brain_extractor (Optional[BrainExtractor]): The brain extractor object for brain extraction. + defacer (Optional[Defacer]): The defacer object for defacing images. + atlas_image_path (Optional[str]): Path to the atlas image for registration (default is the T1 atlas). + temp_folder (Optional[str]): Path to a folder for storing intermediate results. use_gpu (Optional[bool]): Use GPU for processing if True, CPU if False, or automatically detect if None. limit_cuda_visible_devices (Optional[str]): Limit CUDA visible devices to a specific GPU ID. @@ -40,8 +44,9 @@ def __init__( self, center_modality: Modality, moving_modalities: List[Modality], - registrator: Registrator, - brain_extractor: BrainExtractor, + registrator: Registrator = None, + brain_extractor: Optional[BrainExtractor] = None, + defacer: Optional[Defacer] = None, atlas_image_path: str = turbopath(__file__).parent + "/registration/atlas/t1_brats_space.nii", temp_folder: Optional[str] = None, @@ -54,7 +59,14 @@ def __init__( self.moving_modalities = moving_modalities self.atlas_image_path = turbopath(atlas_image_path) self.registrator = registrator + if self.registrator is None: + logger.warning( + "No registrator provided, using default ANTsRegistrator for registration." + ) + self.registrator = ANTsRegistrator() + self.brain_extractor = brain_extractor + self.defacer = defacer self._configure_gpu( use_gpu=use_gpu, limit_cuda_visible_devices=limit_cuda_visible_devices @@ -184,6 +196,10 @@ def signal_handler(sig, frame): def all_modalities(self): return [self.center_modality] + self.moving_modalities + @property + def requires_defacing(self): + return any(modality.requires_deface for modality in self.all_modalities) + @ensure_remove_log_file_handler def run( self, @@ -191,6 +207,7 @@ def run( save_dir_atlas_registration: Optional[str] = None, save_dir_atlas_correction: Optional[str] = None, save_dir_brain_extraction: Optional[str] = None, + save_dir_defacing: Optional[str] = None, log_file: Optional[str] = None, ): """ @@ -198,21 +215,24 @@ def run( atlas correction, and optional brain extraction. Args: - save_dir_coregistration (str, optional): Directory path to save coregistration results. - save_dir_atlas_registration (str, optional): Directory path to save atlas registration results. - save_dir_atlas_correction (str, optional): Directory path to save atlas correction results. - save_dir_brain_extraction (str, optional): Directory path to save brain extraction results. + save_dir_coregistration (str, optional): Directory path to save intermediate coregistration results. + save_dir_atlas_registration (str, optional): Directory path to save intermediate atlas registration results. + save_dir_atlas_correction (str, optional): Directory path to save intermediate atlas correction results. + save_dir_brain_extraction (str, optional): Directory path to save intermediate brain extraction results. + save_dir_defacing (str, optional): Directory path to save intermediate defacing results. log_file (str, optional): Path to save the log file. Defaults to a timestamped file in the current directory. This method orchestrates the entire preprocessing workflow by sequentially performing: - 1. Coregistration: Aligning moving modalities to the central modality. + 1. Co-registration: Aligning moving modalities to the central modality. 2. Atlas Registration: Aligning the central modality to a predefined atlas. 3. Atlas Correction: Applying additional correction in atlas space if specified. - 4. Brain Extraction: Optionally extracting brain regions using specified masks. + 4. Brain Extraction: Optionally extracting brain regions using specified masks. Only executed if any modality requires a brain extraction output (or a defacing output that requires prior brain extraction). + 5. Defacing: Optionally deface images to remove facial features. Only executed if any modality requires a defacing output. Results are saved in the specified directories, allowing for modular and configurable output storage. """ + self._set_log_file(log_file=log_file) logger.info(f"{' Starting preprocessing ':=^80}") logger.info(f"Logs are saved to {self.log_file_handler.baseFilename}") @@ -220,11 +240,66 @@ def run( f"Received center modality: {self.center_modality.modality_name} and moving modalities: {', '.join([modality.modality_name for modality in self.moving_modalities])}" ) + # Co-register moving modalities to center modality logger.info(f"{' Starting Coregistration ':-^80}") + self.run_coregistration( + save_dir_coregistration=save_dir_coregistration, + ) + logger.info( + f"Coregistration complete. Output saved to {save_dir_coregistration}" + ) + + # Register center modality to atlas + logger.info(f"{' Starting atlas registration ':-^80}") + self.run_atlas_registration( + save_dir_atlas_registration=save_dir_atlas_registration, + ) + logger.info( + f"Transformations complete. Output saved to {save_dir_atlas_registration}" + ) + + # Optional: additional correction in atlas space + logger.info(f"{' Checking optional atlas correction ':-^80}") + self.run_atlas_correction( + save_dir_atlas_correction=save_dir_atlas_correction, + ) + + # now we save images that are not skullstripped (current image = atlas registered or atlas registered + corrected) + logger.info("Saving non skull-stripped images...") + for modality in self.all_modalities: + if modality.raw_skull_output_path is not None: + modality.save_current_image( + modality.raw_skull_output_path, + normalization=False, + ) + if modality.normalized_skull_output_path is not None: + modality.save_current_image( + modality.normalized_skull_output_path, + normalization=True, + ) + + # Optional: Brain extraction + logger.info(f"{' Checking optional brain extraction ':-^80}") + self.run_brain_extraction( + save_dir_brain_extraction=save_dir_brain_extraction, + ) + # ## Defacing + logger.info(f"{' Checking optional defacing ':-^80}") + self.run_defacing( + save_dir_defacing=save_dir_defacing, + ) + ## end + logger.info(f"{' Preprocessing complete ':=^80}") + + def run_coregistration(self, save_dir_coregistration: Optional[str] = None) -> None: + """Coregister moving modalities to center modality. + + Args: + save_dir_coregistration (str, optional): Directory path to save intermediate coregistration results. + """ + coregistration_dir = Path(os.path.join(self.temp_folder, "coregistration")) + coregistration_dir.mkdir(exist_ok=True, parents=True) - # Coregister moving modalities to center modality - coregistration_dir = os.path.join(self.temp_folder, "coregistration") - os.makedirs(coregistration_dir, exist_ok=True) logger.info( f"Coregistering {len(self.moving_modalities)} moving modalities to center modality..." ) @@ -238,6 +313,7 @@ def run( fixed_image_path=self.center_modality.current, registration_dir=coregistration_dir, moving_image_name=file_name, + step=PreprocessorSteps.COREGISTERED, ) shutil.copyfile( @@ -252,12 +328,15 @@ def run( src=coregistration_dir, save_dir=save_dir_coregistration, ) - logger.info( - f"Coregistration complete. Output saved to {save_dir_coregistration}" - ) - # Register center modality to atlas - logger.info(f"{' Starting atlas registration ':-^80}") + def run_atlas_registration( + self, save_dir_atlas_registration: Optional[str] = None + ) -> None: + """Register center modality to atlas. + + Args: + save_dir_atlas_registration (Optional[str], optional): Directory path to save intermediate atlas registration results. Defaults to None. + """ logger.info(f"Registering center modality to atlas...") center_file_name = f"atlas__{self.center_modality.modality_name}" transformation_matrix = self.center_modality.register( @@ -265,6 +344,7 @@ def run( fixed_image_path=self.atlas_image_path, registration_dir=self.atlas_dir, moving_image_name=center_file_name, + step=PreprocessorSteps.ATLAS_REGISTERED, ) logger.info(f"Atlas registration complete. Output saved to {self.atlas_dir}") @@ -280,22 +360,27 @@ def run( moving_modality.transform( registrator=self.registrator, fixed_image_path=self.atlas_image_path, - registration_dir_path=self.atlas_dir, + registration_dir_path=Path(self.atlas_dir), moving_image_name=moving_file_name, transformation_matrix_path=transformation_matrix, + step=PreprocessorSteps.ATLAS_REGISTERED, ) self._save_output( src=self.atlas_dir, save_dir=save_dir_atlas_registration, ) - logger.info( - f"Transformations complete. Output saved to {save_dir_atlas_registration}" - ) - # Optional: additional correction in atlas space - logger.info(f"{' Checking optional atlas correction ':-^80}") - atlas_correction_dir = os.path.join(self.temp_folder, "atlas-correction") - os.makedirs(atlas_correction_dir, exist_ok=True) + def run_atlas_correction( + self, + save_dir_atlas_correction: Optional[str] = None, + ) -> None: + """Apply optional atlas correction to moving modalities. + + Args: + save_dir_atlas_correction (Optional[str], optional): Directory path to save intermediate atlas correction results. Defaults to None. + """ + atlas_correction_dir = Path(os.path.join(self.temp_folder, "atlas-correction")) + atlas_correction_dir.mkdir(exist_ok=True, parents=True) for moving_modality in self.moving_modalities: if moving_modality.atlas_correction: @@ -308,20 +393,25 @@ def run( fixed_image_path=self.center_modality.current, registration_dir=atlas_correction_dir, moving_image_name=moving_file_name, + step=PreprocessorSteps.ATLAS_CORRECTED, ) else: - logger.info("Skipping optional atlas correction.") + logger.info( + f"Skipping optional atlas correction for Modality {moving_modality.modality_name}." + ) if self.center_modality.atlas_correction: + center_atlas_corrected_path = os.path.join( + atlas_correction_dir, + f"atlas_corrected__{self.center_modality.modality_name}.nii.gz", + ) shutil.copyfile( src=self.center_modality.current, - dst=os.path.join( - atlas_correction_dir, - f"atlas_corrected__{self.center_modality.modality_name}.nii.gz", - ), + dst=center_atlas_corrected_path, ) - logger.info( - f"Atlas correction complete. Output saved to {save_dir_atlas_correction}" + # save step result + self.center_modality.steps[PreprocessorSteps.ATLAS_CORRECTED] = ( + center_atlas_corrected_path ) self._save_output( @@ -329,68 +419,133 @@ def run( save_dir=save_dir_atlas_correction, ) - # now we save images that are not skullstripped - logger.info("Saving non skull-stripped images...") + def run_brain_extraction( + self, save_dir_brain_extraction: Optional[str] = None + ) -> None: + """Extract brain regions using specified BrainExtractor. + + Args: + save_dir_brain_extraction (Optional[str], optional): Directory path to save intermediate brain extraction results. Defaults to None. + """ + # check if any bet output paths are requested + brain_extraction = any(modality.bet for modality in self.all_modalities) + + # check if any downstream task (e.g. QuickShear) requires brain extraction. + # Quickshear is the default defacer so we also require bet if no defacer is specified + required_downstream = self.requires_defacing and ( + isinstance(self.defacer, QuickshearDefacer) or self.defacer is None + ) + + # skip if no brain extraction is required + if not brain_extraction and not required_downstream: + logger.info("Skipping brain extraction.") + return + + logger.info( + f"Starting brain extraction{' (for downstream defacing task)' if (required_downstream and not brain_extraction) else ''}..." + ) + + # setup output dirs + bet_dir = self.temp_folder / "brain-extraction" + os.makedirs(bet_dir, exist_ok=True) + + logger.info("Extracting brain region for center modality...") + + # Assert that a brain extractor is specified (since the arg is optional) + if self.brain_extractor is None: + logger.warning( + "Brain extraction is required to compute specified outputs but no brain extractor was specified during class initialization." + + " Using default `brainles_preprocessing.brain_extraction.HDBetExtractor`" + ) + self.brain_extractor = HDBetExtractor() + + atlas_mask = self.center_modality.extract_brain_region( + brain_extractor=self.brain_extractor, bet_dir_path=bet_dir + ) + for moving_modality in self.moving_modalities: + logger.info(f"Applying brain mask to {moving_modality.modality_name}...") + moving_modality.apply_bet_mask( + brain_extractor=self.brain_extractor, + mask_path=atlas_mask, + bet_dir=bet_dir, + ) + + self._save_output( + src=bet_dir, + save_dir=save_dir_brain_extraction, + ) + + # now we save images that are skullstripped + logger.info("Saving brain extracted (bet), i.e. skull-stripped images...") for modality in self.all_modalities: - if modality.raw_skull_output_path is not None: + if modality.raw_bet_output_path is not None: modality.save_current_image( - modality.raw_skull_output_path, + modality.raw_bet_output_path, normalization=False, ) - if modality.normalized_skull_output_path is not None: + if modality.normalized_bet_output_path is not None: modality.save_current_image( - modality.normalized_skull_output_path, + modality.normalized_bet_output_path, normalization=True, ) - # Optional: Brain extraction - logger.info(f"{' Checking optional brain extraction ':-^80}") - brain_extraction = any(modality.bet for modality in self.all_modalities) - # print("brain extraction: ", brain_extraction) - - if brain_extraction: - logger.info("Starting brain extraction...") - bet_dir = os.path.join(self.temp_folder, "brain-extraction") - os.makedirs(bet_dir, exist_ok=True) - brain_masked_dir = os.path.join(bet_dir, "brain_masked") - os.makedirs(brain_masked_dir, exist_ok=True) - logger.info("Extracting brain region for center modality...") - atlas_mask = self.center_modality.extract_brain_region( - brain_extractor=self.brain_extractor, bet_dir_path=bet_dir - ) - for moving_modality in self.moving_modalities: - logger.info( - f"Applying brain mask to {moving_modality.modality_name}..." - ) - moving_modality.apply_mask( - brain_extractor=self.brain_extractor, - brain_masked_dir_path=brain_masked_dir, - atlas_mask_path=atlas_mask, - ) - self._save_output( - src=bet_dir, - save_dir=save_dir_brain_extraction, + def run_defacing(self, save_dir_defacing: Optional[str] = None) -> None: + """Deface images to remove facial features using specified Defacer. + + Args: + save_dir_defacing (Optional[str], optional): Directory path to save intermediate defacing results. Defaults to None. + """ + + # skip if no defacing is required + if not self.requires_defacing: + logger.info("Skipping optional defacing.") + return + + logger.info("Starting defacing...") + + # setup output dir + deface_dir = self.temp_folder / "deface" + os.makedirs(deface_dir, exist_ok=True) + + logger.info("Defacing center modality...") + + # Assert that a defacer is specified (since the arg is optional) + if self.defacer is None: + logger.warning( + "Requested defacing but no defacer was specified during class initialization." + + " Using default `brainles_preprocessing.defacing.QuickshearDefacer`" ) - logger.info( - f"Brain extraction complete. Output saved to {save_dir_brain_extraction}" + self.defacer = QuickshearDefacer() + + atlas_mask = self.center_modality.deface( + defacer=self.defacer, defaced_dir_path=deface_dir + ) + # looping over _all_ modalities since .deface is no applying the computed mask + for moving_modality in self.all_modalities: + logger.info(f"Applying deface mask to {moving_modality.modality_name}...") + moving_modality.apply_deface_mask( + defacer=self.defacer, + mask_path=atlas_mask, + deface_dir=deface_dir, ) - else: - logger.info("Skipping optional brain extraction.") - # now we save images that are skullstripped - logger.info("Saving skull-stripped images...") + self._save_output( + src=deface_dir, + save_dir=save_dir_defacing, + ) + # now we save images that are skull-stripped + logger.info("Saving defaced images...") for modality in self.all_modalities: - if modality.raw_bet_output_path is not None: + if modality.raw_defaced_output_path is not None: modality.save_current_image( - modality.raw_bet_output_path, + modality.raw_defaced_output_path, normalization=False, ) - if modality.normalized_bet_output_path is not None: + if modality.normalized_defaced_output_path is not None: modality.save_current_image( - modality.normalized_bet_output_path, + modality.normalized_defaced_output_path, normalization=True, ) - logger.info(f"{' Preprocessing complete ':=^80}") def _save_output( self, diff --git a/tests/test_hdbet_brain_extractor.py b/tests/test_hdbet_brain_extractor.py index a79d2a5..d0e6b76 100644 --- a/tests/test_hdbet_brain_extractor.py +++ b/tests/test_hdbet_brain_extractor.py @@ -55,8 +55,8 @@ def test_extract_creates_output_files(self): def test_apply_mask_creates_output_file(self): self.brain_extractor.apply_mask( input_image_path=self.input_image_path, - mask_image_path=self.input_brain_mask_path, - masked_image_path=self.masked_again_image_path, + mask_path=self.input_brain_mask_path, + bet_image_path=self.masked_again_image_path, ) self.assertTrue( os.path.exists(self.masked_again_image_path),