Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Introduce mri_robust_template as option for longitudinal template generation #2165

Draft
wants to merge 12 commits into
base: prep_for/mri_robust_template
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Required positional parameter "wf" in input and output of `ingress_pipeconfig_paths` function, where a node to reorient templates is added to the `wf`.
- Required positional parameter "orientation" to `resolve_resolution`.
- Optional positional argument "cfg" to `create_lesion_preproc`.
- `mri_robust_template` for longitudinal template generation.
- `max_iter` parameter for longitudinal template generation.

### Changed

- Moved `pygraphviz` from requirements to `graphviz` optional dependencies group.
- Automatically tag untagged `subject_id` and `unique_id` as `!!str` when loading data config files.
- Made orientation configurable (was hard-coded as "RPI").
- Disabled variant image builds.
- Made `mri_robust_template` default implementation for longitudinal template generation.

### Fixed

Expand Down
4 changes: 2 additions & 2 deletions CPAC/anat_preproc/anat_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
wb_command,
)
from CPAC.pipeline import nipype_pipeline_engine as pe
from CPAC.pipeline.nodeblock import nodeblock
from CPAC.pipeline.nodeblock import nodeblock, NODEBLOCK_RETURN
from CPAC.utils.interfaces import Function
from CPAC.utils.interfaces.fsl import Merge as fslMerge

Expand Down Expand Up @@ -1447,7 +1447,7 @@ def mask_T2(wf_name="mask_T2"):
inputs=["T1w"],
outputs=["desc-preproc_T1w", "desc-reorient_T1w", "desc-head_T1w"],
)
def anatomical_init(wf, cfg, strat_pool, pipe_num, opt=None):
def anatomical_init(wf, cfg, strat_pool, pipe_num, opt=None) -> NODEBLOCK_RETURN:
anat_deoblique = pe.Node(interface=afni.Refit(), name=f"anat_deoblique_{pipe_num}")
anat_deoblique.inputs.deoblique = True

Expand Down
1 change: 1 addition & 0 deletions CPAC/longitudinal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from CPAC.utils.docs import DOCS_URL_PREFIX

assert isinstance(__doc__, str)
__doc__ += f"""

See {DOCS_URL_PREFIX}/user/longitudinal
Expand Down
122 changes: 67 additions & 55 deletions CPAC/longitudinal/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from collections import Counter
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing.pool import Pool
import os
from typing import Literal, Optional

import numpy as np
import nibabel as nib
Expand Down Expand Up @@ -131,27 +133,23 @@ def norm_transformation(input_mat):


def template_convergence(
mat_file, mat_type="matrix", convergence_threshold=np.finfo(np.float64).eps
):
mat_file: str,
mat_type: Literal["matrix", "ITK"] = "matrix",
convergence_threshold: float | np.float64 = np.finfo(np.float64).eps,
) -> bool:
"""Check that the deistance between matrices is smaller than the threshold.

Calculate the distance between transformation matrix with a matrix of no transformation.

Parameters
----------
mat_file : str
mat_file
path to an fsl flirt matrix
mat_type : str
'matrix'(default), 'ITK'
mat_type
The type of matrix used to represent the transformations
convergence_threshold : float
(numpy.finfo(np.float64).eps (default)) threshold for the convergence
convergence_threshold
The threshold is how different from no transformation is the
transformation matrix.

Returns
-------
bool
"""
if mat_type == "matrix":
translation, oth_transform = read_mat(mat_file)
Expand Down Expand Up @@ -346,51 +344,68 @@ def flirt_node(in_img, output_img, output_mat):
return node_list


def check_convergence(mat_list, mat_type, convergence_threshold) -> bool:
"""Test if every transformation matrix has reached the convergence threshold."""
convergence_list = [
template_convergence(mat, mat_type, convergence_threshold) for mat in mat_list
]
return all(convergence_list)


@Function.sig_imports(
[
"from multiprocessing.pool import Pool",
"from typing import Literal, Optional",
"from nipype.pipeline import engine as pe",
"from CPAC.longitudinal.preproc import check_convergence",
]
)
def template_creation_flirt(
input_brain_list,
input_skull_list,
init_reg=None,
avg_method="median",
dof=12,
interp="trilinear",
cost="corratio",
mat_type="matrix",
convergence_threshold=-1,
thread_pool=2,
unique_id_list=None,
):
input_brain_list: list[str],
input_skull_list: list[str],
init_reg: Optional[list[pe.Node]] = None,
avg_method: Literal["median", "mean", "std"] = "median",
dof: Literal[12, 9, 7, 6] = 12,
interp: Literal["trilinear", "nearestneighbour", "sinc", "spline"] = "trilinear",
cost: Literal[
"corratio", "mutualinfo", "normmi", "normcorr", "leastsq", "labeldiff", "bbr"
] = "corratio",
mat_type: Literal["matrix", "ITK"] = "matrix",
convergence_threshold: float | np.float64 = -1,
max_iter: int = 5,
thread_pool: int | Pool = 2,
unique_id_list: Optional[list[str]] = None,
) -> tuple[str, str, list[str], list[str], list[str]]:
"""Create a temporary template from a list of images.

