Skip to content

Commit

Permalink
👔 Clarify longitudinal xfms vs longitudinal warps
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Nov 23, 2024
1 parent 322d660 commit 342dae2
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 62 deletions.
18 changes: 15 additions & 3 deletions CPAC/longitudinal/robust_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
traits,
)
from nipype.interfaces.freesurfer import longitudinal
from nipype.interfaces.freesurfer.preprocess import MRIConvert
from nipype.interfaces.freesurfer.utils import LTAConvert

from CPAC.pipeline import nipype_pipeline_engine as pe
Expand All @@ -51,7 +52,7 @@ class RobustTemplateInputSpec(longitudinal.RobustTemplateInputSpec): # noqa: D1

class RobustTemplateOutputSpec(longitudinal.RobustTemplateOutputSpec): # noqa: D101
mapmov = OutputMultiPath(
File(exists=True),
File(),
desc="each input mapped and resampled to longitudinal template",
)

Expand Down Expand Up @@ -127,7 +128,7 @@ def mri_robust_template(
average_metric=cfg["longitudinal_template_generation", "average_method"],
auto_detect_sensitivity=True,
mapmov=True,
out_file=f"{name}.nii.gz",
out_file=f"{name}.mgz",
transform_outputs=True,
),
name="mri_robust_template",
Expand All @@ -138,12 +139,23 @@ def mri_robust_template(
if isinstance(max_iter, int):
node.set_input("maxit", max_iter)

nifti_template = pe.Node(MRIConvert(out_type="niigz"), name="NIfTI-template")
wf.connect(node, "out_file", nifti_template, "in_file")

nifti_outputs = pe.MapNode(
MRIConvert(), name="NIfTI-mapmov", iterfield=["in_file", "out_file"]
)
wf.connect(node, "mapmov", nifti_outputs, "in_file")
nifti_outputs.set_input(
"out_file", [f"space-longitudinal{i + 1}.nii.gz" for i in range(num_sessions)]
)

convert = pe.MapNode(
LTAConvert(), name="convert-to-FSL", iterfield=["in_lta", "out_fsl"]
)
wf.connect(node, "transform_outputs", convert, "in_lta")
convert.set_input(
"out_fsl", [f"space-longitudinal{i}.mat" for i in range(num_sessions)]
"out_fsl", [f"space-longitudinal{i + 1}.mat" for i in range(num_sessions)]
)

return wf
117 changes: 76 additions & 41 deletions CPAC/longitudinal/wf/anat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import cast, Optional

from networkx.classes.digraph import DiGraph
from nipype import config as nipype_config
from nipype.interfaces import fsl
from nipype.interfaces.utility import Merge
Expand All @@ -28,7 +29,7 @@
from CPAC.longitudinal.wf.utils import (
check_creds_path,
cross_graph_connections,
cross_pool_resources,
get_output_from_graph,
select_session_node,
)
from CPAC.pipeline import nipype_pipeline_engine as pe
Expand Down Expand Up @@ -191,17 +192,19 @@ def warp_longitudinal_T1w_to_template(
),
"T1w-brain-template",
],
outputs=[
"label-CSF_mask",
"label-GM_mask",
"label-WM_mask",
"label-CSF_desc-preproc_mask",
"label-GM_desc-preproc_mask",
"label-WM_desc-preproc_mask",
"label-CSF_probseg",
"label-GM_probseg",
"label-WM_probseg",
],
outputs={
"from-longitudinal_to-T1w_mode-image_desc-linear_xfm": {},
"from-longitudinal_to-T1w_mode-image_desc-linear_warp": {},
"label-CSF_mask": {},
"label-GM_mask": {},
"label-WM_mask": {},
"label-CSF_desc-preproc_mask": {},
"label-GM_desc-preproc_mask": {},
"label-WM_desc-preproc_mask": {},
"label-CSF_probseg": {},
"label-GM_probseg": {},
"label-WM_probseg": {},
},
)
def warp_longitudinal_seg_to_T1w(
wf: pe.Workflow,
Expand All @@ -211,6 +214,7 @@ def warp_longitudinal_seg_to_T1w(
opt: Optional[str] = None,
) -> NODEBLOCK_RETURN:
"""Transform anatomical images from longitudinal space template space."""
outputs = {}
if strat_pool.check_rpool("from-longitudinal_to-T1w_mode-image_desc-linear_xfm"):
xfm_prov = strat_pool.get_cpac_provenance(
"from-longitudinal_to-T1w_mode-image_desc-linear_xfm"
Expand All @@ -233,13 +237,21 @@ def warp_longitudinal_seg_to_T1w(
"in_file",
)
xfm = (invt, "out_file")
outputs["from-longitudinal_to-T1w_mode-image_desc-linear_xfm"] = xfm
if reg_tool != "fsl":
msg = f"`warp_longitudinal_seg_to_T1w` not yet implemented for {reg_tool}."
raise NotImplementedError(msg)
warp = pe.Node(
fsl.ConvertWarp(relwarp=True, out_relwarp=True), name=f"convert_warp_{pipe_num}"
)
wf.connect(*xfm, warp, "postmat")
wf.connect(
*strat_pool.get_data("space-longitudinal_desc-brain_T1w"), warp, "reference"
)
outputs["from-longitudinal_to-T1w_mode-image_desc-linear_warp"] = warp, "out_file"

num_cpus = cfg.pipeline_setup["system_config"]["max_cores_per_participant"]

num_ants_cores = cfg.pipeline_setup["system_config"]["num_ants_threads"]

outputs = {}

labels = [
"CSF_mask",
"CSF_desc-preproc_mask",
Expand All @@ -251,7 +263,6 @@ def warp_longitudinal_seg_to_T1w(
"WM_desc-preproc_mask",
"WM_probseg",
]

for label in labels:
apply_xfm = apply_transform(
f"warp_longitudinal_seg_to_T1w_{label}_{pipe_num}",
Expand All @@ -276,11 +287,10 @@ def warp_longitudinal_seg_to_T1w(
node, out = strat_pool.get_data("T1w-brain-template")
wf.connect(node, out, apply_xfm, "inputspec.reference")

wf.connect(*xfm, apply_xfm, "inputspec.transform")

wf.connect(warp, "out_file", apply_xfm, "inputspec.transform")
outputs[f"label-{label}"] = (apply_xfm, "outputspec.output_image")

return (wf, outputs)
return wf, outputs


def anat_longitudinal_wf(
Expand Down Expand Up @@ -345,11 +355,13 @@ def anat_longitudinal_wf(
for key in strats_dct.keys():
strats_dct[key].append(cast(tuple[pe.Node, str], rpool.get_data(key)))
if not dry_run:
workflow.run()
workflow_graph: DiGraph = workflow.run()
for key in strats_dct.keys(): # get the outputs from run-nodes
for index, data in enumerate(list(strats_dct[key])):
if isinstance(data, tuple):
strats_dct[key][index] = workflow.get_output(*data)
strats_dct[key][index] = get_output_from_graph(
workflow_graph, *data
)

wf = initialize_nipype_wf(
config,
Expand Down Expand Up @@ -409,7 +421,7 @@ def anat_longitudinal_wf(
wf.connect(merge_skulls, "out", wholehead_template_node, "input_skull_list")

case "mri_robust_template":
brain_output = head_output = "mri_robust_template.out_file"
brain_output = head_output = "NIfTI-template.out_file"
brain_template_node = mri_robust_template(
f"mri_robust_template_brain_{subject_id}", config, len(sub_list)
)
Expand All @@ -420,7 +432,7 @@ def anat_longitudinal_wf(
merge_brains, "out", brain_template_node, "mri_robust_template.in_files"
)
wf.connect(
merge_brains,
merge_skulls,
"out",
wholehead_template_node,
"mri_robust_template.in_files",
Expand Down Expand Up @@ -471,15 +483,14 @@ def anat_longitudinal_wf(
],
)
wf = connect_pipeline(wf, config, rpool, pipeline_blocks)
if not dry_run:
wf.run()

wf_graph: DiGraph | pe.Workflow = (
cast(DiGraph, wf.run()) if not dry_run else cast(pe.Workflow, wf)
)

# now, just write out a copy of the above to each session
config.pipeline_setup["pipeline_name"] = orig_pipe_name
longitudinal_rpool = rpool
cpr = cross_pool_resources(
f"fsl_longitudinal_{subject_id}"
) # "fsl" for check_prov_for_regtool
for i, session in enumerate(sub_list):
unique_id = session["unique_id"]
input_creds_path = check_creds_path(session.get("creds_path"), subject_id)
Expand All @@ -504,49 +515,61 @@ def anat_longitudinal_wf(

match config["longitudinal_template_generation", "using"]:
case "C-PAC legacy":
assert isinstance(brain_template_node, pe.Node)
cross_graph_connections(
wf_graph,
ses_wf,
merge_brains,
brain_template_node,
"out",
"input_brain_list",
)
cross_graph_connections(
wf_graph,
ses_wf,
merge_skulls,
brain_template_node,
"out",
"input_skull_list",
)
for input_name, output_name in [
("output_brains", "output_brain_list"),
("warps", "warp_list"),
]:
cross_graph_connections(
wf,
wf_graph,
ses_wf,
brain_template_node,
select_sess,
output_name,
input_name,
dry_run,
)

case "mri_robust_template":
assert isinstance(brain_template_node, pe.Workflow)
assert isinstance(wholehead_template_node, pe.Workflow)
index = i + 1
head_select_sess = select_session_node(unique_id, "-wholehead")
head_select_sess = select_session_node(unique_id, "wholehead")
select_sess.set_input("session", f"space-longitudinal{index}")
head_select_sess.set_input("session", f"space-longitudinal{index}")
for input_name, output_name in [
("output_brains", "mri_robust_template.mapmov"),
("output_brains", "NIfTI-mapmov_.out_file"),
("warps", "convert-to-FSL_.out_fsl"),
]:
cross_graph_connections(
wf,
wf_graph,
ses_wf,
brain_template_node,
select_sess,
output_name,
input_name,
dry_run,
)
cross_graph_connections(
wf,
wf_graph,
ses_wf,
wholehead_template_node,
head_select_sess,
output_name,
input_name,
dry_run,
)

rpool.set_data(
Expand Down Expand Up @@ -589,8 +612,20 @@ def anat_longitudinal_wf(
cross_pool_keys = ["from-longitudinal_to-template_mode-image_xfm"]
for key in cross_pool_keys:
node, out = longitudinal_rpool.get_data(key)
cross_graph_connections(wf, ses_wf, node, cpr, out, key, dry_run)
rpool.set_data(key, cpr, key, {}, "", cpr.name)
try:
json_info: dict = longitudinal_rpool.get_json(
key, next(iter(longitudinal_rpool.rpool[key].keys()))
)
except (AttributeError, KeyError, StopIteration):
json_info = {}
rpool.set_data(
key,
node,
out,
json_info,
"",
f"fsl_longitudinal_{subject_id}", # "fsl" for check_prov_for_regtool
)
if not dry_run:
ses_wf.run()

Expand All @@ -605,5 +640,5 @@ def anat_longitudinal_wf(

# this is going to run multiple times!
# once for every strategy!
if not dry_run:
if not dry_run: # check select_sess
ses_wf.run()
Loading

0 comments on commit 342dae2

Please sign in to comment.