Skip to content

Commit

Permalink
RF: Fix ITK warp conversion to nitransforms format (#3300)
Browse files Browse the repository at this point in the history
Alternative to (and builds on) #3296.

Closes #3296.

---------

Co-authored-by: mathiasg <mathiasg@stanford.edu>
Co-authored-by: Mathias Goncalves <goncalves.mathias@gmail.com>
  • Loading branch information
3 people authored Jun 5, 2024
1 parent 604eeef commit 6c4cb04
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 32 deletions.
66 changes: 34 additions & 32 deletions fmriprep/utils/transforms.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Utilities for loading transforms for resampling"""

import warnings
from pathlib import Path

import h5py
import nibabel as nb
import nitransforms as nt
import numpy as np
from nitransforms.io.itk import ITKCompositeH5
from transforms3d.affines import compose as compose_affine


def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.TransformBase:
Expand Down Expand Up @@ -38,16 +38,6 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
return chain


FIXED_PARAMS = np.array([
193.0, 229.0, 193.0, # Size
96.0, 132.0, -78.0, # Origin
1.0, 1.0, 1.0, # Spacing
-1.0, 0.0, 0.0, # Directions
0.0, -1.0, 0.0,
0.0, 0.0, 1.0,
]) # fmt:skip


def load_ants_h5(filename: Path) -> nt.base.TransformBase:
"""Load ANTs H5 files as a nitransforms TransformChain"""
# Borrowed from https://github.com/feilong/process
Expand All @@ -56,7 +46,8 @@ def load_ants_h5(filename: Path) -> nt.base.TransformBase:
# Changes:
# * Tolerate a missing displacement field
# * Return the original affine without a round-trip
# * Always return a nitransforms TransformChain
# * Always return a nitransforms TransformBase
# * Construct warp affine from fixed parameters
#
# This should be upstreamed into nitransforms
h = h5py.File(filename)
Expand All @@ -80,24 +71,35 @@ def load_ants_h5(filename: Path) -> nt.base.TransformBase:
msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n'
raise ValueError(msg)

fixed_params = transform2['TransformFixedParameters'][:]
if not np.array_equal(fixed_params, FIXED_PARAMS):
msg = 'Unexpected fixed parameters\n'
msg += f'Expected: {FIXED_PARAMS}\n'
msg += f'Found: {fixed_params}'
if not np.array_equal(fixed_params[6:], FIXED_PARAMS[6:]):
raise ValueError(msg)
warnings.warn(msg, stacklevel=1)

shape = tuple(fixed_params[:3].astype(int))
warp = h['TransformGroup']['2']['TransformParameters'][:]
warp = warp.reshape((*shape, 3)).transpose(2, 1, 0, 3)
warp *= np.array([-1, -1, 1])

warp_affine = np.eye(4)
warp_affine[:3, :3] = fixed_params[9:].reshape((3, 3))
warp_affine[:3, 3] = fixed_params[3:6]
lps_to_ras = np.eye(4) * np.array([-1, -1, 1, 1])
warp_affine = lps_to_ras @ warp_affine
transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine)))
# Warp field fixed parameters as defined in
# https://itk.org/Doxygen/html/classitk_1_1DisplacementFieldTransform.html
shape = transform2['TransformFixedParameters'][:3]
origin = transform2['TransformFixedParameters'][3:6]
spacing = transform2['TransformFixedParameters'][6:9]
direction = transform2['TransformFixedParameters'][9:].reshape((3, 3))

# We are not yet confident that we handle non-unit spacing
# or direction cosine ordering correctly.
# If we confirm or fix, we can remove these checks.
if not np.allclose(spacing, 1):
raise ValueError(f'Unexpected spacing: {spacing}')
if not np.allclose(direction, direction.T):
raise ValueError(f'Asymmetric direction matrix: {direction}')

# ITK uses LPS affines
lps_affine = compose_affine(T=origin, R=direction, Z=spacing)
ras_affine = np.diag([-1, -1, 1, 1]) @ lps_affine

# ITK stores warps in Fortran-order, where the vector components change fastest
# Vectors are in mm LPS
itk_warp = np.reshape(
transform2['TransformParameters'],
(3, *shape.astype(int)),
order='F',
)

# Nitransforms warps are in RAS, with the vector components changing slowest
nt_warp = itk_warp.transpose(1, 2, 3, 0) * np.array([-1, -1, 1])

transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(nt_warp, ras_affine)))
return nt.TransformChain(transforms)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"smriprep @ git+https://github.com/nipreps/smriprep.git@master",
"tedana >= 23.0.2",
"templateflow >= 24.1.0",
"transforms3d",
"toml",
"codecarbon",
"APScheduler",
Expand Down

0 comments on commit 6c4cb04

Please sign in to comment.