Parameters
----------
input_brain_list : list of str
input_brain_list
list of brain images paths
input_skull_list : list of str
input_skull_list
list of skull images paths
init_reg : list of Node
init_reg
(default None so no initial registration performed)
the output of the function register_img_list with another reference
Reuter et al. 2012 (NeuroImage) section "Improved template estimation"
doi:10.1016/j.neuroimage.2012.02.084 uses a ramdomly
selected image from the input dataset
avg_method : str
function names from numpy library such as 'median', 'mean', 'std' ...
dof : integer (int of long)
number of transform degrees of freedom (FLIRT) (12 by default)
interp : str
('trilinear' (default) or 'nearestneighbour' or 'sinc' or 'spline')
avg_method
function names from numpy library
dof
number of transform degrees of freedom (FLIRT)
interp
final interpolation method used in reslicing
cost : str
('mutualinfo' or 'corratio' (default) or 'normcorr' or 'normmi' or
'leastsq' or 'labeldiff' or 'bbr')
cost
cost function
mat_type : str
'matrix'(default), 'ITK'
mat_type
The type of matrix used to represent the transformations
convergence_threshold : float
convergence_threshold
(numpy.finfo(np.float64).eps (default)) threshold for the convergence
The threshold is how different from no transformation is the
transformation matrix.
max_iter
Maximum number of iterations if transformation does not converge
thread_pool : int or multiprocessing.dummy.Pool
(default 2) number of threads. You can also provide a Pool so the
node will be added to it to be run.
Expand Down Expand Up @@ -463,6 +478,8 @@ def template_creation_flirt(
warp_list,
)

output_brain_list = list(input_brain_list)
output_skull_list = list(input_skull_list)
# Chris: I added this part because it is mentioned in the paper but I actually never used it
# You could run a first register_img_list() with a selected image as starting point and
# give the output to this function
Expand All @@ -471,18 +488,11 @@ def template_creation_flirt(
output_brain_list = [node.inputs.out_file for node in init_reg]
mat_list = [node.inputs.out_matrix_file for node in init_reg]
warp_list = mat_list
# test if every transformation matrix has reached the convergence
convergence_list = [
template_convergence(mat, mat_type, convergence_threshold)
for mat in mat_list
]
converged = all(convergence_list)
converged = check_convergence(mat_list, mat_type, convergence_threshold)
else:
msg = "init_reg must be a list of FLIRT nipype nodes files"
raise ValueError(msg)
else:
output_brain_list = input_brain_list
output_skull_list = input_skull_list
converged = False

temporary_brain_template = os.path.join(
Expand All @@ -496,7 +506,14 @@ def template_creation_flirt(
and the loop stops when this temporary template is close enough (with a transformation
distance smaller than the threshold) to all the images of the precedent iteration.
"""
while not converged:
iterator = 1
iteration = 0
if max_iter == -1:
# make iteration < max_iter always True
iterator = 0
iteration = -2
while not converged and iteration < max_iter:
iteration += iterator
temporary_brain_template, temporary_skull_template = create_temporary_template(
input_brain_list=output_brain_list,
input_skull_list=output_skull_list,
Expand Down Expand Up @@ -551,13 +568,7 @@ def template_creation_flirt(
warp_list[index] = warp_list_filenames[index]

output_brain_list = [node.inputs.out_file for node in reg_list_node]

# test if every transformation matrix has reached the convergence
convergence_list = [
template_convergence(mat, mat_type, convergence_threshold)
for mat in mat_list
]
converged = all(convergence_list)
converged = check_convergence(mat_list, mat_type, convergence_threshold)

if isinstance(thread_pool, int):
pool.close()
Expand Down Expand Up @@ -609,7 +620,7 @@ def subject_specific_template(
"from collections import Counter",
"from multiprocessing.dummy import Pool as ThreadPool",
"from nipype.interfaces.fsl import ConvertXFM",
"from CPAC.longitudinal_pipeline.longitudinal_preproc import ("
"from CPAC.longitudinal.preproc import ("
" create_temporary_template,"
" register_img_list,"
" template_convergence"
Expand All @@ -628,6 +639,7 @@ def subject_specific_template(
"cost",
"mat_type",
"convergence_threshold",
"max_iter",
"thread_pool",
"unique_id_list",
],
Expand Down
Loading
Loading