Skip to content

Commit

Permalink
Reformatted code using black to fix linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
yasinzaii committed Oct 28, 2024
1 parent 1cde873 commit b4751fe
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 58 deletions.
17 changes: 9 additions & 8 deletions brainles_preprocessing/brain_extraction/brain_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# TODO add typing and docs
import shutil
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
Expand All @@ -10,8 +10,9 @@


class Mode(Enum):
FAST = 'fast'
ACCURATE = 'accurate'
FAST = "fast"
ACCURATE = "accurate"


class BrainExtractor:
@abstractmethod
Expand Down Expand Up @@ -51,7 +52,6 @@ def apply_mask(
mask_path (str or Path): Path to the brain mask image (NIfTI format).
bet_image_path (str or Path): Path to save the resulting masked image (NIfTI format).
"""


try:
# Read data
Expand Down Expand Up @@ -105,11 +105,11 @@ def extract(
device (str or int): Device to use for computation (e.g., 0 for GPU 0, 'cpu' for CPU).
do_tta (bool): whether to do test time data augmentation by mirroring along all axes.
"""

# Ensure mode is a Mode enum instance
if isinstance(mode, str):
try:
mode_enum = Mode(mode.lower())
mode_enum = Mode(mode.lower())
except ValueError:
raise ValueError(f"'{mode}' is not a valid Mode.")
elif isinstance(mode, Mode):
Expand All @@ -132,7 +132,9 @@ def extract(

# Construct the path to the generated mask
masked_image_path = Path(masked_image_path)
hdbet_mask_path = masked_image_path.with_name(masked_image_path.name.replace('.nii.gz', '_mask.nii.gz'))
hdbet_mask_path = masked_image_path.with_name(
masked_image_path.name.replace(".nii.gz", "_mask.nii.gz")
)

if hdbet_mask_path.resolve() != Path(brain_mask_path).resolve():
try:
Expand All @@ -142,4 +144,3 @@ def extract(
)
except Exception as e:
raise RuntimeError(f"Error copying mask file: {e}") from e

17 changes: 11 additions & 6 deletions brainles_preprocessing/defacing/defacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class Defacer(ABC):
Subclasses should implement the `deface` method to generate a defaced image
based on the provided input image and mask.
"""

@abstractmethod
def deface(
self,
Expand Down Expand Up @@ -41,23 +42,27 @@ def apply_mask(
mask_path (str or Path): Path to the brain mask image (NIfTI format).
defaced_image_path (str or Path): Path to save the resulting defaced image (NIfTI format).
"""

if not input_image_path.is_file():
raise FileNotFoundError(f"Input image file does not exist: {input_image_path}")
raise FileNotFoundError(
f"Input image file does not exist: {input_image_path}"
)
if not mask_path.is_file():
raise FileNotFoundError(f"Mask file does not exist: {mask_path}")

try:
# Read data
input_data = read_nifti(str(input_image_path))
mask_data = read_nifti(str(mask_path))
except Exception as e:
raise RuntimeError(f"An error occurred while reading input files: {e}") from e

raise RuntimeError(
f"An error occurred while reading input files: {e}"
) from e

# Check that the input and mask have the same shape
if input_data.shape != mask_data.shape:
raise ValueError("Input image and mask must have the same dimensions.")

# Apply mask (element-wise multiplication)
masked_data = input_data * mask_data

Expand Down
50 changes: 31 additions & 19 deletions brainles_preprocessing/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def __init__(
modality_name: str,
input_path: Union[str, Path],
normalizer: Optional[Normalizer] = None,
raw_bet_output_path: Optional[Union[str, Path]] = None,
raw_skull_output_path: Optional[Union[str, Path]] = None,
raw_defaced_output_path: Optional[Union[str, Path]] = None,
normalized_bet_output_path: Optional[Union[str, Path]] = None,
normalized_skull_output_path: Optional[Union[str, Path]] = None,
normalized_defaced_output_path: Optional[Union[str, Path]] = None,
raw_bet_output_path: Optional[Union[str, Path]] = None,
raw_skull_output_path: Optional[Union[str, Path]] = None,
raw_defaced_output_path: Optional[Union[str, Path]] = None,
normalized_bet_output_path: Optional[Union[str, Path]] = None,
normalized_skull_output_path: Optional[Union[str, Path]] = None,
normalized_defaced_output_path: Optional[Union[str, Path]] = None,
atlas_correction: bool = True,
) -> None:
# Basics
Expand All @@ -89,9 +89,15 @@ def __init__(
)

# handle input paths
self.raw_bet_output_path = Path(raw_bet_output_path) if raw_bet_output_path else None
self.raw_skull_output_path = Path(raw_skull_output_path) if raw_skull_output_path else None
self.raw_defaced_output_path = Path(raw_defaced_output_path) if raw_defaced_output_path else None
self.raw_bet_output_path = (
Path(raw_bet_output_path) if raw_bet_output_path else None
)
self.raw_skull_output_path = (
Path(raw_skull_output_path) if raw_skull_output_path else None
)
self.raw_defaced_output_path = (
Path(raw_defaced_output_path) if raw_defaced_output_path else None
)

if normalized_bet_output_path:
if normalizer is None:
Expand Down Expand Up @@ -163,15 +169,19 @@ def normalize(
store_unnormalized.mkdir(parents=True, exist_ok=True)
shutil.copyfile(
src=str(self.current),
dst=str(store_unnormalized / f"unnormalized__{self.modality_name}.nii.gz"),
dst=str(
store_unnormalized / f"unnormalized__{self.modality_name}.nii.gz"
),
)

if temporary_directory:
unnormalized_dir = Path(temporary_directory) / "unnormalized"
unnormalized_dir.mkdir(parents=True, exist_ok=True)
shutil.copyfile(
src=str(self.current),
dst=str(unnormalized_dir / f"unnormalized__{self.modality_name}.nii.gz"),
dst=str(
unnormalized_dir / f"unnormalized__{self.modality_name}.nii.gz"
),
)

# Normalize the image
Expand Down Expand Up @@ -208,13 +218,13 @@ def register(
"""
fixed_image_path = Path(fixed_image_path)
registration_dir = Path(registration_dir)

registered = registration_dir / f"{moving_image_name}.nii.gz"
registered_log = registration_dir / f"{moving_image_name}.log"

# Note, add file ending depending on registration backend!
registered_matrix = registration_dir / f"{moving_image_name}"

registrator.register(
fixed_image_path=fixed_image_path,
moving_image_path=self.current,
Expand Down Expand Up @@ -277,9 +287,11 @@ def apply_deface_mask(
deface_dir = Path(deface_dir)
defaced_img = deface_dir / f"atlas__{self.modality_name}_defaced.nii.gz"
input_img = self.steps[
PreprocessorSteps.ATLAS_CORRECTED
if self.atlas_correction
else PreprocessorSteps.ATLAS_REGISTERED
(
PreprocessorSteps.ATLAS_CORRECTED
if self.atlas_correction
else PreprocessorSteps.ATLAS_REGISTERED
)
]
defacer.apply_mask(
input_image_path=input_img,
Expand Down Expand Up @@ -381,7 +393,7 @@ def deface(
"""

if isinstance(defacer, QuickshearDefacer):

defaced_dir_path = Path(defaced_dir_path)
atlas_mask_path = (
defaced_dir_path / f"atlas__{self.modality_name}_deface_mask.nii.gz"
Expand Down Expand Up @@ -430,4 +442,4 @@ def save_current_image(
shutil.copyfile(
src=str(self.current),
dst=str(output_path),
)
)
41 changes: 27 additions & 14 deletions brainles_preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ def __init__(

self.center_modality = center_modality
self.moving_modalities = moving_modalities

if atlas_image_path is None:
self.atlas_image_path = Path(__file__).parent / "registration" / "atlas" / "t1_brats_space.nii"
self.atlas_image_path = (
Path(__file__).parent / "registration" / "atlas" / "t1_brats_space.nii"
)
else:
self.atlas_image_path = Path(atlas_image_path)

self.registrator = registrator
if self.registrator is None:
logger.warning(
Expand All @@ -86,7 +88,7 @@ def __init__(

self.atlas_dir = self.temp_folder / "atlas-space"
self.atlas_dir.mkdir(exist_ok=True, parents=True)

def _configure_gpu(
self, use_gpu: Optional[bool], limit_cuda_visible_devices: Optional[str] = None
) -> None:
Expand Down Expand Up @@ -249,7 +251,9 @@ def run(
self._set_log_file(log_file=log_file)
logger.info(f"{' Starting preprocessing ':=^80}")
logger.info(f"Logs are saved to {self.log_file_handler.baseFilename}")
modality_names = ', '.join([modality.modality_name for modality in self.moving_modalities])
modality_names = ", ".join(
[modality.modality_name for modality in self.moving_modalities]
)
logger.info(
f"Received center modality: {self.center_modality.modality_name} "
f"and moving modalities: {modality_names}"
Expand Down Expand Up @@ -298,17 +302,19 @@ def run(
self.run_brain_extraction(
save_dir_brain_extraction=save_dir_brain_extraction,
)

# Defacing
logger.info(f"{' Checking optional defacing ':-^80}")
self.run_defacing(
save_dir_defacing=save_dir_defacing,
)

# End
logger.info(f"{' Preprocessing complete ':=^80}")

def run_coregistration(self, save_dir_coregistration: Optional[Union[str, Path]] = None) -> None:
def run_coregistration(
self, save_dir_coregistration: Optional[Union[str, Path]] = None
) -> None:
"""
Coregister moving modalities to center modality.
Expand Down Expand Up @@ -336,8 +342,10 @@ def run_coregistration(self, save_dir_coregistration: Optional[Union[str, Path]]

shutil.copyfile(
src=str(self.center_modality.input_path),
dst=str(coregistration_dir /
f"native__{self.center_modality.modality_name}.nii.gz"),
dst=str(
coregistration_dir
/ f"native__{self.center_modality.modality_name}.nii.gz"
),
)

self._save_output(
Expand Down Expand Up @@ -417,8 +425,11 @@ def run_atlas_correction(
)

if self.center_modality.atlas_correction:
center_atlas_corrected_path = atlas_correction_dir / f"atlas_corrected__{self.center_modality.modality_name}.nii.gz"

center_atlas_corrected_path = (
atlas_correction_dir
/ f"atlas_corrected__{self.center_modality.modality_name}.nii.gz"
)

shutil.copyfile(
src=str(self.center_modality.current),
dst=str(center_atlas_corrected_path),
Expand Down Expand Up @@ -503,7 +514,9 @@ def run_brain_extraction(
normalization=True,
)

def run_defacing(self, save_dir_defacing: Optional[Union[str, Path]] = None) -> None:
def run_defacing(
self, save_dir_defacing: Optional[Union[str, Path]] = None
) -> None:
"""Deface images to remove facial features using specified Defacer.
Args:
Expand Down Expand Up @@ -563,7 +576,7 @@ def run_defacing(self, save_dir_defacing: Optional[Union[str, Path]] = None) ->

def _save_output(
self,
src: Union[str, Path],
src: Union[str, Path],
save_dir: Optional[Union[str, Path]],
):
"""
Expand Down
21 changes: 10 additions & 11 deletions brainles_preprocessing/registration/ANTs/ANTs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def register(
**kwargs: Additional registration parameters to update the instantiated defaults.
"""
start_time = datetime.datetime.now()

# TODO - self.registration_params
# We update the registration parameters with the provided kwargs
registration_kwargs = {**self.registration_params, **kwargs}

# Convert all paths to Path objects
fixed_image_path = Path(fixed_image_path)
moving_image_path = Path(moving_image_path)
Expand All @@ -76,7 +76,7 @@ def register(
raise FileNotFoundError(f"Fixed image not found: {fixed_image_path}")
if not moving_image_path.is_file():
raise FileNotFoundError(f"Moving image not found: {moving_image_path}")

# Ensure matrix_path has .mat suffix
if matrix_path.suffix != ".mat":
matrix_path = matrix_path.with_suffix(".mat")
Expand All @@ -89,17 +89,17 @@ def register(
**registration_kwargs,
)
transformed_image = registration_result["warpedmovout"]

# Ensure output directories exist
transformed_image_path.parent.mkdir(parents=True, exist_ok=True)
matrix_path.parent.mkdir(parents=True, exist_ok=True)

ants.image_write(transformed_image, str(transformed_image_path))

shutil.copyfile(
src=registration_result["fwdtransforms"][0],
dst=str(matrix_path),
)
src=registration_result["fwdtransforms"][0],
dst=str(matrix_path),
)

end_time = datetime.datetime.now()

Expand Down Expand Up @@ -142,7 +142,7 @@ def transform(
# TODO - self.transformation_params
# we update the transformation parameters with the provided kwargs
transform_kwargs = {**self.transformation_params, **kwargs}

# Convert all paths to Path objects
fixed_image_path = Path(fixed_image_path)
moving_image_path = Path(moving_image_path)
Expand All @@ -155,10 +155,9 @@ def transform(
if not moving_image_path.is_file():
raise FileNotFoundError(f"Moving image not found: {moving_image_path}")


fixed_image = ants.image_read(str(fixed_image_path))
moving_image = ants.image_read(str(moving_image_path))

# Ensure output directory exist
transformed_image_path.parent.mkdir(parents=True, exist_ok=True)

Expand Down

0 comments on commit b4751fe

Please sign in to comment.